In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from utils import *

In [4]:
train_iter, val_iter, DE, EN = load_data(max_len=20, min_freq=5, batch_size=32)

In [20]:
class Seq2Seq(nn.Module):
    
    def __init__(self, enc_vocab_size, dec_vocab_size, enc_embed_dim, dec_embed_dim, hidden_size,
                 enc_num_layers, dec_num_layers):
        # Save hyperparameters
        super(Seq2Seq, self).__init__()
        self.enc_vocab_size = enc_vocab_size
        self.dec_vocab_size = dec_vocab_size
        self.enc_embed_dim = enc_embed_dim
        self.dec_embed_dim = dec_embed_dim
        self.hidden_size = hidden_size
        self.enc_num_layers = enc_num_layers
        self.dec_num_layers = dec_num_layers
        
        # Layers
        self.enc_embedding = nn.Embedding(enc_vocab_size, enc_embed_dim)
        self.enc_lstm = nn.LSTM(input_size=enc_embed_dim, hidden_size=hidden_size, num_layers=enc_num_layers)
        self.dec_embedding = nn.Embedding(dec_vocab_size, dec_embed_dim)
        self.dec_lstm = nn.LSTM(input_size=dec_embed_dim, hidden_size=hidden_size, num_layers=dec_num_layers)
        self.linear = nn.Linear(hidden_size, dec_vocab_size)
        
        # Weight initialization
        for p in self.enc_lstm.parameters():
            p.data.uniform_(-0.08, 0.08)
        for p in self.dec_lstm.parameters():
            p.data.uniform_(-0.08, 0.08)
    
    def forward(self, src, trg):        
        # Encoder
        enc_input = self.enc_embedding(src)
        _, hidden = self.enc_lstm(enc_input)
        
        # Decoder
        dec_input = self.dec_embedding(trg)
        output, _ = self.dec_lstm(dec_input, hidden)
        output = self.linear(output)
        log_probs = F.log_softmax(output, dim=2)
        return log_probs

In [192]:
PAD_IDX = EN.vocab.stoi[EN.pad_token]

In [193]:
s2s = Seq2Seq(enc_vocab_size=len(DE.vocab), dec_vocab_size=len(EN.vocab), 
              enc_embed_dim=1000, dec_embed_dim=1000, hidden_size=1000, 
              enc_num_layers=4, dec_num_layers=4)

optimizer = optim.SGD(params=s2s.parameters(), lr=0.7)
loss_func = nn.NLLLoss(ignore_index=PAD_IDX)

In [191]:
trg_input = batch.trg[:-1]
trg_output = batch.trg[1:]
log_probs = s2s(batch.src, trg_input)
loss = loss_func(log_probs.view(-1, len(EN.vocab)), trg_output.view(-1))
print(loss.data.numpy()[0])
#loss.backward()
#optimizer.step()

10.544413


In [178]:
log_probs.view(-1, len(EN.vocab)).gather(dim=0, index=Variable(pad_mask))

RuntimeError: Expected object of type Variable[torch.LongTensor] but found type Variable[torch.ByteTensor] for argument #2 'index'

In [42]:
loss_func(log_probs, trg_output)

ValueError: Expected 2 or 4 dimensions (got 3)

In [23]:
log_probs

Variable containing:
(  0  ,.,.) = 
 -9.3663 -9.4079 -9.3922  ...  -9.3665 -9.3700 -9.4017
 -9.3962 -9.3847 -9.3883  ...  -9.3876 -9.3993 -9.3922
 -9.3923 -9.3636 -9.3684  ...  -9.3778 -9.3797 -9.4046
           ...             ⋱             ...          
 -9.3825 -9.4073 -9.3846  ...  -9.4410 -9.4001 -9.4076
 -9.3447 -9.3790 -9.4073  ...  -9.4271 -9.4367 -9.3828
 -9.4030 -9.3787 -9.3682  ...  -9.3813 -9.3810 -9.3948

