In [1]:
import os, sys
sys.path.insert(1, "/".join(os.path.abspath('').split("/")[:-1]))

import torch

from models import ClassifierTransformer, DecoderOnlyTransformer

from transformer.blocks.utils import ShiftRight, DownsamplingLayer, UpsamplingLayer

from transformer.layers.multi_head_attention.attention_mechanism.attn_params import CosformerParams, VanillaParams, PerformerParams

  @autocast(enabled = False)
  @autocast(enabled = False)


In [2]:
def weight_model(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))

# Decoder only

In [3]:
batch_size, max_length, d_model = 8, 5, 27
vocab_size = 30
num_heads = 3

device = "cpu"

method_params = CosformerParams()

model = DecoderOnlyTransformer(
    d_model=d_model,
    vocab_size=vocab_size,
    structure="6x512",
    num_heads=num_heads,
    method_params=method_params,
    apply_rotary_pos_enc=False,
    dropout=0.1,
    attn_has_outproj=True,
    act_fun="relu",
    pos_enc_type="learnable",
    device=device,
)

x = torch.randint(low=0,high=10,size=(batch_size, max_length)).to(device)
weight_model(model)
model(x).shape

model size: 0.215MB
torch.Size([8, 3, 5, 9])
torch.Size([8, 3, 5, 9])
torch.Size([8, 3, 5, 9])
torch.Size([8, 3, 5, 9])
torch.Size([8, 3, 5, 9])
torch.Size([8, 3, 5, 9])


torch.Size([8, 5, 30])

# ShiftRight

In [4]:
sr = ShiftRight(2)
x = torch.rand(2,4,4)
print(x)
print(sr(x))

tensor([[[0.1447, 0.7891, 0.2759, 0.5191],
         [0.6061, 0.3653, 0.2133, 0.3352],
         [0.1933, 0.5763, 0.0270, 0.0031],
         [0.5236, 0.0224, 0.2713, 0.4525]],

        [[0.4392, 0.7575, 0.7873, 0.8863],
         [0.5406, 0.5315, 0.5795, 0.2743],
         [0.0926, 0.1922, 0.9510, 0.3846],
         [0.0574, 0.1483, 0.2068, 0.1438]]])
tensor([[[0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.1447, 0.7891, 0.2759, 0.5191],
         [0.6061, 0.3653, 0.2133, 0.3352]],

        [[0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.4392, 0.7575, 0.7873, 0.8863],
         [0.5406, 0.5315, 0.5795, 0.2743]]])


In [5]:
ds = DownsamplingLayer(4, 2)
x = torch.rand(2,4,4)
print(x)
print(ds(x))

tensor([[[0.7418, 0.7062, 0.9541, 0.5732],
         [0.5811, 0.1247, 0.9714, 0.4106],
         [0.6326, 0.7896, 0.2048, 0.3308],
         [0.5752, 0.2586, 0.6429, 0.5329]],

        [[0.4346, 0.0238, 0.2349, 0.8429],
         [0.4130, 0.6950, 0.2370, 0.5562],
         [0.8920, 0.3862, 0.2767, 0.8145],
         [0.3674, 0.7929, 0.8570, 0.4274]]])
tensor([[[ 0.3341, -0.6941, -0.8276,  0.6963],
         [ 0.1383, -0.5178, -0.7615,  0.5409]],

        [[ 0.2818, -0.2188, -0.4360,  0.4416],
         [ 0.3419, -0.4404, -0.7808,  0.4683]]], grad_fn=<ViewBackward0>)


In [6]:
us = UpsamplingLayer(4, 2)
x = torch.rand(2,4,4)
print(x)
print(us(x))

tensor([[[0.9054, 0.4136, 0.4122, 0.7331],
         [0.4790, 0.6093, 0.5146, 0.1313],
         [0.1505, 0.3770, 0.6415, 0.8040],
         [0.6130, 0.8362, 0.3947, 0.3664]],

        [[0.6297, 0.5365, 0.3815, 0.2813],
         [0.8473, 0.0233, 0.5794, 0.3076],
         [0.2159, 0.5347, 0.8286, 0.4088],
         [0.0399, 0.3134, 0.8756, 0.2861]]])
tensor([[[-1.3707e-01,  1.1878e-01, -4.3913e-01,  3.4336e-01],
         [ 7.2635e-02,  8.2196e-01,  3.6462e-01,  8.3568e-01],
         [ 4.9052e-02,  3.7318e-01, -2.8793e-01, -6.6212e-02],
         [ 2.7756e-01,  1.0222e+00,  5.1599e-01,  4.4831e-01],
         [-1.0623e-01, -1.0758e-02, -3.7532e-01,  2.0876e-01],
         [-1.2667e-01,  6.7255e-01,  2.2077e-01,  5.0847e-01],
         [ 8.9319e-02,  3.7498e-01, -3.5878e-01,  8.2826e-03],
         [ 2.5358e-01,  1.0491e+00,  5.0260e-01,  6.3076e-01]],

        [[ 1.5055e-04,  3.0380e-01, -3.2292e-01,  9.6576e-02],
         [ 2.6433e-01,  9.3608e-01,  4.5376e-01,  6.0556e-01],
         [-2.1672e-0

# Classifier

In [7]:
batch_size, max_length, d_model = 8, 16, 27
vocab_size = 100
num_heads = 9

device = "mps"

model = ClassifierTransformer(
    d_model=d_model,
    vocab_size=vocab_size,
    structure="3x16",
    num_classes=2,
    num_heads=num_heads,
    method_params=CosformerParams(),
    apply_rotary_pos_enc=True,
    dropout=0.1,
    attn_has_outproj=True,
    act_fun="gelu",
    pos_enc_type="learnable",
    norm_before=True,
    device=device,
)

x = torch.randint(low=0,high=100,size=(batch_size, max_length)).to(device)

weight_model(model)
model(x).shape

model size: 0.096MB
torch.Size([8, 9, 16, 3])
torch.Size([8, 9, 16, 3])
torch.Size([8, 9, 16, 3])


torch.Size([8, 2])

# U-transformer

In [8]:
batch_size, max_length, d_model = 8, 16, 64
vocab_size = 100
num_heads = 4

model = ClassifierTransformer(
    d_model=d_model,
    vocab_size=vocab_size,
    structure="3x64,3x32",
    num_classes=2,
    num_heads=num_heads,
    method_params=CosformerParams(),
    apply_rotary_pos_enc=True,
    dropout=0.1,
    attn_has_outproj=True,
    act_fun="gelu",
    pos_enc_type="learnable",
    norm_before=True,
    device="cpu",
)

x = torch.randint(low=0,high=100,size=(batch_size, max_length))

model(x).shape

torch.Size([8, 4, 16, 16])
torch.Size([8, 4, 16, 16])
torch.Size([8, 4, 16, 16])
torch.Size([8, 4, 8, 16])
torch.Size([8, 4, 8, 16])
torch.Size([8, 4, 8, 16])


torch.Size([8, 2])