In [1]:
import os
import re
import csv
import time
import random
import numpy as np
from tqdm import tqdm
from itertools import islice
from torchinfo import summary as modelinfo
# import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import GPTNeoForCausalLM, GPT2Tokenizer


try: 
    from rouge_score import rouge_scorer
except:
    %pip install rouge-score==0.1.2
    from rouge_score import rouge_scorer
    
try:
    from peft import get_peft_model, LoraConfig, TaskType
except:
    %pip install peft==0.13.2
    from peft import get_peft_model, LoraConfig, TaskType

Collecting rouge-score==0.1.2
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=62ef6af4800dd945b91767316c147b01fbb6bf34f901a841aa838e80fac5b53c
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge-score
Installing collected packages: rouge-score
Successfully installed rouge-score-0.1.2
Note: you may need to restart the kernel to use updated packages.
Collecting peft==0.13.2
  Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)
Downloading peft-0.13.2-py3-none-any.whl (320 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.7/320.7 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: pe

## Config

In [2]:
train_filepath = "/kaggle/input/cnn_dailymail/train.csv"
val_filepath = "/kaggle/input/cnn_dailymail/validation.csv"
test_filepath = "/kaggle/input/cnn_dailymail/test.csv"

# train_filepath = "/kaggle/input/newspaper-text-summarization-cnn-dailymail/cnn_dailymail/train.csv"
# val_filepath = "/kaggle/input/newspaper-text-summarization-cnn-dailymail/cnn_dailymail/validation.csv"
# test_filepath = "/kaggle/input/newspaper-text-summarization-cnn-dailymail/cnn_dailymail/test.csv"

max_train_samples = 5000
max_val_samples = 1000
max_test_samples = 1000

In [3]:
to_train = True
model_name = "EleutherAI/gpt-neo-125m"
tuning_type = "last"

lr = 6e-5
epochs = 5
mini_batch_size = 32

# For LORA
r = 16
alpha = 32

In [4]:
assert tuning_type in ["none", "last", "lora"], "Invalid tuning type"
if tuning_type == "none": to_train=False
os.makedirs("models", exist_ok=True)

In [5]:
random_seed = 42

random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

print("Using Random Seed:", random_seed)

Using Random Seed: 42


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


## Utils

In [7]:
dm_single_close_quote = "\u2019"  # unicode
dm_double_close_quote = "\u201d"

# acceptable ways to end a sentence
END_TOKENS = [
    ".",
    "!",
    "?",
    "'",
    "`",
    '"',
    dm_single_close_quote,
    dm_double_close_quote,
    ")",
]

In [8]:
def remove_period(line):
    if line[-1] in END_TOKENS:
        return line[:-1]
    return line


def remove_punctuations(line):
    return re.sub(r"[^\w\s]", " ", line)


def clean_data(data):
    for i in range(len(data)):
        data[i]["article"] = remove_punctuations(data[i]["article"])
        data[i]["highlights"] = remove_punctuations(data[i]["highlights"])

        data[i]["article"] = re.sub(r"\s+", " ", data[i]["article"]).strip()
        data[i]["highlights"] = re.sub(r"\s+", " ", data[i]["highlights"]).strip()

    return data

In [9]:
def read_data(filepath, max_length=None):
    with open(filepath, "r") as f:
        reader = csv.DictReader(f)
        if max_length is not None:
            rows = islice(reader, max_length)
        else:
            rows = reader
        data = list(rows)
    
    return clean_data(data)

In [10]:
def freeze_last_layer(model):
    assert tuning_type == "last", "Only last layer fine-tuning is supported"

    for param in model.parameters():
        param.requires_grad = False
    
    for param in model.lm_head.parameters():
        param.requires_grad = True

def lora(model):
    assert tuning_type == "lora"

    config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=r,
        lora_alpha=alpha,
        bias="none"
    )

    model = get_peft_model(model, config)

    return model

In [11]:
def make_tokenizer(model_name):
    """Returns GPT2 tokenizer after adding separator and padding tokens"""
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    special_tokens = {"pad_token": "<pad>", "sep_token": "<sep>"}
    tokenizer.add_special_tokens(special_tokens)
    return tokenizer


