In [2]:
from datetime import datetime

import torch
import torch.nn as nn

import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.models import BaselineEncoder, Performer, PerformerAttention, InformerAttention

In [3]:
device = 'cuda'

hidden_size = 128
num_heads = 4
num_layers = 12
dim_feedforward = num_heads * hidden_size

bs = 16
L = 2 ** 10

x = torch.randn(bs, L, hidden_size, device=device)
att = torch.ones(bs, L, device=device) > 0

In [4]:
attention = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
model = BaselineEncoder(attention, hidden_size, num_layers, dim_feedforward).to(device)

start = datetime.now()
print(model(x, att)[0].shape)
print(datetime.now() - start)

torch.Size([16, 1024, 128])
0:00:01.772717


In [6]:
attention = PerformerAttention(hidden_dim=hidden_size, num_heads=num_heads)
model = BaselineEncoder(attention, hidden_size, num_layers, dim_feedforward).to(device)

start = datetime.now()
print(model(x, att)[0].shape)
print(datetime.now() - start)

torch.Size([16, 1024, 128])
0:00:00.024001


In [4]:
config = {
    'head_size': hidden_size  // num_heads, 
    'length': L, 
    'target_len': 70, 
    'attn_func': 'softmax', 
    'attn_num_basis': 100, 
    'attn_drop': 0.1, 
    'infinite_memory': True, 
    'n_layers': num_layers, 
    'n_heads': num_heads, 
    'd_model': hidden_size, 
    'mask': True, 
    'mask_type': 'cnn', 
    'kl_regularizer': True, 
    'sigma_0': 0, 
    'mu_0': 0,
    'share_mask': True,
    'device': 'cpu'
}

attention = InformerAttention(**config)
# model = BaselineEncoder(attention, hidden_size, num_layers, dim_feedforward).to(device)

start = datetime.now()
print(model(x, att)[0].shape)
print(datetime.now() - start)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [9]:
config = {
    'dim': hidden_size,
    'depth': num_layers,
    'heads': num_heads,
    'dim_head': hidden_size
}

model = Performer(**config).to(device)
start = datetime.now()
print(model(x).shape)
print(datetime.now() - start)

torch.Size([16, 1024, 128])
0:00:00.069004
