In [1]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import pytorch_lightning as pl

import numpy as np
import os


In [2]:
TEMP_CONFIG = {'dictonary_size': 10, 'features': 64, 'hidden_size': 128, 'max_sentence_length': 10}
dict_size = TEMP_CONFIG['dictonary_size']
features = TEMP_CONFIG['features']
hidden_size = TEMP_CONFIG['hidden_size']
max_sentence_length = TEMP_CONFIG['max_sentence_length']

In [3]:
emb = nn.Embedding(10, 5)
rnn = nn.RNN(5, 20, batch_first=True)

In [4]:
rand = torch.randint(low=0, high=5, size=(1, 6))
print(rand.shape)
embs = emb(rand)
print(embs.shape)
outs, hidden = rnn(embs, torch.zeros(1, 1, 20))
print(outs.shape, hidden.shape)

torch.Size([1, 6])
torch.Size([1, 6, 5])
torch.Size([1, 6, 20]) torch.Size([1, 1, 20])


In [114]:
class Encoder(torch.nn.Module):
    def __init__(self, max_sentence_length, dictionary_size, features, hidden_size) -> None:
        super(Encoder, self).__init__()
        self.dictionary_size = dictionary_size
        self.hidden_size = hidden_size
        self.features = features
        self.max_sentence_length = max_sentence_length
        self.emb = nn.Embedding(dictionary_size, features)
        self.rnn = nn.RNN(input_size=features, hidden_size=hidden_size, batch_first=True)
        
    def forward(self, batch, hidden_state):
        embeddings = self.emb(batch)
        print(embeddings.shape)
        output, hidden_state = self.rnn(embeddings, hidden_state)
        return output, hidden_state
    
    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size)

class Decoder(torch.nn.Module):
    def __init__(self, max_sentence_length, dictionary_size, hidden_size, output_size) -> None:
        super(Decoder, self).__init__()
        self.max_sentence_length = max_sentence_length
        self.dictionary_size = dictionary_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        
        self.emb = nn.Embedding(self.dictionary_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size*2, self.max_sentence_length)
        # self.attn_combined = nn.Linear(self.hidden_size*3, self.hidden_size)
        self.gru = nn.GRU(self.hidden_size*3, self.hidden_size, batch_first=True)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, batch, hidden_state, encoder_outputs):
        embeddings = self.emb(batch)
        cat = torch.cat([embeddings[0], hidden_state[0]], dim=1)
        attn = self.attn(cat)
        attn_weights = F.softmax(attn, dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs)
        output = torch.cat([embeddings[0], attn_applied[0], hidden_state[0]], dim=1).unsqueeze(0)
        output, hidden_state = self.gru(output, hidden_state)
        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size)


In [115]:
encoder = Encoder(max_sentence_length, dict_size, features, hidden_size)
rand_int = torch.randint(low=0, high=dict_size, size=(1, max_sentence_length))

In [116]:
rand_int

tensor([[7, 0, 9, 7, 5, 1, 5, 0, 9, 8]])

In [117]:
hidden_0 = encoder.init_hidden()
output, hidden = encoder(rand_int, hidden_0)

torch.Size([1, 10, 64])


In [118]:
output.shape, hidden.shape

(torch.Size([1, 10, 128]), torch.Size([1, 1, 128]))

In [119]:
decoder = Decoder(max_sentence_length=max_sentence_length, dictionary_size=20, hidden_size=hidden_size, output_size=10)

In [120]:
rand_int = torch.randint(low=0, high=dict_size, size=(1, 1))

In [121]:
hidden_0 = decoder.init_hidden()

In [122]:
input = torch.randn(1, 1, 4)
mat2 = torch.randn(1, 4, 5)
res = torch.bmm(input, mat2)
res.size()

torch.Size([1, 1, 5])

In [125]:
decoder(rand_int, hidden_0, output) 

(tensor([[-2.2863, -2.3585, -2.3466, -2.2032, -2.2372, -2.3966, -2.1776, -2.2995,
          -2.3824, -2.3653]], grad_fn=<LogSoftmaxBackward0>),
 tensor([[[ 0.2300, -0.4224,  0.3249, -0.2062, -0.4335,  0.8271,  0.0484,
           -0.3034, -0.7401, -0.3866,  0.1695, -0.1123, -0.2303,  0.2145,
            0.3402,  0.0562, -0.2733,  0.2462, -0.5417, -0.1870, -0.1122,
           -0.0526, -0.1126,  0.4738, -0.2139, -0.1675,  0.3899, -0.2158,
            0.1488, -0.3018, -0.2774,  0.0140, -0.4432, -0.0302,  0.2154,
           -0.3638, -0.2621,  0.1858, -0.4662,  0.3248, -0.1305,  0.1152,
            0.2403, -0.1253,  0.4422, -0.2238,  0.3515, -0.1218, -0.1680,
           -0.1718, -0.0535, -0.3999, -0.1258,  0.0321, -0.0918,  0.1172,
            0.2668,  0.1775, -0.4352, -0.1927,  0.5146,  0.2775,  0.1411,
            0.2379, -0.3669,  0.6606,  0.0108, -0.4858,  0.0624, -0.3505,
           -0.0713, -0.4220,  0.2365, -0.4199,  0.2361, -0.0405,  0.4378,
           -0.5760,  0.3931,  0.5880, -0.4