In [1]:
import sys
from pathlib import Path

# Add project root to Python path
project_root = Path.cwd().parent if Path.cwd().name == 'set_transformer' else Path.cwd()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))


In [2]:
import torch
import torch.nn as nn
from set_transformer.set_transformer import MAB, SAB, ISAB, PMA
from fatransformer.model_components import GLU, MHA

In [3]:
layer = MAB(d_model=128, num_heads=4)

x = torch.randn(20, 10, 128)

out = layer(x)

print(out.shape)

torch.Size([20, 10, 128])


In [4]:
layer = SAB(d_model=128, num_heads=4)

x = torch.randn(20, 10, 128)

out = layer(x)

print(out.shape)

torch.Size([20, 10, 128])


In [13]:
layer = ISAB(d_model=128, num_heads=4, m=5)

x = torch.randn(20, 10, 128)

out = layer(x)

print(out.shape)

torch.Size([20, 10, 128])


In [7]:
layer = PMA(d_model=128, num_heads=4, k=5)

x = torch.randn(20, 10, 128)

out = layer(x)

print(out.shape)

torch.Size([20, 5, 128])


In [8]:
class Encoder(nn.Module):
    def __init__(self, d_model: int, num_heads: int, m: int, dropout: float = 0.0):
        super().__init__()
        self.sab1 = SAB(d_model, num_heads, dropout)
        self.sab2 = SAB(d_model, num_heads, dropout)
        
    def forward(self, x: torch.Tensor):
        x = self.sab1(x)
        x = self.sab2(x)
        return x

enc = Encoder(d_model=128, num_heads=4, m=5)

x = torch.randn(20, 10, 128)

out = enc(x)

print(out.shape)

torch.Size([20, 10, 128])


In [12]:
class Decoder(nn.Module):
    def __init__(self, d_model: int, num_heads: int, k: int, dropout: float = 0.0):
        super().__init__()
        self.pma = PMA(d_model, num_heads, k=1)
        self.sab = SAB(d_model, num_heads, dropout)
        self.ff = GLU(d_model, 2*d_model, d_model)
        
    def forward(self, x: torch.Tensor):
        x = self.pma(x)
        x = self.sab(x)
        x = self.ff(x)
        return x

dec = Decoder(d_model=128, num_heads=4, k=5)

x = torch.randn(20, 10, 128)

out = dec(x)

print(out.shape)
        

torch.Size([20, 1, 128])
