# Train the Model using datas

In [1]:
%%capture
%run ../config/config.ipynb
%run ./data-loader.ipynb
%run ./Transformer.ipynb

In [2]:
from torch.optim import Adam
from datetime import datetime
import torch 
from tqdm import tqdm

In [3]:
# Prepare the model 
model = Transformer(
    src_pad_token=src_pad_token, 
    trg_pad_token=trg_pad_token, 
    enc_voc_size=enc_voc_size, 
    dec_voc_size=dec_voc_size, 
    n_head=n_head, 
    max_len=max_len, 
    d_model=d_model, 
    ffn_hidden=ffn_hidden, 
    n_layers=n_layers, 
    drop_prob=drop_prob, 
    device=device).to(device)

model.train()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'model parameter #: {count_parameters(model)}')
logger.info(f'model parameter #: {count_parameters(model)}')


model parameter #: 53940457


In [4]:
# Setup optimizer 
optimizer = Adam(params=model.parameters(), lr=init_lr, weight_decay=weight_decay, eps=eps, betas=(0.9, 0.98))

# Set Noam Scheduler
scheduler = LRScheduler(optimizer, d_model, warmup_steps)
# Setup loss function for training
loss_func = nn.CrossEntropyLoss(ignore_index=src_pad_token)




In [None]:
# store lr rate history per steps 
lr_history = []
# store loss history per steps 
train_loss_history = []

In [5]:
def train_epoch(epoch_num): 
    model.train()
    train_epoch_loss = 0 

    for step, (kr_sentences, en_sentences) in tqdm(enumerate(train_dataloader)): 

        # tokenize kr_sentence 
        kr_tokenized = kr_tokenizer(kr_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids
        
        # tokenize en_sentence 
        # make en_sentence start with eos token(this is because current tokenizer don't have an sos token.)
        en_sentences = ['</s> ' + s for s in en_sentences]
        en_tokenized = en_tokenizer(en_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids
        

        kr_tokenized = kr_tokenized.to(device)
        en_tokenized = en_tokenized.to(device)

        # out is the dec_voc_size vector 
        # during training, we exclude sep token 

        # remove eos token if the sentence is too long, and gets truncated.
        # so we can prevent early-stopping(early-eos) 
        # out: batch_size * max_len * dec_voc_size
        out = model(kr_tokenized, en_tokenized[:, :-1])

        # remove sos token from en_tokenized when calculating loss because out will not include eos token in front of the sentence. 
        # en_tokenized: batch_size * (max_len-1)
        en_tokenized = en_tokenized[:, 1:].to(device)

        # out: batch_size * (max_len - 1) * dec_voc_size
        out = out.permute(0, 2, 1).to(device)

        loss = loss_func(out, en_tokenized)
        optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        scheduler.step()

        train_epoch_loss += loss.item()


        
        if step % 200 == 0:
            print(f'    EPOCH #{epoch_num} STEP #{step} | loss: {loss.item()}, avg_loss: {train_epoch_loss / (step + 1)}')
            logger.info(f'    EPOCH #{epoch_num} STEP #{step} | loss: {loss.item()}, avg_loss: {train_epoch_loss / (step + 1)}')
        
        

    train_step_loss = train_epoch_loss / (step+1)
    # After training epoch, do evaluation 

    return train_step_loss

In [6]:
# evaluate the model 
def evaluate(): 
    model.eval()
    test_epoch_loss = 0 
    test_bleu_loss = 0
    
    with torch.no_grad(): 
        for step, (kr_sentences, en_sentences) in tqdm(enumerate(test_dataloader)): 
            # tokenize kr_sentence 
            kr_tokenized = kr_tokenizer(kr_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids
            
            # tokenize en_sentence 
            # make en_sentence start with eos token(this is because current tokenizer don't have an sos token.)
            en_sentences = ['</s> ' + s for s in en_sentences]
            en_tokenized = en_tokenizer(en_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids

            kr_tokenized = kr_tokenized.to(device)
            en_tokenized = en_tokenized.to(device)


            # this does not remove the eos token 
            # FIXME 
            out = model(kr_tokenized, en_tokenized[:, :-1])
            

            # remove sos token from en_tokenized when calculating loss because out will not include sos token. 
            en_tokenized = en_tokenized[:, 1:].to(device)
    
            out = out.permute(0, 2, 1).to(device)
            
            loss = loss_func(out, en_tokenized)
            test_epoch_loss += loss.item()

            # calcuate the bleu 
            # TODO
        test_step_loss = test_epoch_loss / (step + 1)
    return test_step_loss

In [7]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

for epoch in range(epochs):
    print(f'Epoch #{epoch} Start: current LR {optimizer.param_groups[0]["lr"]}')
    logger.info(f'Epoch #{epoch} Start: current LR {optimizer.param_groups[0]["lr"]}')
    
    train_loss = train_epoch(epoch)
    test_loss = evaluate()
    lr_history.append(optimizer.param_groups[0]["lr"])
    train_loss_history.append(train_loss)

    logger.info(f'Epoch #{epoch} End: Train Loss {train_loss}, Test Loss {test_loss}')

    model_path = model_dir / f'model_{timestamp}_{epoch}' 
    torch.save(model.state_dict(), model_path) 


Epoch #0 Start: current LR 0.0


1it [00:04,  4.01s/it]

    EPOCH #0 STEP #0 | loss: 11.198524475097656, avg_loss: 11.198524475097656


201it [10:51,  3.17s/it]

    EPOCH #0 STEP #200 | loss: 6.86619234085083, avg_loss: 9.200099686485025


202it [10:56,  3.25s/it]


KeyboardInterrupt: 