def make_model(model_name, len_tokenizer, tuning_type="last"):
    model = GPTNeoForCausalLM.from_pretrained(model_name)
    model.resize_token_embeddings(len_tokenizer)

    if tuning_type == "last":
        freeze_last_layer(model)
    elif tuning_type == "lora":
        model = lora(model)

    model.to(device)
    return model

In [12]:
class SummarizationDataset(Dataset):
    def __init__(self, data, tokenizer, article_max_length=512, summary_max_length=128, type="train"):
        assert type in ["train", "val", "test"], "Invalid dataset type"

        self.type = type
        if type == "test":
            summary_max_length = 0

        self.tokenizer = tokenizer
        self.article_max_length = article_max_length
        self.summary_max_length = summary_max_length
        self.instruction_tokens = self.tokenizer.encode("summarize: ")
        self.sep_token = self.tokenizer.encode(" " + self.tokenizer.sep_token + " ")
        self.max_length = article_max_length + summary_max_length + len(self.instruction_tokens) + len(self.sep_token)
        
        self.processed_data = self._process_all_data(data, 4)

    def _attention_mask(self, padding_length):
        return [1] * (self.max_length - padding_length) + [0] * padding_length

    def _process_data(self, data):
        article_ids = self.tokenizer.encode(
            data["article"], truncation=True, max_length=self.article_max_length
        )
        if self.type != "test":
            abstract_ids = self.tokenizer.encode(
                data["highlights"], truncation=True, max_length=self.summary_max_length
            )
        else:
            abstract_ids = []

        # Combine all components
        content = self.instruction_tokens + article_ids + self.sep_token + abstract_ids

        if self.type != "test":
            padding_length = self.max_length - len(content)
            padded_content = content + [self.tokenizer.pad_token_id] * padding_length
        else:
            padding_length = self.max_length - len(content)
            padded_content = [self.tokenizer.pad_token_id] * padding_length + content

        return {
            "text": padded_content,
            "sep_idx": len(article_ids),
            "article_len": len(article_ids),
            "summary_len": len(abstract_ids),
            "attention_mask": self._attention_mask(padding_length),
            "highlights": data["highlights"],
        }
    
    def _process_all_data(self, data, num_workers):
        processed_data = []

        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            # Submit tasks to the thread pool
            futures = [executor.submit(self._process_data, item) for item in data]
            
            # Use tqdm to track the progress
            for future in tqdm(as_completed(futures), total=len(data), desc="Processing Data"):
                processed_data.append(future.result())

        return processed_data

    def __len__(self):
        return len(self.processed_data)

    def __getitem__(self, idx):
        processed_item = self.processed_data[idx]
        return {
            "article": torch.tensor(processed_item["text"]),
            "sep_idx": processed_item["sep_idx"],
            "article_len": processed_item["article_len"],
            "summary_len": processed_item["summary_len"],
            "attention_mask": torch.tensor(processed_item["attention_mask"]),
            "highlights": processed_item["highlights"],
        }

In [13]:
def evaluate(model, val_loader, loss_fn, summary_max_length=128):
    """
    Evaluate the model on validation/test data

    Args:
        model: The model to evaluate
        val_loader: DataLoader for validation data
        loss_fn: Loss function (typically CrossEntropyLoss with ignore_index set to pad_token_id)

    Returns:
        float: Average loss per batch
    """
    model.eval()
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            inputs = batch["article"].to(device)
            sep_idx = batch["sep_idx"].squeeze()
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(inputs, attention_mask=attention_mask, labels=inputs)
            loss = outputs.loss
            # decoded_output = tokenizer.decode(outputs.logits[0].argmax(dim=-1))
            # decoded_input = tokenizer.decode(inputs[0])
            # print(decoded_input, decoded_output, sep="\n\n")

            # Get logits and labels for summary portion only
            shift_logits = outputs.logits[..., sep_idx:-1, :].contiguous()
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = inputs[..., sep_idx + 1 :].contiguous()
            shift_labels = shift_labels.view(-1)

            shift_logits = shift_logits[:summary_max_length]
            shift_labels = shift_labels[:summary_max_length]
            
            # print()
            # print(tokenizer.decode(shift_logits.argmax(dim=-1)))
            # print(shift_labels)

            loss = loss_fn(shift_logits, shift_labels)
            total_loss += loss.item()
            num_batches += 1

    return total_loss / (num_batches if num_batches > 0 else 1)

