In [1]:
import sys
sys.path.append('..')

In [None]:
import modules.initialize as minit

minit.initialize(verbose=True)

In [4]:
from modules import SelfMultiheadAttn, EncdecMultiheadAttn
import torch

In [5]:
device = torch.device('cuda')

In [6]:
seq_length   = 2048
sequences    = 2
hidden_dim   = 2048
heads        = 32
dropout_prob = 0.0

In [None]:
tst_layer = SelfMultiheadAttn(hidden_dim, 
                               heads, 
                               dropout=dropout_prob, 
                               bias=True, 
                               include_norm_add=True, 
                               impl='default')
tst_layer = tst_layer.to(device)

In [None]:
tst_inputs = torch.randn(seq_length, sequences, hidden_dim, device=device).requires_grad_(True)

In [None]:
%%time
for i in range(100):
    tst_outputs,_ = tst_layer.forward(tst_inputs, 
                                       tst_inputs, 
                                       tst_inputs,
                                       key_padding_mask=None, 
                                       need_weights=False, 
                                       attn_mask=None,
                                       is_training=True)


    tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

In [None]:
torch.cuda.max_memory_allocated() / 2**30

### Lean

In [12]:
%env LEAN_USE_JIT=0
import sys
sys.path.append('../../petabert/')

from lib.modules.attn import LeanSelfAttention
import torch, torch.nn as nn

env: LEAN_USE_JIT=0


In [13]:
seq_length   = 2048
sequences    = 2
hidden_dim   = 2048
heads        = 32
dropout_prob = 0.0
device=torch.device('cuda')

In [14]:
tst_layer = LeanSelfAttention(hidden_dim, 
                               heads, 
                               dropout=dropout_prob, 
                               residual=False, checkpoint_attention_core=False)
tst_layer = tst_layer.to(device)

In [15]:
tst_inputs = torch.randn(sequences, seq_length, hidden_dim, device=device).requires_grad_(True)

In [16]:
%%time
for i in range(100):
    tst_outputs, = tst_layer.forward(tst_inputs)
    tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

CPU times: user 2.56 s, sys: 1.62 s, total: 4.18 s
Wall time: 4.19 s


In [17]:
torch.cuda.max_memory_allocated() / 2**30

4.3751220703125

In [None]:
%%time
with torch.cuda.amp.autocast(True):
    for i in range(100):
        tst_outputs, = tst_layer.forward(tst_inputs)
        tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

In [None]:
torch.cuda.max_memory_allocated() / 2**30

### PyTorch

In [None]:
import sys
import torch, torch.nn as nn

In [None]:
seq_length   = 4096
sequences    = 2
hidden_dim   = 4096
heads        = 32
dropout_prob = 0.0
device=torch.device('cuda')

In [None]:
class BaselineAttn(nn.Module):
    def __init__(self, dim: int, heads: int):
        super().__init__()
        self.dim, self.heads = dim, heads
        self.pre_norm = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, add_bias_kv=True, batch_first=True)
        self.post_norm = nn.LayerNorm(dim)
    
    def forward(self, input, attn_mask):
        input = self.pre_norm(input)
        output, weights = self.attn.forward(input, input, input, attn_mask=attn_mask)
        return self.post_norm(output)


In [None]:
tst_layer = BaselineAttn(hidden_dim, heads)
tst_layer = tst_layer.to(device)

In [None]:
tst_inputs = torch.randn(sequences, seq_length, hidden_dim, device=device).requires_grad_(True)
attn_mask = torch.ones(tst_inputs.shape[0] * heads, tst_inputs.shape[1], tst_inputs.shape[1], 
                       device=tst_inputs.device)


In [None]:
%%time
for i in range(100):
    tst_outputs = tst_layer.forward(tst_inputs, attn_mask)
    tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

In [None]:
torch.cuda.max_memory_allocated() / 2**30

In [None]:
%%time
with torch.cuda.amp.autocast(True):
    for i in range(100):
        tst_outputs = tst_layer.forward(tst_inputs, attn_mask)
        tst_outputs.backward(torch.randn_like(tst_inputs))
torch.cuda.synchronize(device)

In [None]:
torch.cuda.max_memory_allocated() / 2**30