In [1]:
import torch
from torch import nn
from models.networks import GlobalContextualDeepTransition
from models.utils import sample_sequence, pad_packed_sequence

In [2]:
LENGTHS = [5, 3, 4, 2]
INPUT_UNITS = 300
ENCODER_UNITS = 128
DECODER_UNITS = 256
OUTPUT_UNITS = 3
TRANSITION_LENGTH = 4

In [3]:
model = GlobalContextualDeepTransition(INPUT_UNITS, ENCODER_UNITS, DECODER_UNITS, TRANSITION_LENGTH, OUTPUT_UNITS)
seq, req = sample_sequence(INPUT_UNITS, LENGTHS)
outputs = model(seq, req)

In [4]:
outputs, lengths = pad_packed_sequence(outputs) # t b u -> b t u
print(lengths)
print(outputs.permute(1, 0, 2))

tensor([5, 3, 4, 2])
tensor([[[ 0.0702, -0.2220,  0.4252],
         [ 0.5801, -0.2577,  0.9527],
         [-0.3754,  0.4508, -0.8836],
         [-0.5972, -0.2501,  0.9223],
         [-0.1064, -0.1981,  0.7893]],

        [[-0.7044,  0.0223, -0.7247],
         [ 0.0876, -0.7223, -0.4285],
         [ 0.1697,  0.4949, -0.1468],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]],

        [[-0.3285,  0.5826,  0.2639],
         [-0.1795,  0.4274, -0.3431],
         [-1.1322, -0.1675,  0.1704],
         [ 0.2928, -0.3433,  0.1543],
         [ 0.0000,  0.0000,  0.0000]],

        [[-0.0683,  0.0737, -0.2362],
         [ 0.0636, -0.0392,  0.5872],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]], grad_fn=<PermuteBackward>)


In [5]:
x = sum(p.numel() for p in model.parameters())
print(f'The encoder-decoder model has {x:,} parameters')

The encoder-decoder model has 2,285,315 parameters