(  1  ,.,.) = 
 -9.3695 -9.3911 -9.3999  ...  -9.3650 -9.3933 -9.3750
 -9.3867 -9.3951 -9.3867  ...  -9.3795 -9.4205 -9.3736
 -9.4123 -9.3689 -9.3652  ...  -9.3674 -9.3808 -9.3767
           ...             ⋱             ...          
 -9.3727 -9.4051 -9.3952  ...  -9.4006 -9.4122 -9.3818
 -9.3394 -9.3859 -9.3884  ...  -9.4193 -9.4603 -9.3485
 -9.3991 -9.3729 -9.3645  ...  -9.3790 -9.4025 -9.3558

(  2  ,.,.) = 
 -9.3883 -9.3761 -9.3869  ...  -9.3706 -9.4057 -9.3441
 -9.3766 -9.4017 -9.3757  ...  -9.3776 -9.4232 -9.3633
 -9.4331 -9.3874 -9.3648  ...  -9.3599 -9.3788 -9.34

In [18]:
n_epochs = 30

In [16]:
for batch in train_iter:
    probs = s2s(batch.src, batch.trg)
    

Variable containing:
(  0  ,.,.) = 
1.00000e-04 *
  0.9086  0.8767  0.8406  ...   0.8312  0.8667  0.8676
  0.8599  0.8723  0.8775  ...   0.8443  0.8595  0.8847
  0.8765  0.8831  0.8667  ...   0.8624  0.8572  0.8702
           ...             ⋱             ...          
  0.8780  0.8896  0.9148  ...   0.8576  0.8404  0.8707
  0.9144  0.8611  0.9002  ...   0.8774  0.8762  0.8810
  0.8903  0.8734  0.8738  ...   0.8327  0.8584  0.8585

(  1  ,.,.) = 
1.00000e-04 *
  0.9047  0.8695  0.8538  ...   0.8336  0.8776  0.8597
  0.8576  0.8813  0.8825  ...   0.8583  0.8924  0.8739
  0.8845  0.8861  0.8734  ...   0.8838  0.8615  0.8698
           ...             ⋱             ...          
  0.8826  0.8759  0.8980  ...   0.8677  0.8751  0.8571
  0.9216  0.8464  0.8761  ...   0.8626  0.9146  0.8696
  0.8961  0.8649  0.8738  ...   0.8600  0.8766  0.8507

(  2  ,.,.) = 
1.00000e-04 *
  0.8877  0.8683  0.8575  ...   0.8428  0.9021  0.8648
  0.8582  0.8852  0.8840  ...   0.8611  0.9231  0.8668
  0.8833  

In [19]:
probs

Variable containing:
(  0  ,.,.) = 
1.00000e-04 *
  0.9086  0.8767  0.8406  ...   0.8312  0.8667  0.8676
  0.8599  0.8723  0.8775  ...   0.8443  0.8595  0.8847
  0.8765  0.8831  0.8667  ...   0.8624  0.8572  0.8702
           ...             ⋱             ...          
  0.8780  0.8896  0.9148  ...   0.8576  0.8404  0.8707
  0.9144  0.8611  0.9002  ...   0.8774  0.8762  0.8810
  0.8903  0.8734  0.8738  ...   0.8327  0.8584  0.8585

(  1  ,.,.) = 
1.00000e-04 *
  0.9047  0.8695  0.8538  ...   0.8336  0.8776  0.8597
  0.8576  0.8813  0.8825  ...   0.8583  0.8924  0.8739
  0.8845  0.8861  0.8734  ...   0.8838  0.8615  0.8698
           ...             ⋱             ...          
  0.8826  0.8759  0.8980  ...   0.8677  0.8751  0.8571
  0.9216  0.8464  0.8761  ...   0.8626  0.9146  0.8696
  0.8961  0.8649  0.8738  ...   0.8600  0.8766  0.8507

(  2  ,.,.) = 
1.00000e-04 *
  0.8877  0.8683  0.8575  ...   0.8428  0.9021  0.8648
  0.8582  0.8852  0.8840  ...   0.8611  0.9231  0.8668
  0.8833  