In [1]:
import torch
import torch.nn as nn
import random

INPUT_DIM = 1111
OUTPUT_DIM = 2222
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

device = 'cpu'

In [36]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        
        self.input_dim = input_dim
        self.emb_dim = emb_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
#         self.dropout = dropout
        
        self.embedding = nn.Embedding(
            num_embeddings=input_dim,
            embedding_dim=emb_dim
        )
        
        self.gru = nn.GRU(
            input_size=emb_dim,
            hidden_size=hid_dim,
            num_layers=n_layers,
            dropout=dropout
        )
        
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, src):
        
        #src = [src sent len, batch size]
        
        # Compute an embedding from the src data and apply dropout to it
        embedded = self.embedding(src)# <YOUR CODE HERE>
        
        embedded = self.dropout(embedded)
        
        output, hidden = self.gru(embedded)
        #embedded = [src sent len, batch size, emb dim]
        
        # Compute the RNN output values of the encoder RNN. 
        # outputs, hidden and cell should be initialized here. Refer to nn.LSTM docs ;)
        
        # <YOUR CODE HERE> 
        
        #outputs = [src sent len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #outputs are always from the top hidden layer
        
        return output, hidden

In [37]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()

        self.emb_dim = emb_dim
        self.hid_dim = hid_dim
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding = nn.Embedding(
            num_embeddings=output_dim,
            embedding_dim=emb_dim
        )
            # <YOUR CODE HERE>
        
        self.gru = nn.GRU(
            input_size=emb_dim,
            hidden_size=hid_dim,
            num_layers=n_layers,
            dropout=dropout
        )
            # <YOUR CODE HERE>
        
        self.
        
        self.out = nn.Linear(
            in_features=2 * hid_dim,
            out_features=output_dim
        )
            # <YOUR CODE HERE>
        
        self.dropout = nn.Dropout(p=dropout)# <YOUR CODE HERE>
        
    def forward(self, input, hidden, encoder_output):
        
        #input = [batch size]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #n directions in the decoder will both always be 1, therefore:
        #hidden = [n layers, batch size, hid dim]
        #context = [n layers, batch size, hid dim]
        
        input = input.unsqueeze(0)
        
        #input = [1, batch size]
        
        # Compute an embedding from the input data and apply dropout to it
        embedded = self.dropout(self.embedding(input))# <YOUR CODE HERE>
        
        #embedded = [1, batch size, emb dim]
        
        # Compute the RNN output values of the encoder RNN. 
        # outputs, hidden and cell should be initialized here. Refer to nn.LSTM docs ;)
        # <YOUR CODE HERE>
        
        
        #output = [sent len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #sent len and n directions will always be 1 in the decoder, therefore:
        #output = [1, batch size, hid dim]
        #hidden = [n layers, batch size, hid dim]
        #cell = [n layers, batch size, hid dim]
        
        
        output, hidden = self.gru(embedded, hidden)
        
        attn_scores = torch.bmm(
            output.transpose(0, 1),
            encoder_output.transpose(0, 1).transpose(2, 1),   
        ).transpose(0, 1)

        attn_scores_sftmax = nn.functional.softmax(attn_scores, dim=2)
        attn_matr = attn_scores_sftmax.transpose(0, 2) * encoder_output
        attn_vect = attn_matr.sum(dim=0)
        
        prediction = self.out(torch.cat([attn_vect, output.squeeze(0)], dim=1))
        
        #prediction = [batch size, output dim]
        
        return prediction, hidden

In [38]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
        assert encoder.hid_dim == decoder.hid_dim, \
            "Hidden dimensions of encoder and decoder must be equal!"
        assert encoder.n_layers == decoder.n_layers, \
            "Encoder and decoder must have equal number of layers!"
        
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        
        #src = [src sent len, batch size]
        #trg = [trg sent len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        
        # Again, now batch is the first dimention instead of zero
        batch_size = trg.shape[1]
        max_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        #tensor to store decoder outputs
        outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)
        
        #last hidden state of the encoder is used as the initial hidden state of the decoder
        encoder_output, hidden = self.encoder(src)
        
        #first input to the decoder is the <sos> tokens
        input = trg[0,:]
        
        for t in range(1, max_len):
            
            output, hidden = self.decoder(input, hidden, encoder_output)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.max(1)[1]
            input = (trg[t] if teacher_force else top1)
        
        return outputs

In [39]:
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)

model = Seq2Seq(enc, dec, device).to(device)
model

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(1111, 256)
    (gru): GRU(256, 512, num_layers=2, dropout=0.5)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (embedding): Embedding(2222, 256)
    (gru): GRU(256, 512, num_layers=2, dropout=0.5)
    (out): Linear(in_features=1024, out_features=2222, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [40]:
SENT_LEN = 21
BATCH_SIZE = 10

batch_src = torch.randint(0, 1110, (SENT_LEN, BATCH_SIZE))
batch_trg = torch.randint(0, 1110, (SENT_LEN - 1, BATCH_SIZE))

print(f"batch SRC shape: {batch_src.shape}")
print(f"batch TRG shape: {batch_trg.shape}")

batch SRC shape: torch.Size([21, 10])
batch TRG shape: torch.Size([20, 10])


In [41]:
model_output = model(batch_src, batch_trg)
print(f"model output shape {model_output.shape}")

model output shape torch.Size([20, 10, 2222])


In [144]:
asd = torch.ones([10, 512])

In [147]:
torch.cat([asd, asd], dim=1).shape

torch.Size([10, 1024])

In [20]:
output.shape

torch.Size([1, 10, 512])

In [30]:
output.transpose(0, 1).shape

torch.Size([10, 1, 512])

In [32]:
tmp_linear = nn.Linear(
    in_features=512,
    out_features=512
)

In [35]:
tmp_linear(output).shape

torch.Size([1, 10, 512])

In [31]:
encoder_output.transpose(0, 1).transpose(2, 1).shape

torch.Size([10, 512, 21])

In [29]:
attn_scores = torch.bmm(
    output.transpose(0, 1),
    encoder_output.transpose(0, 1).transpose(2, 1),   
).transpose(0, 1)

attn_scores_sftmax = nn.functional.softmax(attn_scores, dim=2)
attn_matr = attn_scores_sftmax.transpose(0, 2) * encoder_output
attn_vect = attn_matr.sum(dim=0)

# prediction = self.out(torch.cat([attn_vect, output.squeeze(0)], dim=1))