In [1]:
from src.seq2seq import *
from src.attention import *
from src.utils import *
from src.layers import MaskedCrossEntropyLoss
import torch 
import torch.optim as optim

# Setup

In [2]:
# OPTIONS: 
# ENGLISH - en, 
# GERMAN - de, 
# FRENCH - fr, 
# CZECH - cs

lang1 = 'en'
lang2 = 'de'

src_sentences, tgt_sentences = load_data(lang1, lang2)

In [3]:
TEST_SIZE=0.2
BATCH_SIZE=32
MAX_VOCAB=10000

src_vocab, tgt_vocab, train_loader, test_loader = make_dataset(src_sentences, tgt_sentences, TEST_SIZE, BATCH_SIZE, MAX_VOCAB)

In [4]:
print(f"Number of training examples: {len(train_loader.dataset)}")
print(f"Number of testing examples: {len(test_loader.dataset)}")

Number of training examples: 24011
Number of testing examples: 6003


In [5]:
print(f"Unique tokens in source ({lang1}) vocabulary: {len(src_vocab)}")
print(f"Unique tokens in target ({lang2}) vocabulary: {len(tgt_vocab)}")

Unique tokens in source (en) vocabulary: 9516
Unique tokens in target (de) vocabulary: 10000


# Make the Model

In [7]:
# ENCODER ARGS
ENC_UNITS = 128
ENC_EMBEDDING = 128
SRC_VOCAB_SIZE = len(src_vocab)
ENC_NUM_LAYERS = 1

# DECODER ARGS
DEC_UNITS = ENC_UNITS
DEC_EMBEDDING = ENC_EMBEDDING
TGT_VOCAB_SIZE = len(tgt_vocab)
DEC_NUM_LAYERS = ENC_NUM_LAYERS

# SEQ2SEQ ARGS
TEACHER_FORCING = 0.
MAX_LENGTH = train_loader.dataset.tensors[1].size(-1) + 1
SOS_TOKEN = tgt_vocab.SOS_token

In [8]:
encoder = Encoder(ENC_UNITS, ENC_EMBEDDING, SRC_VOCAB_SIZE, ENC_NUM_LAYERS)
decoder = Decoder(DEC_UNITS, DEC_EMBEDDING, TGT_VOCAB_SIZE, DEC_NUM_LAYERS)

seq2seq = Seq2Seq(encoder, decoder, TEACHER_FORCING, MAX_LENGTH, SOS_TOKEN)

print(f'The model has {count_parameters(seq2seq):,} trainable parameters')

The model has 3,986,192 trainable parameters


In [9]:
criterion = MaskedCrossEntropyLoss(pad_tok=tgt_vocab.PAD_token)
optimizer = optim.Adam(seq2seq.parameters())

# Train

In [10]:
N_EPOCHS = 2
CLIP = 1

for epoch in range(N_EPOCHS):
    
    train_loss = train(seq2seq, train_loader, optimizer, criterion, CLIP, src_vocab.PAD_token)
    valid_loss = evaluate(seq2seq, test_loader, criterion)
    
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')


100%|██████████| 751/751 [06:49<00:00,  1.94it/s]


torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 

torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])


  0%|          | 0/751 [00:00<?, ?it/s]

torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([19, 47, 10000]) torch.Size([19, 47])
Epoch: 01
	Train Loss: 4.917 | Train PPL: 136.589
	 Val. Loss: 4.419 |  Val. PPL:  82.988



100%|██████████| 751/751 [07:00<00:00,  2.04it/s]


torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 

torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([32, 47, 10000]) torch.Size([32, 47])
torch.Size([19, 47, 10000]) torch.Size([19, 47])
Epoch: 02
	Train Los

In [19]:
sent = 'a woman on a subway is falling asleep .'

In [20]:
translate(sent, seq2seq, src_vocab, tgt_vocab, src_vocab.PAD_token)

(['<sos> eine frau in einem einem . . <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>'],
 None)