In [1]:
import json
import os
import numpy as np
from itertools import cycle

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from transformers import T5Tokenizer, MT5ForConditionalGeneration, AdamW, get_linear_schedule_with_warmup
from datasets import WebNLGDataset

# Google's Official Preprocess Codes
# https://github.com/google-research/language/blob/master/language/totto/baseline_preprocessing/preprocess_utils.py
from preprocess_utils import get_highlighted_subtable, linearize_subtable
from nltk.translate.bleu_score import sentence_bleu

In [2]:
# Train Config
device=torch.device('cpu')
lr=1e-4
batch_size=4 # 3 for 't5-large' and make 'accumulation_steps' larger
accumulation_steps=3
epochs=10
LOW_RESOURCE_SIZE = 0

In [3]:
# Pre-Trained T5 Tokenizer
tokenizer = T5Tokenizer.from_pretrained('google/mt5-base')
# Add Special Tokens: Table Tags
tokenizer.add_special_tokens({
    'additional_special_tokens': [
        '|',
        ':',
    ]
})
# Pre-Trained T5 Model
model = MT5ForConditionalGeneration.from_pretrained('google/mt5-base').to(device)
# Resize PLM's Embedding Layer
model.resize_token_embeddings(len(tokenizer))

Embedding(250100, 768)

In [4]:
dataset_train_en = WebNLGDataset(tokenizer=tokenizer, language='en')

if LOW_RESOURCE_SIZE > 0:
    sampled_train_indices = np.linspace(0, len(dataset_train)-1, num=LOW_RESOURCE_SIZE, dtype=int)
    subset_train = torch.utils.data.Subset(dataset_train, sampled_indices)
    dataloader_train_en = DataLoader(subset_train, batch_size=batch_size, shuffle=True, collate_fn=dataset_train_en.collate_fn)
else:
    dataloader_train_en = DataLoader(dataset_train_en, batch_size=batch_size, shuffle=True, collate_fn=dataset_train_en.collate_fn)
    
dataset_dev_en = WebNLGDataset(tokenizer=tokenizer, language='en', split='dev')
dataloader_dev_en = DataLoader(dataset_dev_en, batch_size=batch_size, shuffle=False, collate_fn=dataset_dev_en.collate_fn)

print(f'Initialized English train dataloader with {len(dataloader_train_en)} samples.')
print(f'Initialized English dev dataloader with {len(dataloader_dev_en)} samples.')


Initialized English train dataloader with 3658 samples.
Initialized English dev dataloader with 198 samples.


In [5]:
dataset_train_ru = WebNLGDataset(tokenizer=tokenizer, language='ru')
dataloader_train_ru = DataLoader(dataset_train_ru, batch_size=batch_size, shuffle=True, collate_fn=dataset_train_ru.collate_fn)
dataset_dev_ru = WebNLGDataset(tokenizer=tokenizer, language='ru', split='dev')
dataloader_dev_ru = DataLoader(dataset_dev_ru, batch_size=batch_size, shuffle=False, collate_fn=dataset_dev_ru.collate_fn)

print(f'Initialized Russian train dataloader with {len(dataloader_train_ru)} samples.')
print(f'Initialized Russian dev dataloader with {len(dataloader_dev_ru)} samples.')


Initialized Russian train dataloader with 3658 samples.
Initialized Russian dev dataloader with 198 samples.


In [6]:
references = []
with open("../webnlg_data/references/reference-en0", "r") as reference_file:
    for line in reference_file:
        stripped_line = line.strip().split()
        references.append(stripped_line)  

In [7]:
def evaluate_model(model, epoch):
    
    # PLM (Eval Mode)
    model.eval()

    # Trained Model
    model = model.to(device)
    model.eval()
    
    total_blue = 0
    with torch.no_grad():
        for language, dataloader in [('english', dataloader_dev_en), ('russian', dataloader_dev_ru)]:
            
            # Generation
            if os.path.exists(f'../webnlg_data/multilingual_fine_tuning/{language}_epoch{epoch}_lowsize{LOW_RESOURCE_SIZE}.txt'):
                os.remove(f'../webnlg_data/multilingual_fine_tuning/{language}_epoch{epoch}_lowsize{LOW_RESOURCE_SIZE}.txt')
            f = open(f'../webnlg_data/multilingual_fine_tuning/{language}_epoch{epoch}_lowsize{LOW_RESOURCE_SIZE}.txt', 'a')
            
            for idx, (data, attn_mask, _) in enumerate(dataloader):
                if (idx+1)%100==0: print(batch_size*(idx+1), 'generated')

                data=data.to(device)
                attn_mask=attn_mask.to(device)

                # Beam Search
                outputs = model.generate(
                    data,
                    max_length=300,
                    num_beams=5,
                    early_stopping=True,
                )

                blue_score = 0
                for idx2, generation in enumerate(tokenizer.batch_decode(outputs, skip_special_tokens=True)):
                    reference = references[idx + idx2]
                    candidate = generation.strip().split()
                    blue_score += sentence_bleu(reference, candidate) / batch_size
                    f.write(generation + '\n')
                total_blue += blue_score
                if idx == 1:
                    break

            f.close()

            f = open(f'../webnlg_data/multilingual_fine_tuning/blue_scores_{language}_lowsize{LOW_RESOURCE_SIZE}.txt', 'a')
            f.write('BLUE after ' + str(epoch) + ' epochs: ' + str(blue_score) + '\n')
            f.close()
    
    return total_blue / len(dataloader_dev)

In [None]:
# Optim, Scheduler
optimizer=AdamW(model.parameters(), lr=lr)
scheduler=get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=1000,
    num_training_steps=int(epochs*(len(dataloader_train_en)+len(dataloader_train_ru))/(accumulation_steps*batch_size))
)

# TensorBoard: Logging
writer=SummaryWriter()
step_global=0

best_bleu_score = 0
best_model = None
best_model_name = None

for epoch in range(epochs):
    # Train Phase
    model.train()
    model.to(device)
    
    loss_train=0
    optimizer.zero_grad()
    
    for language, dataloader in [('english', dataloader_train_en), ('russian', dataloader_train_ru)]:
        for step, (data, attn_mask, label) in enumerate(dataloader):
            data=data.to(device)
            attn_mask=attn_mask.to(device)
            label=label.to(device)

            outputs=model(input_ids=data, attention_mask=attn_mask, labels=label)

            loss=outputs[0]/accumulation_steps
            loss.backward()

            loss_train+=loss.item()

            if (step+1)%accumulation_steps==0:
                step_global+=1

                # TensorBoard
                writer.add_scalar(
                    f'loss_train/MT5-base_Fine-Tuning_lr{lr}_batch{int(accumulation_steps*batch_size)}_epoch{epochs}',
                    loss_train,
                    step_global
                )
                # Console
                if step_global%1000==0:
                    print(f'epoch {epoch+1} step {step_global} loss_train {loss_train:.4f}')
                # Set Loss to 0
                loss_train=0

                optimizer.step()
                scheduler.step()

                optimizer.zero_grad()
            
    blue_score = evaluate_model(model, epoch)
    print(f'BLUE of {blue_score}')
    
    if blue_score > best_bleu_score:
        best_model = model.deepcopy()
        best_model_name = f'MT5-base_Fine-Tuning_lr{lr}_batch{int(accumulation_steps*batch_size)}_lowsize{LOW_RESOURCE_SIZE}_epoch{epoch+1}of{epochs}.pt'
        best_bleu_score = blue_score
        
best_model.to(torch.device('cpu'))
torch.save(best_model, f'../model/{best_model_name}')

