<a href="https://colab.research.google.com/github/morganmcg1/reformer-fastai/blob/main/exploration/chunked_feed_forward_mem_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Mon Nov  9 17:55:05 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   65C    P8    12W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import torch
import torch.nn as nn
import torch.autograd.profiler as profiler

Both modules are stripped versions of lucidrains code

In [3]:
class FeedForward(nn.Module):
    def __init__(self, d_model, p=0.):
        super().__init__()
        self.lin1 = nn.Linear(d_model, d_model*4)
        self.act = nn.GELU()
        self.drop = nn.Dropout(p)
        self.lin2 = nn.Linear(d_model*4, d_model)

    def forward(self, x):
        x = self.lin1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.lin2(x)
        return x

In [4]:
class Chunk(nn.Module):
    def __init__(self, chunks, func, dim=1):
        super().__init__()
        self.dim = dim
        self.chunks = chunks
        self.func = func

    def forward(self, x):
        if self.chunks == 1:
            return self.fn(x)
        chunks = x.chunk(self.chunks, dim = self.dim)
        return torch.cat([self.func(c) for c in chunks], dim = self.dim)

## sanity check

In [5]:
# config
bs = 1
sl = 2**8
d_model = 768
p = 0.

In [6]:
x = torch.randn(bs, sl, d_model)
ff = FeedForward(d_model, p)
chunk_ff = Chunk(10, ff)
with torch.no_grad():
    out = ff(x)
    c_out = chunk_ff(x)
assert ((out - c_out).abs() < 1e-5).all()

In [7]:
del x, ff, chunk_ff, out, c_out

## test 1 forward pass + no_grad

### vanilla ff

In [8]:
bs = 8
sl = 2**16
d_model = 1024
p = 0.

In [9]:
x = torch.randn(bs, sl, d_model, device='cuda')

In [10]:
!nvidia-smi

Mon Nov  9 17:31:12 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   56C    P0    30W /  70W |   2997MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [11]:
ff = FeedForward(d_model, p).cuda()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof:
    with torch.no_grad():
        out = ff(x)
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

RuntimeError: ignored

### chunked ff

In [6]:
bs = 8
sl = 2**16
d_model = 1024
p = 0.

In [7]:
x = torch.randn(bs, sl, d_model, device='cuda')

In [8]:
!nvidia-smi

Mon Nov  9 17:34:31 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P0    30W /  70W |   2997MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [9]:
ff = FeedForward(d_model, p).cuda()
chunk_ff = Chunk(10, ff)
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof:
    with torch.no_grad():
        out = chunk_ff(x)
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
           aten::empty        38.83%       4.553ms        38.83%       4.553ms     111.042us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      20.02 Gb      20.02 Gb            41  
          aten::matmul         3.44%     403.090us        41.27%       4.838ms     241.905us     188.290us        

## test 2 forward pass

### vanilla ff

In [5]:
# config
bs = 6
sl = 2**16
d_model = 1024
p = 0.

In [6]:
x = torch.randn(bs, sl, d_model).cuda()

In [7]:
!nvidia-smi

Mon Nov  9 17:45:32 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   70C    P0    34W /  70W |   2485MiB / 15079MiB |     26%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [8]:
ff = FeedForward(d_model, p).cuda()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof:
    out = ff(x)
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

RuntimeError: ignored

### chunked ff

In [4]:
# config
bs = 6
sl = 2**16
d_model = 1024
p = 0.

In [6]:
x = torch.randn(bs, sl, d_model).cuda()

In [7]:
!nvidia-smi

Mon Nov  9 17:49:30 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P0    31W /  70W |   2485MiB / 15079MiB |     12%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [8]:
ff = FeedForward(d_model, p).cuda()
chunk_ff = Chunk(10, ff)
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof:
    out = chunk_ff(x)
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

RuntimeError: ignored

## test fwd + bwd

### ff

In [5]:
# config
bs = 4
sl = 2**16
d_model = 1024
p = 0.
x = torch.randn(bs, sl, d_model).cuda()
!nvidia-smi

Mon Nov  9 17:57:02 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   56C    P0    30W /  70W |   1973MiB / 15079MiB |     20%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
ff = FeedForward(d_model, p).cuda()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof:
    out = ff(x)
    out.sum().backward()
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

RuntimeError: ignored

### chunked ff

In [7]:
# config
bs = 4
sl = 2**16
d_model = 1024
p = 0.
x = torch.randn(bs, sl, d_model).cuda()
!nvidia-smi

Mon Nov  9 18:00:18 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P0    33W /  70W |  12269MiB / 15079MiB |     38%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [8]:
ff = FeedForward(d_model, p).cuda()
chunk_ff = Chunk(10, ff)
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof:
    out = chunk_ff(x)
    out.sum().backward()
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

RuntimeError: ignored