In [15]:
def train(
    model,
    optimiser,
    loss_fn,
    train_loader,
    val_loader,
    num_epochs,
    max_grad_norm=1.0,
    mini_batch_size=4,
    summary_max_length=128,
    save_path="models/model.pt",
    print_every=5,
):
    global_step = 0
    best_val_loss = float("inf")
    tr_loss = 0.0
    last_tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()

    for epoch in range(num_epochs):
        model.train()
        epoch_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for step, batch in enumerate(epoch_iterator):
            inputs = batch["article"].to(device)
            sep_idx = batch["sep_idx"].squeeze()
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(inputs, labels=inputs, attention_mask=attention_mask)

            shift_logits = outputs.logits[..., sep_idx:-1, :].contiguous()
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = inputs[..., sep_idx + 1 :].contiguous()
            shift_labels = shift_labels.view(-1)

            shift_logits = shift_logits[:summary_max_length]
            shift_labels = shift_labels[:summary_max_length]

            loss = loss_fn(shift_logits, shift_labels)
            loss /= mini_batch_size
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            tr_loss += loss.item()

            if (step + 1) % mini_batch_size == 0:
                optimiser.step()
                model.zero_grad()
                global_step += 1

                if global_step % print_every == 0:
                    log_loss = (tr_loss - logging_loss) / mini_batch_size
                    logging_loss = tr_loss

                    print(f"Step: {global_step}, Mini-Batch Loss: {log_loss:.4f}")

        print(
            f"Epoch: {epoch+1}, Avg Training Loss: {(tr_loss - last_tr_loss) / len(train_loader):.4f}/sample",
            end=", ",
        )
        last_tr_loss = tr_loss

        val_loss = evaluate(model, val_loader, loss_fn)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)

        print(f"Validation Loss: {val_loss:.4f}")

In [16]:
def predict(
    model,
    tokenizer,
    text,
    article_max_length=512,
    max_length=128,
    preprocess=True,
    sep_idx=None,
):
    model.eval()
    if preprocess:
        text_ids = (
            tokenizer.encode("summarize: ")
            + tokenizer.encode(text)[:article_max_length]
            + tokenizer.encode(tokenizer.sep_token)
        )
        sep_idx = len(text_ids) - 1
        inputs = torch.tensor(text_ids).unsqueeze(0).to(device)
    else:
        assert sep_idx is not None, "sep_idx must be provided if preprocess is False"
        assert isinstance(sep_idx, torch.Tensor), "sep_idx must be a list"
        inputs = text.to(device)

    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=max_length,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7,
            num_return_sequences=1,
            no_repeat_ngram_size=3,
        )

    if preprocess:
        generated_summary = tokenizer.decode(
            outputs[0][sep_idx + 1 :], skip_special_tokens=True
        )
        return generated_summary
    else:
        generated_summaries = []
        for i in range(len(sep_idx)):
            sep_idx_i = sep_idx[i].item()
            generated_summary = tokenizer.decode(outputs[i][sep_idx_i + 1 :], skip_special_tokens=True)
            generated_summaries.append(generated_summary)

        return generated_summaries

In [17]:
def test_score(model, tokenizer, test_loader, article_max_length=512, max_length=128, print_every=False):
    rouge_scorer_obj = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"])
    rouge_scores = {"rouge1": [], "rouge2": [], "rougeL": []}

    for i in tqdm(test_loader):
        text = i["article"]
        summary = i["highlights"]

        generated_summary = predict(
            model,
            tokenizer,
            text,
            article_max_length,
            max_length,
            preprocess=False,
            sep_idx=i["sep_idx"],
        )

        for j in range(len(summary)):
            scores = rouge_scorer_obj.score(summary[j], generated_summary[j])
            for key in rouge_scores:
                rouge_scores[key].append(scores[key].fmeasure)
            
            if print_every:
                print(scores)

    for key in rouge_scores:
        rouge_scores[key] = np.mean(rouge_scores[key])

    return rouge_scores

## Main

In [19]:
train_data = read_data(train_filepath, max_train_samples)
val_data = read_data(val_filepath, max_val_samples)
test_data = read_data(test_filepath, max_test_samples)

