# Train the Model using datas

In [1]:
%%capture
%run ../config/config.ipynb
%run ./data-loader.ipynb
%run ./Transformer.ipynb

In [2]:
from torch.optim import Adam
from datetime import datetime
import torch 
from tqdm import tqdm

In [3]:
# Prepare the model 
model = Transformer(
    src_pad_token=src_pad_token, 
    trg_pad_token=trg_pad_token, 
    enc_voc_size=enc_voc_size, 
    dec_voc_size=dec_voc_size, 
    n_head=n_head, 
    max_len=max_len, 
    d_model=d_model, 
    ffn_hidden=ffn_hidden, 
    n_layers=n_layers, 
    drop_prob=drop_prob, 
    device=device).to(device)

model.train()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'model parameter #: {count_parameters(model)}')


model parameter #: 6356745


In [4]:
# Setup optimizer 
optimizer = Adam(params=model.parameters(), lr=init_lr, weight_decay=weight_decay, eps=eps)

# Setup loss function for training
loss_func = nn.CrossEntropyLoss(ignore_index=src_pad_token)


In [8]:
def train_epoch(epoch_num): 
    train_epoch_loss = 0 

    for step, (kr_sentences, en_sentences) in tqdm(enumerate(train_dataloader)): 

        # tokenize kr_sentence 
        kr_tokenized = kr_tokenizer(kr_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids
        
        # tokenize en_sentence 
        # make en_sentence start with eos token(this is because current tokenizer don't have an sos token.)
        en_sentences = ['</s> ' + s for s in en_sentences]
        en_tokenized = en_tokenizer(en_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids

        
        optimizer.zero_grad()

        kr_tokenized = kr_tokenized.to(device)
        en_tokenized = en_tokenized.to(device)

        # out is the dec_voc_size vector 
        # during training, we exclude sep token 

        # remove eos token if the sentence is too long, and gets truncated.
        # so we can prevent early-stopping(early-eos) 
        # out: batch_size * max_len * dec_voc_size
        out = model(kr_tokenized, en_tokenized[:, :-1])

        # remove sos token from en_tokenized when calculating loss because out will not include eos token in front of the sentence. 
        # en_tokenized: batch_size * (max_len-1)
        en_tokenized = en_tokenized[:, 1:].to(device)

        # out: batch_size * (max_len - 1) * dec_voc_size
        out = out.permute(0, 2, 1).to(device)

        loss = loss_func(out, en_tokenized)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        train_epoch_loss += loss.item()
        
        if step % 10 == 0:
            print(f'EPOCH #{epoch_num} STEP #{step} | loss: {loss.item()}, avg_loss: {train_epoch_loss / (step + 1)}')

    train_step_loss = train_epoch_loss / (step+1)
    # After training epoch, do evaluation 

    return train_step_loss
    



In [9]:
# evaluate the model 
def evaluate(): 
    model.eval()
    test_epoch_loss = 0 
    test_bleu_loss = 0
    
    with torch.no_grad(): 
        for step, (kr_sentences, en_sentences) in tqdm(enumerate(test_dataloader)): 
            # tokenize kr_sentence 
            kr_tokenized = kr_tokenizer(kr_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids
            
            # tokenize en_sentence 
            # make en_sentence start with eos token(this is because current tokenizer don't have an sos token.)
            en_sentences = ['</s> ' + s for s in en_sentences]
            en_tokenized = en_tokenizer(en_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids

            kr_tokenized = kr_tokenized.to(device)
            en_tokenized = en_tokenized.to(device)


            # this does not remove the eos token 
            # FIXME 
            out = model(kr_tokenized, en_tokenized[:, :-1])
            

            # remove sos token from en_tokenized when calculating loss because out will not include sos token. 
            en_tokenized = en_tokenized[:, 1:].to(device)
    
            out = out.permute(0, 2, 1).to(device)
            
            loss = loss_func(out, en_tokenized)
            test_epoch_loss += loss.item()

            # calcuate the bleu 
            # TODO
    return test_step_loss

In [10]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
best_vloss = 100_000_000

for epoch in range(epochs):
    train_loss = train_epoch(epoch)
    test_loss = evaluate()

    print(f'Epoch {epoch}: Train Loss {train_loss}, Test Loss {test_loss}')

    if test_loss < best_vloss:
        best_vloss = test_loss 
        model_path = model_dir / f'model_{timestamp}_{epoch}' 
        torch.save(model.state_dict(), model_path)  

1it [00:01,  1.96s/it]

EPOCH #0 STEP #0 | loss: 8.861778259277344, avg_loss: 8.861778259277344


11it [00:19,  1.84s/it]

EPOCH #0 STEP #10 | loss: 6.259122848510742, avg_loss: 7.432485016909513


21it [00:36,  1.79s/it]

EPOCH #0 STEP #20 | loss: 5.9041032791137695, avg_loss: 6.7086379414512995


31it [00:53,  1.88s/it]

EPOCH #0 STEP #30 | loss: 5.693465232849121, avg_loss: 6.399760230895011


41it [01:10,  1.63s/it]

EPOCH #0 STEP #40 | loss: 5.653689384460449, avg_loss: 6.234608510645424


51it [01:30,  1.67s/it]

EPOCH #0 STEP #50 | loss: 5.683628082275391, avg_loss: 6.132588797924566


61it [01:46,  1.60s/it]

EPOCH #0 STEP #60 | loss: 5.643075466156006, avg_loss: 6.057823720525523


71it [02:05,  2.12s/it]

EPOCH #0 STEP #70 | loss: 5.691100597381592, avg_loss: 6.006386736748924


81it [02:24,  1.95s/it]

EPOCH #0 STEP #80 | loss: 5.619980812072754, avg_loss: 5.959612398971746


91it [02:42,  1.99s/it]

EPOCH #0 STEP #90 | loss: 5.46574068069458, avg_loss: 5.913847640320495


101it [03:00,  1.82s/it]

EPOCH #0 STEP #100 | loss: 5.295058727264404, avg_loss: 5.861514979069776


111it [03:19,  1.66s/it]

EPOCH #0 STEP #110 | loss: 5.204499244689941, avg_loss: 5.808224274231507


121it [03:40,  1.82s/it]

EPOCH #0 STEP #120 | loss: 5.075936317443848, avg_loss: 5.753798177419615


131it [04:01,  2.44s/it]

EPOCH #0 STEP #130 | loss: 4.932798385620117, avg_loss: 5.704099640591454


141it [04:20,  2.15s/it]

EPOCH #0 STEP #140 | loss: 5.02657413482666, avg_loss: 5.657996397491888


151it [04:39,  1.77s/it]

EPOCH #0 STEP #150 | loss: 4.899193286895752, avg_loss: 5.609607368115558


161it [04:59,  2.19s/it]

EPOCH #0 STEP #160 | loss: 4.976208686828613, avg_loss: 5.567313961360766


171it [05:19,  2.21s/it]

EPOCH #0 STEP #170 | loss: 4.771450996398926, avg_loss: 5.52278839356718


181it [05:41,  2.13s/it]

EPOCH #0 STEP #180 | loss: 4.778852462768555, avg_loss: 5.482393836448206


191it [06:02,  2.19s/it]

EPOCH #0 STEP #190 | loss: 4.630377769470215, avg_loss: 5.441534421830902


201it [06:30,  2.88s/it]

EPOCH #0 STEP #200 | loss: 4.811519145965576, avg_loss: 5.4042598430197035


201it [06:33,  1.96s/it]


KeyboardInterrupt: 