In [1]:
from src.seq2seq import *
from src.attention import *
from src.utils import *
from src.layers import MaskedCrossEntropyLoss
import torch 
import torch.optim as optim

# Setup

In [2]:
# OPTIONS: 
# ENGLISH - en, 
# GERMAN - de, 
# FRENCH - fr, 
# CZECH - cs

lang1 = 'de'
lang2 = 'en'

# train_sentences, test_sentences = load_data(lang1, lang2)
# train_sentences = (train_sentences[0][:3000], train_sentences[1][:3000])
train_sentences = load_summary()

In [3]:
TEST_SIZE=0.2
BATCH_SIZE=64
VALID_BATCH_SIZE=64
MAX_VOCAB=20000

src_vocab, tgt_vocab, train_loader, valid_loader = make_dataset(train_sentences, train_sentences, BATCH_SIZE, VALID_BATCH_SIZE, MAX_VOCAB)

In [4]:
print(f"Number of training examples: {len(train_loader.dataset)}")
print(f"Number of testing examples: {len(valid_loader.dataset)}")
print(f"Training Batches {len(train_loader)}\tValidation Batches {len(valid_loader)}")

Number of training examples: 2000
Number of testing examples: 2000
Training Batches 32	Validation Batches 32


In [5]:
print(f"Unique tokens in source ({lang1}) vocabulary: {len(src_vocab)}")
print(f"Unique tokens in target ({lang2}) vocabulary: {len(tgt_vocab)}")

Unique tokens in source (de) vocabulary: 6132
Unique tokens in target (en) vocabulary: 3100


# Make the Model

In [6]:
# ENCODER ARGS
ENC_UNITS = 128
ENC_EMBEDDING = 256
SRC_VOCAB_SIZE = len(src_vocab)
ENC_NUM_LAYERS = 1

# DECODER ARGS
DEC_UNITS = ENC_UNITS
DEC_EMBEDDING = ENC_EMBEDDING
TGT_VOCAB_SIZE = len(tgt_vocab)
DEC_NUM_LAYERS = ENC_NUM_LAYERS

# SEQ2SEQ ARGS
TEACHER_FORCING = 1.0
MAX_LENGTH = train_loader.dataset.tensors[1].size(-1) + 1
SOS_TOKEN = tgt_vocab.SOS_token

In [7]:
encoder = Encoder(ENC_UNITS, ENC_EMBEDDING, SRC_VOCAB_SIZE, ENC_NUM_LAYERS)
decoder = Decoder(DEC_UNITS, DEC_EMBEDDING, TGT_VOCAB_SIZE, DEC_NUM_LAYERS)

seq2seq = Seq2Seq(encoder, decoder, TEACHER_FORCING, MAX_LENGTH, SOS_TOKEN)

print(f'The model has {count_parameters(seq2seq):,} trainable parameters')

The model has 3,059,740 trainable parameters


In [8]:
print(seq2seq)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(6132, 256)
    (gru): GRU(256, 128, batch_first=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(3100, 256)
    (gru): GRU(256, 128, batch_first=True)
    (fc): Linear(in_features=128, out_features=3100, bias=True)
  )
)


In [9]:
criterion = MaskedCrossEntropyLoss(pad_tok=tgt_vocab.PAD_token)
optimizer = optim.Adam(seq2seq.parameters())

# Train

In [10]:
tgt_vocab.SOS_token, tgt_vocab.EOS_token

(2, 3)

In [11]:
valid_loss = evaluate(seq2seq, valid_loader, criterion)

100%|██████████| 32/32 [00:01<00:00, 16.31it/s]


In [12]:
valid_loss

8.052558198571205

In [32]:
N_EPOCHS = 100
CLIP = 1

seq2seq.teacher_forcing = 0.

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    print(f'Epoch: {epoch+1:02}')
    
    train_loss = train(seq2seq, train_loader, optimizer, criterion, CLIP, src_vocab.PAD_token)
    valid_loss = evaluate(seq2seq, valid_loader, criterion)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(seq2seq.state_dict(), 'models/seq2seq.pt')
    
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')


Epoch: 01
100%|██████████| 32/32 [00:10<00:00,  3.76it/s]
100%|██████████| 32/32 [00:01<00:00, 19.45it/s]
	Train Loss: 5.162 | Train PPL: 174.546
	 Val. Loss: 5.053 |  Val. PPL: 156.557
Epoch: 02
100%|██████████| 32/32 [00:09<00:00,  3.71it/s]
100%|██████████| 32/32 [00:01<00:00, 19.44it/s]
	Train Loss: 5.085 | Train PPL: 161.517
	 Val. Loss: 4.978 |  Val. PPL: 145.135
Epoch: 03
100%|██████████| 32/32 [00:10<00:00,  3.69it/s]
100%|██████████| 32/32 [00:01<00:00, 18.91it/s]
	Train Loss: 5.004 | Train PPL: 149.074
	 Val. Loss: 4.908 |  Val. PPL: 135.386
Epoch: 04
100%|██████████| 32/32 [00:10<00:00,  3.47it/s]
100%|██████████| 32/32 [00:01<00:00, 18.34it/s]
	Train Loss: 4.936 | Train PPL: 139.232
	 Val. Loss: 4.816 |  Val. PPL: 123.435
