System information (Google Colab free tier with T4 GPU)
- OS Platform and Distribution: Ubuntu 18.04.5 LTS
- Flax version:
0.5.2
- Jax version:
0.3.14
- Jaxlib version:
0.3.14+cuda11.cudnn805
- Python version:
3.7.13
- GPU/TPU model and memory: Tesla T4, 16 GB
- CUDA version (if applicable):
11.1
Problem you have encountered:
I have been trying to train some models on Google Colab (Free tier). There seems to be a system memory leak (Not GPU memory). More and more system memory gets used up on each call to a jit-ed apply function of a nn.Module. It increases indefinitely until Google Colab crashes.
What I have tried:
- Just JAX
- Network structure
-
When trying to create a minimal reproducible example, a one-layer network
class Net(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(32, (3, 3))(x)
return x
does not show a memory leak.
However, a two-layer network does:
class Net(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(32, (3, 3))(x)
x = nn.Conv(32, (3, 3))(x)
return x
I have not tried with other layers (e.g. using nn.Dense instead of nn.Conv) but I think the above example is minimal enough.
- Other platforms
- I have not encountered the memory leak in other platforms. I have tried it on my local machine and on Kaggle, both ran without memory leak. So this could be Colab's issue as well, but I don't know how to debug this.
Screenshot of error:

The system memory usage growth is also reflected in Colab's resource monitor.
Steps to reproduce:
Colab notebook to reproduce: https://colab.research.google.com/drive/1oWSDpYIDUgFfAe26XqsX9Ufa-Pe_nF4d?usp=sharing
System information (Google Colab free tier with T4 GPU)
0.5.20.3.140.3.14+cuda11.cudnn8053.7.1311.1Problem you have encountered:
I have been trying to train some models on Google Colab (Free tier). There seems to be a system memory leak (Not GPU memory). More and more system memory gets used up on each call to a jit-ed apply function of a
nn.Module. It increases indefinitely until Google Colab crashes.What I have tried:
When trying to create a minimal reproducible example, a one-layer network
does not show a memory leak.
However, a two-layer network does:
I have not tried with other layers (e.g. using
nn.Denseinstead ofnn.Conv) but I think the above example is minimal enough.Screenshot of error:
The system memory usage growth is also reflected in Colab's resource monitor.
Steps to reproduce:
Colab notebook to reproduce: https://colab.research.google.com/drive/1oWSDpYIDUgFfAe26XqsX9Ufa-Pe_nF4d?usp=sharing