In [None]:
tokenizer = make_tokenizer(model_name)
len_tokenizer = len(tokenizer)

model = make_model(model_name, len_tokenizer, tuning_type)

In [21]:
modelinfo(model)

Layer (type:depth-idx)                                  Param #
GPTNeoForCausalLM                                       --
├─GPTNeoModel: 1-1                                      --
│    └─Embedding: 2-1                                   38,598,912
│    └─Embedding: 2-2                                   (1,572,864)
│    └─Dropout: 2-3                                     --
│    └─ModuleList: 2-4                                  --
│    │    └─GPTNeoBlock: 3-1                            (7,085,568)
│    │    └─GPTNeoBlock: 3-2                            (7,085,568)
│    │    └─GPTNeoBlock: 3-3                            (7,085,568)
│    │    └─GPTNeoBlock: 3-4                            (7,085,568)
│    │    └─GPTNeoBlock: 3-5                            (7,085,568)
│    │    └─GPTNeoBlock: 3-6                            (7,085,568)
│    │    └─GPTNeoBlock: 3-7                            (7,085,568)
│    │    └─GPTNeoBlock: 3-8                            (7,085,568)
│    │    └─GPTNeoBlo

In [22]:
if to_train:
    train_dataset = SummarizationDataset(train_data, tokenizer, type="train")
    val_dataset = SummarizationDataset(val_data, tokenizer, type="val")
test_dataset = SummarizationDataset(test_data, tokenizer, type="test")

if to_train:
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Processing Data: 100%|██████████| 5000/5000 [01:31<00:00, 54.66it/s]
Processing Data: 100%|██████████| 1000/1000 [00:17<00:00, 56.62it/s]
Processing Data: 100%|██████████| 1000/1000 [00:16<00:00, 62.32it/s]


In [23]:
optimiser = optim.AdamW(model.parameters(), lr=lr)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [24]:
if to_train:
    start_time = time.time()
    train(
        model,
        optimiser,
        loss_fn,
        train_loader,
        val_loader,
        num_epochs=epochs,
        mini_batch_size=mini_batch_size,
        save_path=f"models/model_{tuning_type}.pt",
        print_every=40,
    )
    end_time = time.time()
    total_training_time = end_time - start_time
    print(f"Total Training Time: {total_training_time / 60:.2f} minutes\n")

if os.path.exists(f"models/model_{tuning_type}.pt"):
    model.load_state_dict(torch.load(f"models/model_{tuning_type}.pt", map_location=device, weights_only=True))
    print("Model loaded successfully")
else:
    print("No model checkpoint found")

Epoch 1:  26%|██▌       | 1282/5000 [02:08<06:09, 10.05it/s]

Step: 40, Mini-Batch Loss: 23.5291


Epoch 1:  51%|█████     | 2562/5000 [04:15<04:02, 10.04it/s]

Step: 80, Mini-Batch Loss: 22.2773


Epoch 1:  77%|███████▋  | 3842/5000 [06:22<01:55, 10.04it/s]

Step: 120, Mini-Batch Loss: 21.6129


Epoch 1: 100%|██████████| 5000/5000 [08:17<00:00, 10.05it/s]


Epoch: 1, Avg Training Loss: 0.5541/sample, 

Evaluating: 100%|██████████| 1000/1000 [00:43<00:00, 22.73it/s]


Validation Loss: 17.2898


Epoch 2:   3%|▎         | 130/5000 [00:12<08:06, 10.01it/s]

Step: 160, Mini-Batch Loss: 21.2895


Epoch 2:  28%|██▊       | 1410/5000 [02:20<05:57, 10.05it/s]

Step: 200, Mini-Batch Loss: 20.8317


Epoch 2:  54%|█████▍    | 2690/5000 [04:27<03:50, 10.02it/s]

Step: 240, Mini-Batch Loss: 20.3651


Epoch 2:  79%|███████▉  | 3969/5000 [06:34<01:44,  9.91it/s]

Step: 280, Mini-Batch Loss: 20.2127


Epoch 2: 100%|██████████| 5000/5000 [08:17<00:00, 10.05it/s]


Epoch: 2, Avg Training Loss: 0.5099/sample, 

Evaluating: 100%|██████████| 1000/1000 [00:43<00:00, 22.77it/s]


Validation Loss: 16.2494


