In [1]:
import os, sys
sys.path.insert(1, "/".join(os.path.abspath('').split("/")[:-1]))

import torch

from models import ClassifierTransformer, DecoderOnlyTransformer

from transformer.blocks.utils import ShiftRight, DownsamplingLayer, UpsamplingLayer

from transformer.layers.multi_head_attention.attention_mechanism.attn_params import CosformerParams, VanillaParams, PerformerParams

  @autocast(enabled = False)
  @autocast(enabled = False)


In [2]:
def weight_model(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))

# Decoder only

In [3]:
batch_size, max_length, d_model = 8, 16, 1024
vocab_size = 30
num_heads = 4

method_params = VanillaParams()

model = DecoderOnlyTransformer(
    d_model=d_model,
    vocab_size=vocab_size,
    structure="6x1024",
    num_heads=num_heads,
    method_params=method_params,
    apply_rotary_pos_enc=True,
    dropout=0.1,
    attn_has_outproj=True,
    act_fun="relu",
    pos_enc_type="learnable",
    device="cpu",
)

x = torch.randint(low=0,high=10,size=(batch_size, max_length))
weight_model(model)
model(x).shape

model size: 200.498MB


torch.Size([8, 16, 30])

# ShiftRight

In [4]:
sr = ShiftRight(2)
x = torch.rand(2,4,4)
print(x)
print(sr(x))

tensor([[[0.5094, 0.1889, 0.6503, 0.2646],
         [0.7025, 0.8269, 0.8765, 0.6991],
         [0.8861, 0.7245, 0.1752, 0.3057],
         [0.7518, 0.5491, 0.4045, 0.6779]],

        [[0.1519, 0.6030, 0.0880, 0.7727],
         [0.6214, 0.6128, 0.2682, 0.8686],
         [0.6315, 0.3781, 0.0280, 0.6445],
         [0.2567, 0.3052, 0.6363, 0.6880]]])
tensor([[[0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.5094, 0.1889, 0.6503, 0.2646],
         [0.7025, 0.8269, 0.8765, 0.6991]],

        [[0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.1519, 0.6030, 0.0880, 0.7727],
         [0.6214, 0.6128, 0.2682, 0.8686]]])


In [5]:
ds = DownsamplingLayer(4, 2)
x = torch.rand(2,4,4)
print(x)
print(ds(x))

tensor([[[0.7694, 0.6985, 0.6660, 0.1967],
         [0.0062, 0.5241, 0.8500, 0.3925],
         [0.1267, 0.6699, 0.8163, 0.7819],
         [0.4394, 0.6064, 0.1130, 0.0887]],

        [[0.8162, 0.2403, 0.0268, 0.4737],
         [0.9123, 0.3872, 0.7385, 0.2919],
         [0.5827, 0.2822, 0.8665, 0.6915],
         [0.0313, 0.4579, 0.0186, 0.8751]]])
tensor([[[-0.0042, -0.0704, -0.7758,  0.7127],
         [-0.3964,  0.3781, -0.5826,  0.5683]],

        [[-0.4693, -0.0624, -0.7208,  0.3145],
         [-0.1135,  0.0774, -0.7928,  0.7974]]], grad_fn=<ViewBackward0>)


In [6]:
us = UpsamplingLayer(4, 2)
x = torch.rand(2,4,4)
print(x)
print(us(x))

tensor([[[0.9008, 0.6034, 0.1423, 0.6246],
         [0.6567, 0.9595, 0.5829, 0.9248],
         [0.2808, 0.6677, 0.4141, 0.9785],
         [0.9868, 0.5687, 0.6689, 0.4301]],

        [[0.5266, 0.1060, 0.3934, 0.0352],
         [0.7165, 0.7234, 0.8946, 0.7398],
         [0.1611, 0.4775, 0.7440, 0.9856],
         [0.1146, 0.0751, 0.3384, 0.8298]]])
tensor([[[-6.3578e-01, -2.6518e-01, -1.8701e-01,  2.2499e-01],
         [ 8.4219e-01, -6.6328e-01,  2.7802e-01,  3.6642e-01],
         [-6.1478e-01, -2.4426e-01, -6.9666e-02,  2.8947e-01],
         [ 7.7288e-01, -9.0025e-01,  1.7178e-01,  4.0270e-01],
         [-4.2795e-01, -3.0789e-01, -3.3951e-02,  4.4959e-01],
         [ 5.7789e-01, -9.3200e-01,  2.4812e-01,  5.0994e-01],
         [-4.9557e-01,  5.7479e-03,  8.9243e-03,  7.7401e-02],
         [ 8.4704e-01, -6.5836e-01,  1.6229e-01,  3.6996e-02]],

        [[-7.9698e-02,  6.2831e-02, -2.0158e-04,  1.5463e-01],
         [ 6.4092e-01, -4.8436e-01,  1.7835e-01, -1.2602e-02],
         [-4.6890e-0

# Classifier

In [7]:
batch_size, max_length, d_model = 8, 16, 27
vocab_size = 100
num_heads = 9

device = "mps"

model = ClassifierTransformer(
    d_model=d_model,
    vocab_size=vocab_size,
    structure="3x16",
    num_classes=2,
    num_heads=num_heads,
    method_params=CosformerParams(),
    apply_rotary_pos_enc=True,
    dropout=0.1,
    attn_has_outproj=True,
    act_fun="gelu",
    pos_enc_type="learnable",
    norm_before=True,
    device=device,
)

x = torch.randint(low=0,high=100,size=(batch_size, max_length)).to(device)

weight_model(model)
model(x).shape

model size: 0.096MB


RuntimeError: einsum(): the number of subscripts in the equation (4) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

# U-transformer

In [None]:
batch_size, max_length, d_model = 8, 16, 64
vocab_size = 100
num_heads = 4

model = ClassifierTransformer(
    d_model=d_model,
    vocab_size=vocab_size,
    structure="3x64,3x32",
    num_classes=2,
    num_heads=num_heads,
    method_params=CosformerParams(),
    apply_rotary_pos_enc=True,
    dropout=0.1,
    attn_has_outproj=True,
    act_fun="gelu",
    pos_enc_type="learnable",
    norm_before=True,
    device="cpu",
)

x = torch.randint(low=0,high=100,size=(batch_size, max_length))

model(x).shape

torch.Size([8, 2])