In [None]:
import torch, requests
import json, os
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import set_seed, T5ForConditionalGeneration

#### General Information


- T5 Paper: [Here](https://arxiv.org/pdf/1910.10683v3)

In [None]:
configs = {
    'max_input_embedding_length': 512,
    'max_output_embedding_length': 128,
    'task_prefix': "summarize: ",
    'tokenizer': 't5-small',
    'ignore_ids': -100,
    'padding_ids': 0,
    'base_model': 't5-small'
}

#### Dataset description

- Must be processed before run the training script
- Follow the file attached, preprocess.ipynb.
- Preprocessed data could be downloaded from here: [Link](https://www.kaggle.com/datasets/eddyvo/t5-base-tokens-cnn-daily)

In [None]:
root = '/kaggle/input/t5-base-tokens-cnn-daily' 
train_file, val_file = [os.path.join(root, fname) for fname in ['train_ds_encoded.json', 'val_ds_encoded.json']]

assert(os.path.exists(train_file))
assert(os.path.exists(val_file))

In [None]:
with open(train_file, 'r') as fp:
    train_list = json.load(fp)
    
with open(val_file, 'r') as fp:
    val_list = json.load(fp)

In [None]:
class CNNDaily(Dataset):
    def __init__(self, elements):
        self.elements = elements
    
    def __len__(self):
        return len(self.elements)
    
    def __getitem__(self, index):
        try:
            res = self.elements[index]
            return torch.LongTensor(res['input_ids']), torch.LongTensor(res['attention_mask']), torch.LongTensor(res['labels'])
        except Exception as err:
            print('Exception raised while loading item', index, '\nTrying to load', (index + 1) % len(self.elements))
            return self.__getitem__((index + 1) % len(self.elements))

In [None]:
train_ds, val_ds = CNNDaily(train_list), CNNDaily(val_list)

In [None]:
hyperparameters = {
    "learning_rate": 1e-5,
    "num_epochs": 2,
    "train_batch_size": 24,
    "eval_batch_size": 32,
    "seed": 42,
    "patience": 3, # early stopping
    "output_dir": "/content/",
}

In [None]:
train_loader = DataLoader(
    train_ds,
    batch_size = hyperparameters['train_batch_size'],
    shuffle = True
)

val_loader = DataLoader(
    val_ds,
    batch_size = hyperparameters['eval_batch_size'],
    shuffle = True
)

In [None]:
logs = {
    'train_batch_loss': [],
    'eval_batch_loss': [],
    'train_epoch_loss': [],
    'eval_epoch_loss': []
}

batch_log_interval = 500

In [None]:
model = T5ForConditionalGeneration.from_pretrained(configs['base_model'])
model = model.cuda() 
set_seed(hyperparameters["seed"])

In [None]:
def push_log(key, value):
    requests.get('http://ndtran.tech/logs/submit/', params = {
        'key': key,
        'room': 't5-small-finetune-tmp-1', 
        'value': value
    })

In [None]:
optimizer = AdamW(
    model.parameters(), 
    lr = hyperparameters["learning_rate"],
)

# scheduler = torch.optim.lr_scheduler.LinearLR(
#     optimizer,
#     total_iters = hyperparameters['num_epochs']
# )

In [None]:
# for param in model.parameters():
#     param.requires_grad = False

# for param in model.lm_head.parameters():
#     param.requires_grad = True
    
# pytorch_total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# pytorch_total_params = sum(p.numel() for p in model.parameters())
# print('Trainable params:', pytorch_total_trainable_params)
# print('Total params:', pytorch_total_params)
# print('Percentage:', pytorch_total_trainable_params * 100 / pytorch_total_params, '(%)')

In [None]:
epochs_no_improve = 0
min_val_loss = 1000000

for epoch in range(hyperparameters['num_epochs']):
    epoch_loss = 0
    model.train()
    for i, (X, Y, Z) in tqdm(
        enumerate(train_loader), 
        total = len(train_loader), 
        desc = f'Training {str(epoch + 1).zfill(2)} / {hyperparameters["num_epochs"]}'
    ):

        out = model(
            input_ids = X.cuda(),
            attention_mask = Y.cuda(),
            labels = Z.cuda()
        )

        loss = out.loss
        loss.backward()
        epoch_loss += loss.item()
        
        optimizer.step()
        optimizer.zero_grad()
        
        if i % batch_log_interval == 0:
            logs['train_batch_loss'].append(epoch_loss / (1 + i))
            # push_log('train_batch_loss', epoch_loss / (1 + i))
            
            
    logs['train_epoch_loss'].append(epoch_loss / len(train_loader))
    # push_log('train_epoch_loss', epoch_loss / len(train_loader))

    epoch_loss = 0
    model.eval()
    for i, (X, Y, Z) in tqdm(
        enumerate(val_loader), 
        total = len(val_loader), 
        desc = f'Evaluating {str(epoch + 1).zfill(2)} / {hyperparameters["num_epochs"]}'
    ):
        with torch.no_grad():
            out = model(
                input_ids = X.cuda(),
                attention_mask = Y.cuda(),
                labels = Z.cuda()
            )

            loss = out.loss
            
            epoch_loss += loss.item()
            
            if i % batch_log_interval == 0:
                logs['eval_batch_loss'].append(epoch_loss / (1 + i))
                # push_log('eval_batch_loss', epoch_loss / (1 + i))
                
                
                
    val_loss = epoch_loss / len(val_loader)
    logs['eval_epoch_loss'].append(val_loss)
    # push_log('eval_epoch_loss', val_loss)
    
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        
        if epochs_no_improve == hyperparameters['patience']:
            print('Early stopping at epoch', epoch)
            break
            


In [None]:
from datetime import datetime

out_dir = os.path.join('weights', datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
model.save_pretrained(out_dir)

with open('logs.json', 'w') as fp:
    json.dump(logs, fp, indent = 4)

#### For the last step

- Save the model locally, or upload to huggingface and load from there whenever needed
- Here is our pre-trained model: [Link](https://huggingface.co/ndtran/t5-small_cnn-daily-mails). Feel free to use it.
- use the code snippet below to load it: 

```python
  model = T5ForConditionalGeneration.from_pretrained('ndtran/t5-small_cnn-daily-mails')
```