Epoch 3:   5%|▌         | 258/5000 [00:25<07:53, 10.01it/s]

Step: 320, Mini-Batch Loss: 20.0472


Epoch 3:  31%|███       | 1538/5000 [02:33<05:45, 10.03it/s]

Step: 360, Mini-Batch Loss: 19.9306


Epoch 3:  56%|█████▋    | 2818/5000 [04:40<03:37, 10.04it/s]

Step: 400, Mini-Batch Loss: 19.5970


Epoch 3:  82%|████████▏ | 4098/5000 [06:47<01:30, 10.01it/s]

Step: 440, Mini-Batch Loss: 19.5742


Epoch 3: 100%|██████████| 5000/5000 [08:17<00:00, 10.05it/s]


Epoch: 3, Avg Training Loss: 0.4930/sample, 

Evaluating: 100%|██████████| 1000/1000 [00:43<00:00, 22.78it/s]


Validation Loss: 16.3394


Epoch 4:   8%|▊         | 386/5000 [00:38<07:40, 10.02it/s]

Step: 480, Mini-Batch Loss: 19.9925


Epoch 4:  33%|███▎      | 1665/5000 [02:45<05:32, 10.02it/s]

Step: 520, Mini-Batch Loss: 19.7062


Epoch 4:  59%|█████▉    | 2946/5000 [04:52<03:25, 10.02it/s]

Step: 560, Mini-Batch Loss: 19.8407


Epoch 4:  85%|████████▍ | 4226/5000 [07:00<01:17, 10.04it/s]

Step: 600, Mini-Batch Loss: 19.5675


Epoch 4: 100%|██████████| 5000/5000 [08:17<00:00, 10.05it/s]


Epoch: 4, Avg Training Loss: 0.4921/sample, 

Evaluating: 100%|██████████| 1000/1000 [00:43<00:00, 22.77it/s]


Validation Loss: 15.6558


Epoch 5:  10%|█         | 514/5000 [00:51<07:27, 10.02it/s]

Step: 640, Mini-Batch Loss: 19.5028


Epoch 5:  36%|███▌      | 1794/5000 [02:58<05:19, 10.05it/s]

Step: 680, Mini-Batch Loss: 19.0562


Epoch 5:  61%|██████▏   | 3074/5000 [05:05<03:11, 10.05it/s]

Step: 720, Mini-Batch Loss: 18.5341


Epoch 5:  87%|████████▋ | 4354/5000 [07:13<01:04, 10.05it/s]

Step: 760, Mini-Batch Loss: 19.3672


Epoch 5: 100%|██████████| 5000/5000 [08:17<00:00, 10.05it/s]


Epoch: 5, Avg Training Loss: 0.4758/sample, 

Evaluating: 100%|██████████| 1000/1000 [00:43<00:00, 22.81it/s]


Validation Loss: 15.6085
Total Training Time: 45.18 minutes

Model loaded successfully


In [25]:
# # Predict first 5 examples from test data

# for i in range(5):
#     text = test_data[i]["article"]
#     summary = test_data[i]["highlights"]
#     generated_summary = predict(model, tokenizer, text)

#     print(f"Example {i+1}")
#     print("Text:", text)
#     print("Actual Summary:", summary)
#     print("Generated Summary:", generated_summary)
#     print("\n")

In [26]:
print("Test Scores:", test_score(model, tokenizer, test_loader, print_every=False))

100%|██████████| 63/63 [06:12<00:00,  5.92s/it]

Test Scores: {'rouge1': 0.19136570004823917, 'rouge2': 0.10484339052743428, 'rougeL': 0.15257491328665496}





In [27]:
print(f"Allocated Memory: {torch.cuda.memory_allocated() / (1024 ** 3):.4f} GB")
print(f"Max Allocated Memory: {torch.cuda.max_memory_allocated() / (1024 ** 3):.4f} GB")
print(f"Reserved Memory: {torch.cuda.memory_reserved() / (1024 ** 3):.4f} GB")
print(f"Max Reserved Memory: {torch.cuda.max_memory_reserved() / (1024 ** 3):.4f} GB")

Allocated Memory: 0.9669 GB
Max Allocated Memory: 4.6935 GB
Reserved Memory: 6.3906 GB
Max Reserved Memory: 6.3906 GB
