# Train the Model using datas

In [1]:
%%capture
%run ../config/config.ipynb
%run ./data-loader.ipynb
%run ./Transformer.ipynb

In [2]:
from torch.optim import Adam
from datetime import datetime
import torch 
from tqdm import tqdm

In [3]:
# Prepare the model 
model = Transformer(
    src_pad_token=src_pad_token, 
    trg_pad_token=trg_pad_token, 
    enc_voc_size=enc_voc_size, 
    dec_voc_size=dec_voc_size, 
    n_head=n_head, 
    max_len=max_len, 
    d_model=d_model, 
    ffn_hidden=ffn_hidden, 
    n_layers=n_layers, 
    drop_prob=drop_prob, 
    device=device).to(device)

model.train()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'model parameter #: {count_parameters(model)}')


model parameter #: 6356745


In [8]:

class NoamScheduler:
    def __init__(self, optimizer, d_model, warmup_steps, LR_scale=1): 
        self.optimizer = optimizer
        self.step_count = 0 
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.LR_scale = LR_scale
        self._d_model_factor = self.LR_scale * (self.d_model ** -0.5)
    def step(self): 
        self.step_count += 1 
        lr = self.calculate_learning_rate()
        self.optimizer.param_groups[0]['lr'] = lr 
    def calculate_learning_rate(self): 
        minimum_factor = min(self.step_count ** -0.5, self.step_count * self.warmup_steps ** -1.5)
        return self._d_model_factor * minimum_factor

# Setup optimizer 
optimizer = Adam(params=model.parameters(), lr=init_lr, weight_decay=weight_decay, eps=eps)

# Set Noam Scheduler
scheduler = NoamScheduler(optimizer, d_model, warmup_steps)
# Setup loss function for training
loss_func = nn.CrossEntropyLoss(ignore_index=src_pad_token)


In [9]:
def train_epoch(epoch_num): 
    train_epoch_loss = 0 

    for step, (kr_sentences, en_sentences) in tqdm(enumerate(train_dataloader)): 

        # tokenize kr_sentence 
        kr_tokenized = kr_tokenizer(kr_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids
        
        # tokenize en_sentence 
        # make en_sentence start with eos token(this is because current tokenizer don't have an sos token.)
        en_sentences = ['</s> ' + s for s in en_sentences]
        en_tokenized = en_tokenizer(en_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids

        
        optimizer.zero_grad()

        kr_tokenized = kr_tokenized.to(device)
        en_tokenized = en_tokenized.to(device)

        # out is the dec_voc_size vector 
        # during training, we exclude sep token 

        # remove eos token if the sentence is too long, and gets truncated.
        # so we can prevent early-stopping(early-eos) 
        # out: batch_size * max_len * dec_voc_size
        out = model(kr_tokenized, en_tokenized[:, :-1])

        # remove sos token from en_tokenized when calculating loss because out will not include eos token in front of the sentence. 
        # en_tokenized: batch_size * (max_len-1)
        en_tokenized = en_tokenized[:, 1:].to(device)

        # out: batch_size * (max_len - 1) * dec_voc_size
        out = out.permute(0, 2, 1).to(device)

        loss = loss_func(out, en_tokenized)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        scheduler.step()

        train_epoch_loss += loss.item()
        
        if step % 50 == 0:
            print(f'EPOCH #{epoch_num} STEP #{step} | loss: {loss.item()}, avg_loss: {train_epoch_loss / (step + 1)}')

    train_step_loss = train_epoch_loss / (step+1)
    # After training epoch, do evaluation 

    return train_step_loss
    



In [10]:
# evaluate the model 
def evaluate(): 
    model.eval()
    test_epoch_loss = 0 
    test_bleu_loss = 0
    
    with torch.no_grad(): 
        for step, (kr_sentences, en_sentences) in tqdm(enumerate(test_dataloader)): 
            # tokenize kr_sentence 
            kr_tokenized = kr_tokenizer(kr_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids
            
            # tokenize en_sentence 
            # make en_sentence start with eos token(this is because current tokenizer don't have an sos token.)
            en_sentences = ['</s> ' + s for s in en_sentences]
            en_tokenized = en_tokenizer(en_sentences, padding=True, truncation=True, max_length=max_len, return_tensors="pt").input_ids

            kr_tokenized = kr_tokenized.to(device)
            en_tokenized = en_tokenized.to(device)


            # this does not remove the eos token 
            # FIXME 
            out = model(kr_tokenized, en_tokenized[:, :-1])
            

            # remove sos token from en_tokenized when calculating loss because out will not include sos token. 
            en_tokenized = en_tokenized[:, 1:].to(device)
    
            out = out.permute(0, 2, 1).to(device)
            
            loss = loss_func(out, en_tokenized)
            test_epoch_loss += loss.item()

            # calcuate the bleu 
            # TODO
        test_step_loss = test_epoch_loss / (step + 1)
    return test_step_loss

In [None]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
best_vloss = 100_000_000

for epoch in range(epochs):
    train_loss = train_epoch(epoch)
    test_loss = evaluate()

    print(f'Epoch {epoch}: Train Loss {train_loss}, Test Loss {test_loss}')

    if test_loss < best_vloss:
        best_vloss = test_loss 
        model_path = model_dir / f'model_{timestamp}_{epoch}' 
        torch.save(model.state_dict(), model_path)  

1it [00:00,  1.34it/s]

EPOCH #0 STEP #0 | loss: 11.167913436889648, avg_loss: 11.167913436889648


51it [00:29,  1.64it/s]

EPOCH #0 STEP #50 | loss: 11.144428253173828, avg_loss: 11.170808137631884


101it [00:57,  2.01it/s]

EPOCH #0 STEP #100 | loss: 11.07529354095459, avg_loss: 11.145308201855952


151it [01:24,  1.69it/s]

EPOCH #0 STEP #150 | loss: 10.903339385986328, avg_loss: 11.100729613904132


201it [01:48,  1.95it/s]

EPOCH #0 STEP #200 | loss: 10.793999671936035, avg_loss: 11.041111689894947


217it [01:56,  2.19it/s]