In [1]:
import torch

from src.init_states import ZerosState, TrainableState
from src.init_vars import BasicVars
from src.init_contexts import EmptyContext
from src.embeddings import BasicEmbedding, IdentityEmbedding

from src.encoders import RNNEncoder
from src.decoders import RNNDecoder
#from src.architectures import RNNEncoder, RNNDecoder

## Init states

In [4]:
cell = 'GRU'
hidden_size = 4
num_layers = 1

batch_size = 2

init_state = ZerosState(cell, hidden_size, num_layers)
state = init_state(batch_size)
# ::state:: [num_layers, batch_size, hidden_size]
state.shape

torch.Size([1, 2, 4])

In [8]:
cell = 'GRU'
hidden_size = 4
num_layers = 1

batch_size = 2

init_state = TrainableState(cell, hidden_size, num_layers, a=-0.8, b=0.8)
state = init_state(batch_size)
# ::state:: [num_layers, batch_size, hidden_size]
state.shape

torch.Size([1, 2, 4])

## Init vars

In [4]:
enc_output = None
formula = None
num_variables = 5
variables =None

init_dec_var = BasicVars()
var = init_dec_var(enc_output, formula, num_variables, variables)
# ::var:: [batch_size, seq_len, feature_size]
var.shape

torch.Size([1, 5, 1])

## Init contexts

In [5]:
enc_output = None
formula = None
num_variables = 5
variables = None
batch_size = 2

init_dec_context = EmptyContext()
context = init_dec_context(enc_output, formula, num_variables, variables, batch_size)
# ::context:: [batch_size, feature_size=0]
context.shape

torch.Size([2, 0])

## Embeddings

In [2]:
num_labels = 5
embedding_size = 10

batch_size = 2
seq_len = 4
X = torch.randint(low=0, high=num_labels-1, size=(batch_size, seq_len, 1), dtype=torch.long)
# ::X:: [batch_size, seq_len, features_size=1]

embedding = BasicEmbedding(num_labels, embedding_size)
X = embedding(X)
# ::X:: [batch_size, seq_len, num_features=embedding_size]
X.shape

torch.Size([2, 4, 10])

In [4]:
batch_size = 2
seq_len = 4
feature_size = 5
X = torch.rand((batch_size, seq_len, feature_size))
# ::X:: [batch_size, seq_len, features_size=1]

embedding = IdentityEmbedding()
X = embedding(X)
# ::X:: [batch_size, seq_len, num_features=embedding_size]
X.shape

torch.Size([2, 4, 5])

## Encoders

In [3]:
cell = 'GRU'
embedding_size = 7
hidden_size = 16
num_layers = 1

enc_embedding = IdentityEmbedding()
X = torch.rand((2, 5, embedding_size))
# ::X:: [batch_size, seq_len, features_size=embedding_size]

encoder = RNNEncoder(cell = cell,
                     embedding = enc_embedding,
                     embedding_size = embedding_size,
                     hidden_size = hidden_size,
                     num_layers = num_layers,
                     dropout = 0)
encoder.eval()
output, state = encoder(X)
# output shape: [seq_len, batch_size, hidden_size]
# state shape: [num_layers, batch_size, hidden_size]
print(output.shape) 

torch.Size([5, 2, 16])


## Decoders

In [2]:
cell = 'GRU'
hidden_size = 16
num_layers = 1
dropout = 0
clip_logits_c = 0

batch_size = 2
seq_len = 3
embedding_size = 8
assignment_emb = BasicEmbedding(num_labels=3, embedding_size=embedding_size)
variable_emb = BasicEmbedding(num_labels=5, embedding_size=embedding_size)
input_size = embedding_size * 3

var = torch.randint(0, 4, [batch_size, seq_len, 1])
# ::var:: [batch_size, seq_len, feature_size]
a_prev = torch.randint(0, 2, [batch_size, seq_len, 1])
# ::a_prev:: [batch_size, seq_len]
context = torch.rand([batch_size, embedding_size])
# ::context:: [batch_size, feature_size]
state = torch.rand([num_layers, batch_size, hidden_size])
# ::state:: [num_layers, batch_size, hidden_size]
X = (var, a_prev, context)

decoder  = RNNDecoder(input_size = input_size,
                      cell = cell,
                      assignment_emb = assignment_emb,
                      variable_emb = variable_emb,
                      hidden_size = hidden_size,
                      num_layers = num_layers,
                      dropout = 0,
                      clip_logits_c = 0)
decoder.eval()
output, state = decoder(X, state)
# output shape: [batch_size, seq_len, 2]
# state shape: [num_layers, batch_size, hidden_size]
print(output.shape)
print(state.shape)

torch.Size([2, 3, 2])
torch.Size([1, 2, 16])