Epoch: 05
100%|██████████| 32/32 [00:10<00:00,  3.65it/s]
100%|██████████| 32/32 [00:01<00:00, 18.41it/s]
	Train Loss: 4.867 | Train PPL: 129.941
	 Val. Loss: 4.757 |  Val. PPL: 116.438
Epoch: 06
100%|██████████| 32/32 [00:10<00:00,  3.58it/s]
100%|███████

100%|██████████| 32/32 [00:10<00:00,  3.58it/s]
100%|██████████| 32/32 [00:01<00:00, 18.92it/s]
	Train Loss: 2.580 | Train PPL:  13.194
	 Val. Loss: 2.474 |  Val. PPL:  11.867
Epoch: 46
100%|██████████| 32/32 [00:10<00:00,  3.00it/s]
100%|██████████| 32/32 [00:01<00:00, 16.32it/s]
	Train Loss: 2.508 | Train PPL:  12.284
	 Val. Loss: 2.434 |  Val. PPL:  11.410
Epoch: 47
100%|██████████| 32/32 [00:10<00:00,  3.24it/s]
100%|██████████| 32/32 [00:01<00:00, 18.09it/s]
	Train Loss: 2.474 | Train PPL:  11.873
	 Val. Loss: 2.395 |  Val. PPL:  10.969
Epoch: 48
100%|██████████| 32/32 [00:11<00:00,  2.59it/s]
100%|██████████| 32/32 [00:02<00:00, 12.88it/s]
	Train Loss: 2.435 | Train PPL:  11.415
	 Val. Loss: 2.340 |  Val. PPL:  10.382
Epoch: 49
100%|██████████| 32/32 [00:10<00:00,  3.56it/s]
100%|██████████| 32/32 [00:01<00:00, 18.56it/s]
	Train Loss: 2.415 | Train PPL:  11.193
	 Val. Loss: 2.318 |  Val. PPL:  10.154
Epoch: 50
100%|██████████| 32/32 [00:10<00:00,  3.57it/s]
100%|██████████| 32/32

100%|██████████| 32/32 [00:10<00:00,  3.77it/s]
100%|██████████| 32/32 [00:01<00:00, 19.55it/s]
	Train Loss: 1.060 | Train PPL:   2.885
	 Val. Loss: 0.995 |  Val. PPL:   2.703
Epoch: 90
100%|██████████| 32/32 [00:10<00:00,  3.73it/s]
100%|██████████| 32/32 [00:01<00:00, 19.59it/s]
	Train Loss: 1.047 | Train PPL:   2.848
	 Val. Loss: 0.968 |  Val. PPL:   2.633
Epoch: 91
100%|██████████| 32/32 [00:10<00:00,  3.59it/s]
100%|██████████| 32/32 [00:01<00:00, 18.85it/s]
	Train Loss: 0.992 | Train PPL:   2.697
	 Val. Loss: 0.936 |  Val. PPL:   2.550
Epoch: 92
100%|██████████| 32/32 [00:10<00:00,  3.60it/s]
100%|██████████| 32/32 [00:01<00:00, 18.95it/s]
	Train Loss: 0.967 | Train PPL:   2.629
	 Val. Loss: 0.913 |  Val. PPL:   2.492
Epoch: 93
100%|██████████| 32/32 [00:10<00:00,  3.61it/s]
100%|██████████| 32/32 [00:01<00:00, 18.93it/s]
	Train Loss: 0.935 | Train PPL:   2.548
	 Val. Loss: 0.890 |  Val. PPL:   2.434
Epoch: 94
100%|██████████| 32/32 [00:11<00:00,  3.14it/s]
100%|██████████| 32/32

# Translate

In [22]:
seq2seq.load_state_dict(torch.load('models/seq2seq.pt'))

<All keys matched successfully>

In [66]:
idx = 3

src_sentence = train_loader.dataset.tensors[0][idx:idx+1]
tgt_sentence = train_loader.dataset.tensors[1][idx:idx+1]

src_sentence = src_vocab.to_string(src_sentence, remove_special=True)[0]
tgt_sentence = tgt_vocab.to_string(tgt_sentence, remove_special=True)[0]

In [67]:
translation, attention = translate(src_sentence, seq2seq, src_vocab, tgt_vocab, src_vocab.PAD_token)

tensor([[   4,  100,  389,   14,  191,  675, 1063, 3670,  390,  391,  389,   19,
            9, 3671,  414,    9, 3672,   29,  191, 3673,   12,  960,   29, 3674,
         3675,   10,  415,    3,    5]])
tensor([[1497, 1498, 1073,  752, 1499,   79,  662,  662,  906,  522,    5,  522,
            3,    3,    3,    3]])


In [68]:
print(f"> {src_sentence}")
print(f"= {tgt_sentence}")
print(f"< {translation}")

> south korea s nuclear envoy kim sook urged north korea monday to restart work to disable its nuclear plants and stop its typical brinkmanship in negotiations .
= envoy urges north korea to restart nuclear disablement
< aga khan pours his wealth into islamic islamic sites syria in syria
