In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

In [2]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, einsum
from einops import parse_shape, rearrange, repeat

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
from brainle.models.architectures.attention import AttentionBase, SABlock, RABlock, MABlock 
        
att = AttentionBase(
    in_features = 12,
    out_features = 24,
    num_heads = 4,
)
q = torch.rand(2, 10, 12)
k = torch.rand(2, 20, 12)
v = torch.rand(2, 20, 24)
print(att(q, k, v).shape)
print(f"Params: {count_parameters(att)}")

torch.Size([2, 10, 24])
Params: 600


In [4]:
block = SABlock(
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)        
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 24])
Params: 1176


In [5]:
block = RABlock(
    in_tokens = 10,
    out_tokens = 5,
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 5, 24])
Params: 1092


In [6]:
block = MABlock(
    memory_size = 512,
    in_features = 12,
    out_features = 24,
    num_heads = 4
)

out = block(torch.rand(2, 10, 12))
print(out.shape)
print(f"Params: {count_parameters(block)}")

torch.Size([2, 10, 24])
Params: 19176
