In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

In [2]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, einsum
from einops import parse_shape, rearrange, repeat

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
from brainle.models.architectures.attention import AttentionBase, SABlock, RABlock, MABlock, FeedForwardBlock, TransformerBlock, PatcherBlock, UnpatcherBlock, ConvTention, ConvTeNet
        
att = AttentionBase(
    in_features = 12,
    out_features = 24,
    num_heads = 4,
)
q = torch.rand(2, 10, 12)
k = torch.rand(2, 20, 12)
v = torch.rand(2, 20, 24)
print(att(q, k, v).shape)
print(f"Params: {count_parameters(att)}")

torch.Size([2, 10, 24])
Params: 600


In [4]:
block = SABlock(
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)        
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 24])
Params: 1176


In [5]:
block = RABlock(
    in_tokens = 10,
    out_tokens = 5,
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 5, 24])
Params: 1092


In [6]:
block = MABlock(
    memory_size = 512,
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 24])
Params: 19176


In [7]:
block = FeedForwardBlock(
    features = 512,
    multiplier = 4,
    dropout = 0.1
)
out = block(torch.rand(2, 10, 512))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 512])
Params: 2100736


In [8]:
block = TransformerBlock(
    features = 256,
    num_heads = 2,
    dropout_attention = 0.1,
    dropout_mlp = 0.1,
    mlp_multiplier = 4
)
out = block(torch.rand(2, 10, 256))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 256])
Params: 788992


In [9]:
patcher = PatcherBlock(
    kernel_size = 4,
    stride = 2,
    padding = 1
)

unpatcher = UnpatcherBlock(
    kernel_size = 4,
    stride = 2,
    padding = 1
)
x = torch.tensor([[ [1,1,1], [2,2,2], [3,3,3], [4,4,4], [5,5,5], [6,6,6] ]]).float()
print(x, x.shape)
x = patcher(x)
print(x, x.shape)
x = unpatcher(x)
print(x, x.shape)

tensor([[[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
         [4., 4., 4.],
         [5., 5., 5.],
         [6., 6., 6.]]]) torch.Size([1, 6, 3])
tensor([[[[0., 0., 0.],
          [1., 1., 1.],
          [2., 2., 2.],
          [3., 3., 3.]],

         [[2., 2., 2.],
          [3., 3., 3.],
          [4., 4., 4.],
          [5., 5., 5.]],

         [[4., 4., 4.],
          [5., 5., 5.],
          [6., 6., 6.],
          [0., 0., 0.]]]]) torch.Size([1, 3, 4, 3])
tensor([[[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
         [4., 4., 4.],
         [5., 5., 5.],
         [6., 6., 6.]]]) torch.Size([1, 6, 3])


In [10]:
encode = ConvTention(
    in_features = 256,
    out_features = 256,
    num_heads = 8,
    num_layers = 4,
    out_patch_tokens = 2,
    kernel_size = 4,
    stride = 4,
    padding = 0,
    memory_size = 512,
    dropout = 0.1
)
decode = ConvTention(
    in_features = 256,
    out_features = 256,
    num_heads = 8,
    num_layers = 4,
    out_patch_tokens = 4,
    kernel_size = 2,
    stride = 2,
    padding = 0,
    memory_size = 512,
    dropout = 0.1
)

out = encode(torch.rand(2, 1024, 256))
print(out.shape)
out = decode(out)
print(out.shape)
print(f"Params: {count_parameters(encode)}")

torch.Size([2, 512, 256])
torch.Size([2, 1024, 256])
Params: 3746816


In [11]:
net = ConvTeNet(
    vocabulary_size = 800,
    embedding_dim = 256,
    num_layers = 7,
    num_heads = 8,
    use_skip = True
)

x = torch.randint(low=0, high=800, size=(2, 2048))
out = net(x)
print(out.shape)
print(f"Params: {count_parameters(net)}")

torch.Size([2, 2048, 800])
Params: 47368480
