In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import TrainingArguments, Trainer
from transformers import BartTokenizer, BartForConditionalGeneration, EarlyStoppingCallback
import pickle

device = torch.device("cuda")
print('GPU:', torch.cuda.get_device_name(0))

import warnings
warnings.filterwarnings('ignore')

GPU: A100-SXM4-40GB


In [2]:
# load model and tokenizer
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

# freeze encoder layers
for param in model.base_model.encoder.parameters():
    param.requires_grad = False

for param in model.base_model.shared.parameters():
    param.requires_grad = False

In [3]:
# load tokenized training/dev data
train_source_tokenized  = torch.load('train_source_tokenized.pt')
train_target_tokenized  = torch.load('train_target_tokenized.pt')
dev_source_tokenized  = torch.load('dev_source_tokenized.pt')
dev_target_tokenized  = torch.load('dev_target_tokenized.pt')

In [4]:
# custom dataset class for data loading
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, source, target):
        self.source = source
        self.target = target

    def __getitem__(self, idx):
        target = {key: val[idx].detach().clone() for key, val in self.target.items()}
        target_input_ids = target['input_ids'][:-1].detach().clone()
        target_attention_mask = target['attention_mask'][:-1].detach().clone()

        labels = target['input_ids'][1:].detach().clone()
        labels[labels[:] == tokenizer.pad_token_id] = -100

        source = {key: val[idx].detach().clone() for key, val in self.source.items()}
        
        item = {'input_ids': source['input_ids'].detach().clone(),
                'attention_mask': source['attention_mask'].detach().clone(),
                'decoder_input_ids': target_input_ids.detach().clone(),
                'decoder_attention_mask': target_attention_mask.detach().clone(),
                'labels': labels}

        return item

    def __len__(self):
        return len(self.source["input_ids"])

train_dataset = CustomDataset(train_source_tokenized, train_target_tokenized)
val_dataset = CustomDataset(dev_source_tokenized, dev_target_tokenized)

In [None]:
# arguments for Huggingface trainer class
args = TrainingArguments(
    num_train_epochs=5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    output_dir="output",
    save_strategy = "no",
    metric_for_best_model = 'eval_loss',
    load_best_model_at_end=True,
    fp16 = True,
    evaluation_strategy = "steps",
    eval_steps=1000,
    learning_rate = 1.97622e-05,
    weight_decay = 0.0298317,
    seed = 21)


trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience = 2)])

trainer.train()

Step,Training Loss,Validation Loss,Runtime,Samples Per Second
1000,1.8295,1.773201,671.2418,148.978
2000,1.783,1.738948,671.4276,148.936
3000,1.7615,1.719797,666.2657,150.09
4000,1.7468,1.705974,666.1667,150.113
5000,1.7369,1.69564,666.1614,150.114
6000,1.639,1.692161,666.6972,149.993
7000,1.6433,1.687195,668.212,149.653
8000,1.6402,1.681399,668.9341,149.492
9000,1.637,1.674053,668.7481,149.533
10000,1.6311,1.671935,666.4172,150.056
