In [None]:
BASE_DIR = '' # Working dir
DATA_DIR = f'{BASE_DIR}data/'
MODELS_DIR = f'{BASE_DIR}models/'

# Load data

In [None]:
IMPROVE_TOKEN = "improve_english: "
IMPROVE_TOKEN_MULTI = "improve_english"

def get_improve_token(language=None, use_single_token=False):
    if use_single_token:
        return IMPROVE_TOKEN
    return IMPROVE_TOKEN if not language else f'{IMPROVE_TOKEN_MULTI} {language}: '

def get_reverse_token(language=None, use_single_token=False):
    if use_single_token:
        return TRANSFORM_TOKEN
    return TRANSFORM_TOKEN if not language else f'{TRANSFORM_TOKEN_MULTI} {language}: '

def prepare_sentences(df, limit=1000, language=None, use_single_token=False, reverse_model=False):    
    sentences_inputs = []
    sentences_outputs = []
    counter = 0
    for index, row in df.iterrows():
        
        if reverse_model:
            sentences_inputs.append(get_reverse_token(language, use_single_token) + row['en'])        
                
            if language:
                sentences_outputs.append(row[f'trans_{language}'])        
            else:
                sentences_outputs.append(row['trans'])                
        else:
            if language:
                sentences_inputs.append(get_improve_token(language, use_single_token) + row[f'trans_{language}'])        
            else:
                sentences_inputs.append(get_improve_token(language, use_single_token) + row['trans'])                    
            sentences_outputs.append(row['en'])
            
        counter += 1
        if counter >=limit:
            break
    return sentences_inputs, sentences_outputs

# Multi lingual 

In [None]:
import pandas as pd

In [None]:
from random import shuffle

def prepare_sentences_multilingual(df, languages, size, use_single_token=False, reverse_model=False):
    sentences_inputs, sentences_outputs = [], []
    for language in languages:
        inputs, outputs = prepare_sentences(df, size*1000, language, use_single_token, reverse_model)
        sentences_inputs.extend(inputs)
        sentences_outputs.extend(outputs)
    return sentences_inputs, sentences_outputs

In [None]:
def prepare_data(data_frame, language, size, reverse_model=False, languages_multi=None):
    use_single_token = language == 'all'
    
    if language == 'all' or language == 'multi':
        languages = languages_multi
    else:
        languages = [language]
    sentences_inputs, sentences_outputs = prepare_sentences_multilingual(data_frame, languages, size, use_single_token, reverse_model)   

    grouped = list(zip(sentences_inputs, sentences_outputs))
    shuffle(grouped)

    sentences_inputs, sentences_outputs = zip(*grouped)

    return sentences_inputs, sentences_outputs

# Prepare data

In [None]:
import torch
from transformers import T5Tokenizer
from torch.utils.data import Dataset, random_split, IterableDataset

class LazyLoadDataset(IterableDataset):
    def __init__(self, tokenizer, inputs, outputs, length):
        super().__init__()
        self.tokenizer = tokenizer
        self.sentences = list(zip(inputs, outputs))
        self.length = length 
        
    def __len__(self):
        return self.length
    
    def get_next_sentence(self):    
        for sentence in self.sentences:
            yield sentence
    
    def __iter__(self):
        return self.get_next_sentence()
    
class PairedDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

## Train and validation set

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

def get_dataloader(dataset, tokenizer, split_size=1, batch_size = 32):
        
    def tokenize_batch(batch):
        sentence_inputs = [item[0] for item in batch]
        sentence_outputs = [item[1] for item in batch]
        encodings_input = tokenizer(sentence_inputs, truncation=True, padding='longest')
        encodings_output = tokenizer(sentence_outputs, truncation=True, padding='longest')
            
        return {
            'inputs': {key: torch.tensor(val) for key, val in encodings_input.items() },
            'outputs': {key: torch.tensor(val) for key, val in encodings_output.items()}
        }  

    return DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=tokenize_batch
    )

# Prepare training

In [None]:
from transformers import T5ForConditionalGeneration, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup

def prepare_training(model_name, tokenizer, data_size, epochs):    
    model = T5ForConditionalGeneration.from_pretrained(model_name).cuda()
    
    optimizer = AdamW(model.parameters(),
                      lr = 5e-5, 
                      eps = 1e-8 
                    )
    num_warmup_steps = 0

    total_steps = data_size * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = num_warmup_steps, 
                                                num_training_steps = total_steps)
    
    return model, optimizer, scheduler

# Training loop

In [None]:
import time
import datetime
import random
import numpy as np

seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [None]:
import wandb
from tqdm import tqdm

def train(run, model, optimizer, scheduler, epochs, model_name, dataloader):    
        
    with run:
        for epoch_i in range(0, epochs):
            model.train()

            progress_bar = tqdm(dataloader)
            step = 0
            total_train_loss = 0

            for batch in progress_bar:
                inputs = batch['inputs']
                outputs = batch['outputs']

                b_input_ids = inputs['input_ids'].to(device)
                b_masks = inputs['attention_mask'].to(device)
                b_labels = outputs['input_ids'].to(device)

                model.zero_grad()        
                outputs = model(  b_input_ids,
                                  labels=b_labels, 
                                  attention_mask = b_masks
                                )
                loss = outputs.loss
                total_train_loss += loss.item()

                loss.backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                optimizer.step()
                scheduler.step()

                if step % 5 == 0 and step > 0:
                    progress_bar.set_postfix({'loss': round(total_train_loss/step,2)})
                    train_loss = total_train_loss/step
                
                    run.log({
                        'epoch': epoch_i,
                        'step': step,
                        'train_loss': train_loss
                    })
                    
                step += 1
                
                
        run.finish()

    model.save_pretrained(MODELS_DIR+model_name)

In [None]:
import gc

def main(data, model_versions):            
    tokenizer_cache = {}
    
    for train_model in model_versions:
        # Clean up 
        torch.cuda.empty_cache()  
                
        model_name = f'{train_model["name"]}-s{train_model["size"]}-{train_model["version"]}'
        model_base = train_model['base']
        
        if not tokenizer_cache.get(model_base):
            tokenizer_cache[model_base] = T5Tokenizer.from_pretrained(model_base)
            
        tokenizer = tokenizer_cache.get(model_base)
        epochs = train_model['epochs']
        size = train_model['size']
        batch_size = train_model.get('batch_size', 32)
        languages_multi = train_model.get('languages', ['pt','es', 'de', 'fr'])
        
        # Prepare data and model
        model_language = model_name.split('-')[2] 
        model_type = model_name.split('-')[1] 
        is_reverse = model_type == 'reverse'
        
        sentences_inputs, sentences_outputs = prepare_data(data, model_language, size, is_reverse, languages_multi)        
        
        paired_dataset = LazyLoadDataset(tokenizer, sentences_inputs, sentences_outputs, len(sentences_inputs))    
        training_dataloader = get_dataloader(paired_dataset, tokenizer, batch_size=batch_size)
        model, optimizer, scheduler = prepare_training(model_base, tokenizer, len(training_dataloader), epochs)        
    
        # Print info
        print('\n')
        print('Training', model_name, 'Sentences', len(sentences_inputs), 'Batches', len(training_dataloader))
        print('=> Sample input:', sentences_inputs[0])
        print('=> Sample output:', sentences_outputs[0])
        time.sleep(1)
            
        # Init wandb
        project_name = '-'.join(model_name.split('-')[:-1])
        config_wandb = {
            'epochs': epochs,
            'size': size,
            'batch_size': batch_size,
            'base': model_base,
            'dataset_size': len(sentences_inputs),
        }
        
        if model_type == 'all' or model_type == 'multi':
            config_wandb['languages'] = languages_multi
            
        run = wandb.init(reinit=True, project=project_name, config=config_wandb)
        # Train model
        train(run, model, optimizer, scheduler, epochs, model_name, training_dataloader)
        
        # Clean up
        del model, optimizer, scheduler, training_dataloader, paired_dataset, sentences_inputs, sentences_outputs
        gc.collect()
        time.sleep(10)
        torch.cuda.empty_cache()        

In [None]:
df_training = pd.read_csv(f'{DATA_DIR}pt-es-en-parallel-corpus.csv')

In [None]:
device = torch.device('cuda')

BATCH_SIZE=16
EPOCHS = 3

models_config = [
{
    'id': 1,
    'name': 't5sm-l1aware-multi',
    'base': 't5-small', 
    'epochs': EPOCHS,
    'batch_size': BATCH_SIZE,
    'size': 260,
    'languages': ['pt','es'],
    'version': 'v1',
},
{
    'id': 2,    
    'name': 't5lg-l1aware-multi',
    'base': 't5-base', 
    'epochs': EPOCHS,
    'batch_size': BATCH_SIZE,
    'size': 130,
    'languages': ['pt','es'],
    'version': 'v1',
},
]

train_ids = [1, 2]

to_train = [model for model in models_config if model['id'] in train_ids]
main(df_training, to_train)