# BART Fine-Tuning


## Get Data

In [1]:
import pickle

In [2]:
# ArXiv

arxiv_input=pickle.load(open("articles_inputs.pickle","rb"))
arxiv_target=pickle.load(open("articles_targets.pickle","rb"))

pubmed_input=pickle.load(open("pubMed_inputs.pickle","rb"))
pubmed_target=pickle.load(open("pubMed_targets.pickle","rb"))

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)#CAUTION: RUN THIS C

In [4]:
device

device(type='cuda')

In [5]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import re
import pandas as pd
import numpy as np
import torch
import os
import json
from transformers import BartTokenizer, BartModel, BartForConditionalGeneration
import math

In [6]:
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")

In [7]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

In [8]:
arxiv_target["input_ids"].shape

torch.Size([202902, 1024])

In [9]:
texts=[]
from tqdm import tqdm
for i in tqdm(arxiv_target["input_ids"][0:10000, ]):
    text=tokenizer.decode(i,skip_special_tokens=True).replace("<S>","").replace("</S>","")
    texts.append(text)
for i in tqdm(pubmed_target["input_ids"][0:5000, ]):
    text=tokenizer.decode(i,skip_special_tokens=True).replace("<S>","").replace("</S>","")
    texts.append(text)

100%|██████████| 10000/10000 [02:08<00:00, 77.73it/s]
100%|██████████| 5000/5000 [01:04<00:00, 77.31it/s]


In [10]:
abstract_targerts=tokenizer(texts, max_length=1024, 
        truncation=True, 
        padding="max_length", 
        return_tensors="pt")

In [11]:
arxiv_input["input_ids"]=arxiv_input["input_ids"].cpu()

In [12]:
arxiv_input["attention_mask"]=arxiv_input["attention_mask"].cpu()

In [13]:
pubmed_input["input_ids"]=pubmed_input["input_ids"].cpu()
pubmed_input["attention_mask"]=pubmed_input["attention_mask"].cpu()


In [14]:
from torch.utils.data import Dataset

class SummarizationDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs["input_ids"])

    def __getitem__(self, idx):
        return {
            'input_ids': self.inputs["input_ids"][idx],
            'attention_mask': self.inputs["attention_mask"][idx],
            'labels': self.targets["input_ids"][idx]
        }

In [15]:
inputs={"input_ids":None,"attention_mask":None}
inputs["input_ids"]=torch.cat([arxiv_input["input_ids"],pubmed_input["input_ids"]],dim=0)
inputs["attention_mask"]=torch.cat([arxiv_input["attention_mask"],pubmed_input["attention_mask"]],dim=0)

In [16]:
inputs["input_ids"].shape

torch.Size([15000, 1024])

In [17]:
dataset = SummarizationDataset(inputs, abstract_targerts)

In [18]:
for name, param in model.named_parameters():
    if "shared.weight" in name or "embed_positions.weight" in name \
    or "encoder.layers." in name or "encoder.layernorm_embedding." in name \
    or "decoder.embed_positions" in name or any(f"decoders.layers.{i}." in name for i in range(0,10)):
        param.requires_grad=False

In [19]:
from torch.utils.data import DataLoader

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

In [20]:
del arxiv_target

In [21]:
from transformers import AdamW
from torch.nn import functional as F
import torch


model = model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
def train_epoch(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0

    for batch in tqdm(data_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            labels=labels
        )
        loss = outputs.loss
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return total_loss / len(data_loader)

def evaluate_epoch(model, data_loader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(data_loader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids, 
                attention_mask=attention_mask, 
                labels=labels
            )
            total_loss += outputs.loss.item()

    return total_loss / len(data_loader)



In [22]:
import torch
torch.cuda.empty_cache()

In [43]:
EPOCHS = 10

for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = evaluate_epoch(model, val_loader, device)
    if epoch==4:
        torch.save(model.state_dict(), "BART-E5-ALL.pth")
    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")

torch.save(model.state_dict(), "BART-E10-ALL.pth")

100%|██████████| 1500/1500 [40:31<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 1, Train Loss: 1.1911534604231517, Val Loss: 0.9314441401163737


100%|██████████| 1500/1500 [44:09<00:00,  1.77s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 2, Train Loss: 0.8501952592134475, Val Loss: 0.7798139524459838


100%|██████████| 1500/1500 [40:34<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 3, Train Loss: 0.7106391162077585, Val Loss: 0.7159524720509847


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 4, Train Loss: 0.6270521755417188, Val Loss: 0.6812958683172862


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 5, Train Loss: 0.5642623227238656, Val Loss: 0.6651042557557424


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 6, Train Loss: 0.514181921839714, Val Loss: 0.6679914069970448


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 7, Train Loss: 0.4707191417316596, Val Loss: 0.6816804750760397


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 8, Train Loss: 0.43113634472091994, Val Loss: 0.6961883805592854


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 9, Train Loss: 0.39462132079402606, Val Loss: 0.722692139228185


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 10, Train Loss: 0.3601629015703996, Val Loss: 0.7502380991776785


In [23]:
EPOCHS = 5

for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = evaluate_epoch(model, val_loader, device)
    
    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")

torch.save(model.state_dict(), "BART-E5-ALL.pth")

100%|██████████| 1500/1500 [40:31<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 1, Train Loss: 1.189143114288648, Val Loss: 0.9200942738850911


100%|██████████| 1500/1500 [40:34<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 2, Train Loss: 0.8486903121272723, Val Loss: 0.7762309277852376


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 3, Train Loss: 0.7126234473983447, Val Loss: 0.7050663736661276


100%|██████████| 1500/1500 [40:35<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 4, Train Loss: 0.625435917754968, Val Loss: 0.6725757858753204


100%|██████████| 1500/1500 [40:36<00:00,  1.62s/it]
100%|██████████| 375/375 [04:32<00:00,  1.38it/s]


Epoch 5, Train Loss: 0.5634584678212802, Val Loss: 0.6621964476903279


In [21]:
#torch.save(model.state_dict(), "BART-E3-P.pth")