In [1]:
import torch
from src.architectures import RNNEncoder, RNNDecoder
from src.embeddings import BasicEmbedding
from src.initial_states import ZerosState, TrainableState

Testing Embeddings

In [3]:
embedding = BasicEmbedding(num_labels=5, embedding_size=6)
embedding.eval()

# X must be: [batch_size, seq_len]
X = torch.randint(low=0, high=4, size=(2, 4), dtype=torch.long)
output = embedding(X)
# output shape: [batch_size, seq_len, num_features=embedding_size]
output.shape

torch.Size([2, 4, 6])

Testing RNNEncoder

In [4]:
cell = 'GRU'
embedding_size = 7
hidden_size = 16
num_layers = 1
enc_embedding = BasicEmbedding(num_labels=5, embedding_size=embedding_size)

encoder = RNNEncoder(cell=cell, embedding=enc_embedding, embedding_size=embedding_size, hidden_size=hidden_size,
                    num_layers=num_layers, dropout=0)
encoder.eval()

# X must be: [batch_size, seq_len]
X = torch.randint(low=0, high=4, size=(2, 4), dtype=torch.long)
output, state = encoder(X)
# output shape: [seq_len, batch_size, hidden_size]
# state shape: [num_layers, batch_size, hidden_size]
print(output.shape)

torch.Size([4, 2, 16])


Testing Initial States

In [10]:
cell = 'GRU'
batch_size = 2
hidden_size = 4
num_layers = 1

init_state = ZerosState(cell, batch_size, hidden_size, num_layers)
state = init_state()
# state shape [num_layers, batch_size, hidden_size]
state.shape

torch.Size([1, 2, 4])

In [8]:
cell = 'GRU'
batch_size = 2
hidden_size = 4
num_layers = 1

init_state = TrainableState(cell, batch_size, hidden_size, num_layers)
state = init_state()
# state shape [num_layers, batch_size, hidden_size]
state.shape

torch.Size([1, 2, 4])

Testing RNNDecoder

In [2]:
cell = 'GRU'
embedding_size = 7
input_size = embedding_size * 3
hidden_size = 16
num_layers = 1

batch_size = 2
seq_len = 3

assignment_emb = BasicEmbedding(num_labels=3, embedding_size=embedding_size)
variable_emb = BasicEmbedding(num_labels=5, embedding_size=embedding_size)

decoder  = RNNDecoder(input_size=input_size, cell=cell, assignment_emb=assignment_emb, variable_emb=variable_emb,
                        hidden_size=hidden_size, num_layers=num_layers)
decoder.eval()

# var must be: [batch_size, seq_len]
var = torch.randint(0, 4, [batch_size, seq_len])

# a_prev must be: [batch_size, seq_len]
a_prev = torch.randint(0, 2, [batch_size, seq_len])

# context must be: [batch_size, feature_size]
context = torch.rand([batch_size, embedding_size])

# state must be: [num_layers, batch_size, hidden_size]
state = torch.rand([num_layers, batch_size, hidden_size])

X = (var, a_prev, context)
output, state = decoder(X, state)
# output shape: [batch_size, seq_len, 2]
# state shape: [num_layers, batch_size, hidden_size]
print(output.shape)
print(state.shape)

torch.Size([2, 3, 2])
torch.Size([1, 2, 16])


In [None]:
cell = 'GRU'
embedding_size = 7
input_size = embedding_size * 3
hidden_size = 16
num_layers = 1

batch_size = 2
seq_len = 3

assignment_emb = BasicEmbedding(num_labels=3, embedding_size=embedding_size)
variable_emb = BasicEmbedding(num_labels=5, embedding_size=embedding_size)

decoder  = RNNDecoder(input_size=input_size, cell=cell, assignment_emb=assignment_emb, variable_emb=variable_emb,
                        hidden_size=hidden_size, num_layers=num_layers)
decoder.eval()

# var must be: [batch_size, seq_len]
var = torch.randint(0, 4, [batch_size, seq_len])

# a_prev must be: [batch_size, seq_len]
a_prev = torch.randint(0, 2, [batch_size, seq_len])

# context must be: [batch_size, feature_size]
context = torch.rand([batch_size, embedding_size])

# state must be: [num_layers, batch_size, hidden_size]
state = torch.rand([num_layers, batch_size, hidden_size])

X = (var, a_prev, context)
output, state = decoder(X, state)
# output shape: [batch_size, seq_len, 2]
# state shape: [num_layers, batch_size, hidden_size]
print(output.shape)
print(state.shape)

In [14]:
var = torch.zeros([4, 2, 3])
# var shape: [seq_len, batch_size, features_size=var_embedding_size]

a_prev = torch.ones([4, 2, 3])
# var shape: [seq_len, batch_size, features_size=var_embedding_size]
        

# context must be: [batch_size, feature_size]
context = torch.rand([2, 0])
# Broadcasting context
context = context.repeat(var.shape[0], 1, 1)
# context shape: [seq_len, batch_size, features_size]

dec_input = torch.cat((var, a_prev, context), -1)
# dec_input shape: [seq_len, batch_size, features_size=input_size]

In [15]:
dec_input.shape

torch.Size([4, 2, 6])

In [15]:
class Class1():
    def __init__(self, var1, var2, **kwargs):
        self.var1= var1
        self.var2 = var2

    def method(self):
        raise NotImplementedError

class Class2(Class1):
    def __init__(self, var1, var2, var3, **kwargs):
        super().__init__(var1, var2, **kwargs)
        self.var3= var3


    def method(self):
        print(self.var1)
        print(self.var2)
        print(self.var3)

In [39]:
batch_size = 2
seq_len = 3
emb_size = 4
var = torch.zeros([batch_size, seq_len, emb_size]).permute(1,0,2)
context = torch.empty([batch_size, 0])

print(var)
print(context)

tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.]]])
tensor([], size=(2, 0))


In [40]:
context = context.repeat(var.shape[0], 1, 1)
context

tensor([], size=(3, 2, 0))

In [43]:
torch.cat((var, context), -1)
#3,2,8

tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.]]])

In [21]:
action_prev = torch.tensor(2).reshape(1,1)
action_prev.dtype

torch.int64

In [17]:
dec_X = torch.unsqueeze(torch.tensor([2], dtype=torch.long), dim=0)

In [20]:
dec_X.dtype

torch.int64

In [49]:
torch.tensor([2] * 5, dtype=torch.long).reshape(-1,1).shape
# ::action_prev:: [batch_size=1, seq_len=1]

torch.Size([5, 1])

In [51]:
torch.tensor([i for i in range(5)]).reshape(1,-1,1).shape

torch.Size([1, 5, 1])

In [53]:
t = torch.rand([2,3,5])
# ::var:: [batch_size, seq_len, feature_size]
t

tensor([[[0.5221, 0.4866, 0.6199, 0.5961, 0.7663],
         [0.2851, 0.2966, 0.9776, 0.7473, 0.8604],
         [0.2733, 0.1481, 0.4454, 0.6126, 0.7616]],

        [[0.4640, 0.7413, 0.3950, 0.3177, 0.7367],
         [0.9951, 0.9485, 0.0726, 0.0920, 0.8593],
         [0.8353, 0.4101, 0.1689, 0.0460, 0.5975]]])

In [61]:
t[:,2:3,:].shape

torch.Size([2, 1, 5])