In [1]:
import numpy as np
import torch

from models.meshed_memory import EncoderLayer, MeshedMemoryEncoder, DecoderLayer, Decoder
from utils import count_trainable_parameters

In [2]:
# Generate dummy data
dummy = torch.Tensor(np.random.randn(5, 50, 1024)) # (bs, detections, data)
dummy = dummy.float()

In [3]:
encoder = MeshedMemoryEncoder(in_size=1024, 
                              num_layers=2, 
                              out_size=512,
                              key_size=32, 
                              value_size=32, 
                              num_heads=8, 
                              dropout_rate=0.1,
                              feedforward_size=512
)
encoded, mask = encoder(dummy)
print(encoded.size())
print(mask.size())
print(count_trainable_parameters(encoder))

torch.Size([5, 2, 50, 512])
torch.Size([5, 1, 1, 50])
2631680


In [10]:
decoder = Decoder(
    num_layers=2,
    num_encoder_layers=2,
    max_sequence_len=85,
    pad_token=0,
    out_size=512, 
    key_size=32,
    value_size=32, 
    feedforward_size=512,
    encoded_size=512,
    vocab_size=100,
    num_heads=8,
    dropout_rate=0.1
)
y = torch.tensor(np.random.randint(0,99, size=(5, 85)))
out = decoder(y, encoded, mask )
print(f"Output size: {out.size()}")
print(count_trainable_parameters(decoder))

Mask Size: torch.Size([5, 85, 1])
Self Attention Size: torch.Size([5, 1, 85, 85])
Output size: torch.Size([5, 85, 100])
4310628
