In [1]:
import torch
from torch import nn
from models.networks import GlobalContextualDeepTransition
from models.utils import sample_sequence, pad_packed_sequence

In [2]:
LENGTHS = [5, 3, 4, 2]
INPUT_UNITS = 300
ENCODER_UNITS = 128
DECODER_UNITS = 256
OUTPUT_UNITS = 3
TRANSITION_LENGTH = 4

In [3]:
model = GlobalContextualDeepTransition(INPUT_UNITS, ENCODER_UNITS, DECODER_UNITS, TRANSITION_LENGTH, OUTPUT_UNITS)
seq, req = sample_sequence(INPUT_UNITS, LENGTHS)
outputs = model(seq, req)

In [4]:
outputs, lengths = pad_packed_sequence(outputs) # t b u -> b t u
print(lengths)
print(outputs.permute(1, 0, 2))

tensor([5, 3, 4, 2])
tensor([[[0.3607, 0.2622, 0.3771],
         [0.1129, 0.4447, 0.4424],
         [0.4100, 0.3693, 0.2207],
         [0.4938, 0.3382, 0.1680],
         [0.4703, 0.2745, 0.2552]],

        [[0.3537, 0.4559, 0.1903],
         [0.3856, 0.1869, 0.4275],
         [0.5828, 0.1617, 0.2555],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[0.4115, 0.3041, 0.2844],
         [0.1522, 0.4710, 0.3769],
         [0.3236, 0.2290, 0.4473],
         [0.2705, 0.4538, 0.2757],
         [0.0000, 0.0000, 0.0000]],

        [[0.3240, 0.3408, 0.3352],
         [0.3646, 0.4096, 0.2258],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]], grad_fn=<PermuteBackward>)


In [5]:
x = sum(p.numel() for p in model.parameters())
print(f'The encoder-decoder model has {x:,} parameters')

The encoder-decoder model has 2,285,315 parameters
