In [1]:
import sys
sys.path.append('..')

In [2]:
import modules.initialize as minit

minit.initialize(verbose=True)

Detected CUDA files, patching ldflags
Emitting ninja build file /storage/hdd1/jheuristic/exp/decentralized/jheuristic/adapters/mytorchcudamodules/notebooks/../csrc/build/build.ninja...
Building extension module fused_mix_prec_layer_norm_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module fused_mix_prec_layer_norm_cuda...


In [3]:
from modules import SelfMultiheadAttn, EncdecMultiheadAttn
import torch

In [4]:
device = torch.device('cuda')

In [5]:
seq_length   = 4096
sequences    = 2
hidden_dim   = 4096
heads        = 32
dropout_prob = 0.0

In [6]:
tst_layer = SelfMultiheadAttn(hidden_dim, 
                               heads, 
                               dropout=dropout_prob, 
                               bias=True, 
                               include_norm_add=True, 
                               impl='default')
tst_layer = tst_layer.to(device)

In [7]:
tst_inputs = torch.randn(seq_length, sequences, hidden_dim, device=device).requires_grad_(True)

In [8]:
%%time
for i in range(100):
    tst_outputs,_ = tst_layer.forward(tst_inputs, 
                                       tst_inputs, 
                                       tst_inputs,
                                       key_padding_mask=None, 
                                       need_weights=False, 
                                       attn_mask=None,
                                       is_training=True)


    tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

  head_dim = inputs.size(2) // heads_t[0]


CPU times: user 28.5 s, sys: 21.9 s, total: 50.4 s
Wall time: 50.4 s


In [9]:
torch.cuda.max_memory_allocated() / 2**30

27.312759399414062

### Lean

In [1]:
%env LEAN_USE_JIT=0
import sys
sys.path.append('../..')

from lib.modules.attn import LeanSelfAttention
import torch, torch.nn as nn

env: LEAN_USE_JIT=0


In [2]:
seq_length   = 4096
sequences    = 2
hidden_dim   = 4096
heads        = 32
dropout_prob = 0.0
device=torch.device('cuda')

In [3]:
tst_layer = LeanSelfAttention(hidden_dim, 
                               heads, 
                               dropout=dropout_prob, 
                               residual=False, checkpoint_attention_core=False)
tst_layer = tst_layer.to(device)

In [4]:
tst_inputs = torch.randn(sequences, seq_length, hidden_dim, device=device).requires_grad_(True)

In [5]:
%%time
for i in range(100):
    tst_outputs, = tst_layer.forward(tst_inputs)
    tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

CPU times: user 28 s, sys: 19.9 s, total: 48 s
Wall time: 48 s


In [6]:
torch.cuda.max_memory_allocated() / 2**30

17.500244140625

In [5]:
%%time
with torch.cuda.amp.autocast(True):
    for i in range(100):
        tst_outputs, = tst_layer.forward(tst_inputs)
        tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

CPU times: user 8.9 s, sys: 6.7 s, total: 15.6 s
Wall time: 15.7 s


In [6]:
torch.cuda.max_memory_allocated() / 2**30

15.406532287597656

### PyTorch

In [1]:
import sys
import torch, torch.nn as nn

In [2]:
seq_length   = 4096
sequences    = 2
hidden_dim   = 4096
heads        = 32
dropout_prob = 0.0
device=torch.device('cuda')

In [3]:
class BaselineAttn(nn.Module):
    def __init__(self, dim: int, heads: int):
        super().__init__()
        self.dim, self.heads = dim, heads
        self.pre_norm = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, add_bias_kv=True, batch_first=True)
        self.post_norm = nn.LayerNorm(dim)
    
    def forward(self, input, attn_mask):
        input = self.pre_norm(input)
        output, weights = self.attn.forward(input, input, input, attn_mask=attn_mask)
        return self.post_norm(output)


In [4]:
tst_layer = BaselineAttn(hidden_dim, heads)
tst_layer = tst_layer.to(device)

In [5]:
tst_inputs = torch.randn(sequences, seq_length, hidden_dim, device=device).requires_grad_(True)
attn_mask = torch.ones(tst_inputs.shape[0] * heads, tst_inputs.shape[1], tst_inputs.shape[1], 
                       device=tst_inputs.device)


In [6]:
%%time
for i in range(100):
    tst_outputs = tst_layer.forward(tst_inputs, attn_mask)
    tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

CPU times: user 29.5 s, sys: 26 s, total: 55.5 s
Wall time: 55.5 s


In [7]:
torch.cuda.max_memory_allocated() / 2**30

21.758270263671875

In [6]:
%%time
with torch.cuda.amp.autocast(True):
    for i in range(100):
        tst_outputs = tst_layer.forward(tst_inputs, attn_mask)
        tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

CPU times: user 15.5 s, sys: 13.5 s, total: 29 s
Wall time: 29.1 s


In [7]:
torch.cuda.max_memory_allocated() / 2**30

19.60052490234375