In [3]:
from src.seq2seq import *
from src.attention import *
from src.utils import *
from src.layers import MaskedCrossEntropyLoss
import torch 
import torch.optim as optim
import random

# Setup

In [56]:
# OPTIONS: 
# ENGLISH - en, 
# GERMAN - de, 
# FRENCH - fr, 
# CZECH - cs

lang1 = 'de'
lang2 = 'en'

src_sentences, tgt_sentences = load_data(lang1, lang2)

In [57]:
TEST_SIZE=0.2
BATCH_SIZE=32
MAX_VOCAB=17000

src_vocab, tgt_vocab, train_loader, valid_loader = make_dataset(src_sentences, tgt_sentences, TEST_SIZE, BATCH_SIZE, MAX_VOCAB)

In [58]:
print(f"Number of training examples: {len(train_loader.dataset)}")
print(f"Number of validation examples: {len(valid_loader.dataset)}")

Number of training examples: 24011
Number of validation examples: 6003


In [59]:
print(f"Unique tokens in source ({lang1}) vocabulary: {len(src_vocab)}")
print(f"Unique tokens in target ({lang2}) vocabulary: {len(tgt_vocab)}")

Unique tokens in source (de) vocabulary: 16441
Unique tokens in target (en) vocabulary: 9381


# Make the Model

In [51]:
# ENCODER ARGS
ENC_UNITS = 128
ENC_EMBEDDING = 128
SRC_VOCAB_SIZE = len(src_vocab)
ENC_NUM_LAYERS = 1

# ATTENTION DECODER ARGS
DEC_UNITS = ENC_UNITS
DEC_EMBEDDING = ENC_EMBEDDING
TGT_VOCAB_SIZE = len(tgt_vocab)
DEC_NUM_LAYERS = ENC_NUM_LAYERS
'''
Choices = [
    ConcatAttention, GeneralAttention, 
    DotAttention, MeanAttention, LastInSeqAttention
]
'''
ATTN_LAYER = GeneralAttention
ATTN_HIDDEN_SIZE = 64

# SEQ2SEQ ARGS
TEACHER_FORCING = 0.
MAX_LENGTH = train_loader.dataset.tensors[1].size(-1) + 1
SOS_TOKEN = tgt_vocab.SOS_token

In [52]:
encoder = Encoder(ENC_UNITS, ENC_EMBEDDING, SRC_VOCAB_SIZE, ENC_NUM_LAYERS)
decoder = AttentionDecoder(DEC_UNITS, DEC_EMBEDDING, TGT_VOCAB_SIZE, DEC_NUM_LAYERS, ATTN_LAYER, ATTN_HIDDEN_SIZE)

seq2seq = Seq2Seq(encoder, decoder, TEACHER_FORCING, MAX_LENGTH, SOS_TOKEN)

print(f'The model has {count_parameters(seq2seq):,} trainable parameters')

The model has 3,985,278 trainable parameters


In [53]:
print(seq2seq)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(10063, 128)
    (gru): GRU(128, 128, batch_first=True)
  )
  (decoder): AttentionDecoder(
    (embedding): Embedding(9469, 128)
    (gru): GRU(256, 128, batch_first=True)
    (attention): GeneralAttention(
      (W_a): Bilinear(in1_features=128, in2_features=128, out_features=1, bias=True)
    )
    (fc): Linear(in_features=128, out_features=9469, bias=True)
  )
)


In [54]:
criterion = MaskedCrossEntropyLoss(pad_tok=tgt_vocab.PAD_token)
optimizer = optim.Adam(seq2seq.parameters())

# Train

In [55]:
N_EPOCHS = 2
CLIP = 1

for epoch in range(N_EPOCHS):
    print(f'Epoch: {epoch+1:02}')
    
    train_loss = train(seq2seq, train_loader, optimizer, criterion, CLIP, src_vocab.PAD_token)
    valid_loss = evaluate(seq2seq, valid_loader, criterion)
    
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')


Epoch: 01



  0%|          | 0/751 [00:00<?, ?it/s][A[A[A


  0%|          | 1/751 [00:03<39:38,  3.17s/it][A[A[A

KeyboardInterrupt: 

In [60]:
idx = 5

src_sentence = train_loader.dataset.tensors[0][idx:idx+1]
tgt_sentence = train_loader.dataset.tensors[1][idx:idx+1]

src_sentence = src_vocab.to_string(src_sentence, remove_special=True)[0]
tgt_sentence = tgt_vocab.to_string(tgt_sentence, remove_special=True)[0]

In [61]:
translation, attention = translate(src_sentence, seq2seq, src_vocab, tgt_vocab, src_vocab.PAD_token)

In [62]:
print(f"> {src_sentence}")
print(f"= {tgt_sentence}")
print(f"< {translation}")

> ein junger mann in einem roten gewand lachelt .
= a young man in a red robe is smiling .
< motorists rat doll pattern section handing hoist canes fair-skinned pledge frying pledge scientific-looking barefooted themselves themselves slouching device arab slinging mache mache uniform te tipped specialty tech workman points clarinets trails walter perspiring foods tumbling israeli chain headgear corduroy seated seated protective lubricates


In [None]:
plot_attention(attention, src_sentence, translation)