In [1]:
import numpy as np
import torch

from models.model_utils import count_parameters
from models.meshed_memory import EncoderLayer, MeshedMemoryEncoder, DecoderLayer, Decoder
from data.augmentation import Flickr30KRegionalFeatures

In [2]:
# Generate dummy data
dummy = torch.Tensor(np.random.randn(50, 50, 1024)) # (bs, detections, data)
dummy = dummy.float()

In [3]:
encoder = MeshedMemoryEncoder(in_size=1024, 
                              num_layers=2, 
                              out_size=256,
                              key_size=32, 
                              value_size=32, 
                              num_heads=8, 
                              dropout_rate=0.1,
                              feedforward_size=256,
                              num_mem_slots=8,
)
encoded, mask = encoder(dummy)
print(encoded.size())
print(mask.size())
print(count_parameters(encoder))

torch.Size([50, 2, 50, 256])
torch.Size([50, 1, 1, 50])
(1062656, 1062656)


In [5]:
decoder = Decoder(
    num_layers=2,
    num_encoder_layers=2,
    max_sequence_len=20,
    pad_token=0,
    out_size=256, 
    key_size=32,
    value_size=32, 
    feedforward_size=256,
    encoded_size=256,
    vocab_size=2000,
    num_heads=8,
    dropout_rate=0.1
)
y = torch.tensor(np.random.randint(0,99, size=(50, 20)))
out = decoder(y, encoded, mask )
print(f"Output size: {out.size()}")
print(count_parameters(decoder))

torch.Size([1, 1, 20, 20])
torch.Size([50, 20])
Output size: torch.Size([50, 20, 2000])
(2870224, 2875600)
