## Import Libraries and Download Data

In [None]:
!pip install wandb --quiet -q
!pip install torchsummaryX -q
!pip install datasets -q
!pip install zstandard -q
!pip install tiktoken -q
!pip install rouge -q
!pip install torch nltk


In [None]:
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchsummaryX import summary
from torch.utils.data import Dataset, DataLoader
import torchaudio.transforms as tat
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from rouge import Rouge
import nltk
from nltk.translate.bleu_score import corpus_bleu

from sklearn.metrics import accuracy_score
import gc

import zipfile
import pandas as pd
from tqdm import tqdm
import os
import datetime
import zstandard
import datasets
import tiktoken
import random
import wandb
import math

import warnings
warnings.filterwarnings('ignore')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

Device:  cuda


In [None]:
### If you are using colab, you can import google drive to save model checkpoints in a folder.
### This is used when connecting to GCE VMs, but the user still wants to connect to Google Drive
import os.path as path
if not path.exists("/content/drive"):
  !sudo add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
  !sudo apt-get update -qq 2>&1 > /dev/null
  !sudo apt -y install -qq google-drive-ocamlfuse 2>&1 > /dev/null
  !google-drive-ocamlfuse

  !sudo apt-get install -qq w3m # to act as web browser
  !xdg-settings set default-web-browser w3m.desktop # to set default browser
  %cd /content
  !mkdir drive
  %cd drive
  !mkdir MyDrive
  %cd ..
  %cd ..
  !google-drive-ocamlfuse /content/drive/MyDrive

In [68]:
config = {
    'epochs'        : 5,
    'batch_size'    : 64,
    'init_lr'       : 3e-5,
    'block_size'    : 256,
    'dropout'       : 0.1,
    'vocab_size'    : 50257,
    'bias'          : True,
    'n_layer'       : 12,
    'n_head'        : 10,
    'n_embd'        : 250,
    'end_token'     : 50256,
    'summary_length' : 30
}

## Blocks

In [None]:
class LayerNorm(nn.Module):
  def __init__(self, ndim, bias):
    super().__init__()
    self.weight = torch.nn.Parameter(torch.ones(ndim))
    self.bias = torch.nn.Parameter(torch.zeros(ndim)) if bias else None

  def forward(self, input):
    return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

In [None]:
class CausalSelfAttention(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.c_attn = nn.Linear(config['n_embd'], 3 * config['n_embd'], bias=config['bias'])
    self.c_proj = nn.Linear(config['n_embd'], config['n_embd'], bias=config['bias'])

    self.attn_dropout = nn.Dropout(config['dropout'])
    self.resid_dropout = nn.Dropout(config['dropout'])
    self.n_head = config['n_head']
    self.n_embd = config['n_embd']
    self.dropout = config['dropout']

  def forward(self, x):
    B, T, C = x.size()

    q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

    y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)

    # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
    # att = F.softmax(att, dim=-1)
    # att = self.attn_dropout(att)
    # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs

    y = y.transpose(1, 2).contiguous().view(B, T, C)

    y = self.resid_dropout(self.c_proj(y))
    return y

In [None]:
class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config['n_embd'], 4 * config['n_embd'], bias=config['bias'])
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config['n_embd'], config['n_embd'], bias=config['bias'])
        self.dropout = nn.Dropout(config['dropout'])

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [None]:
class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config['n_embd'], bias=config['bias'])
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config['n_embd'], bias=config['bias'])
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

## Model

In [None]:
class GPT(nn.Module):
  def __init__(self, config):
    super().__init__()

    self.config = config

    self.transformer = nn.ModuleDict(dict(
        wte = nn.Embedding(config['vocab_size'], config['n_embd']),
        wpe = nn.Embedding(config['block_size'], config['n_embd']),
        drop = nn.Dropout(config['dropout']),
        h = nn.ModuleList([Block(config) for _ in range(config['n_layer'])]),
        ln_f = LayerNorm(config['n_embd'], bias=config['bias']),
    ))

    self.lm_head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False)

    self.transformer.wte.weight = self.lm_head.weight
    self.apply(self._init_weights)

    for pn, p in self.named_parameters():
      if pn.endswith('c_proj.weight'):
        torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config['n_layer']))


  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  def forward(self, idx, targets=None):
      device = idx.device
      b, t = idx.size()
      pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

      # forward the GPT model itself
      tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
      pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
      x = self.transformer.drop(tok_emb + pos_emb)
      for block in self.transformer.h:
          x = block(x)
      x = self.transformer.ln_f(x)

      logits = self.lm_head(x)

      return logits

  def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= self.config['block_size'] else idx[:, -self.config['block_size']:]
        # forward the model to get the logits for the index in the sequence
        logits, _ = self(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)

    return idx

## Load CNN/Daily Mail Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("cnn_dailymail", '3.0.0')

In [None]:
print(dataset.keys())

dict_keys(['train', 'validation', 'test'])


In [None]:
END_OF_TEXT = 50256
START_OF_TEXT = 50255

enc = tiktoken.get_encoding("gpt2")
summarization_prompt = enc.encode_ordinary("Summarize this article:")

In [None]:
def split_into_chunks(encoded_article, article_index, chunk_size=250):
    end_range = len(encoded_article) - len(encoded_article) % chunk_size
    return [(summarization_prompt + encoded_article[i:i + chunk_size], article_index) for i in range(0, end_range, chunk_size)]


In [None]:
def pad_with_eos(text, length=config['summary_length']):
  if length > len(text):
    pad_len = length - len(text)
    text += [END_OF_TEXT] * pad_len
    return text
  else:
    return text[:length]

In [None]:
def flatten_extend(matrix):
    flat_list = []
    for row in matrix:
      flat_list.extend(row)
    return flat_list

# Dataset and Dataloader

In [None]:
# Dataset class to load train and validation data

class CNNDailyMailDataset(torch.utils.data.Dataset):

    def __init__(self, prefix, encoder):

        data = dataset[prefix]

        self.enc = encoder

        self.inputs = [split_into_chunks(self.enc.encode_ordinary(data[i]["article"]), i)
            if (data[i]["article"] != None) else [] for i in range(len(data))]

        self.inputs = flatten_extend(self.inputs)

        self.targets = np.array([pad_with_eos(enc.encode_ordinary(data[i]["highlights"])) for i in range(len(data))])

        self.length = len(self.inputs)

    def __len__(self):
        return self.length

    def __getitem__(self, ind):
        article_chunk = self.inputs[ind][0]
        article_index = self.inputs[ind][1]
        return torch.tensor(article_chunk), torch.tensor(self.targets[article_index]), article_index


In [None]:
class CNNDailyMailTestDataset(torch.utils.data.Dataset):

    def __init__(self, prefix, encoder):

        data = dataset[prefix]

        self.enc = encoder

        self.inputs = [split_into_chunks(self.enc.encode_ordinary(data[i]["article"]), i)
            if (data[i]["article"] != None) else [] for i in range(len(data))]

        self.targets = np.array([data[i]['highlights'] for i in range(len(data))])

        self.length = len(self.inputs)

    def __len__(self):
        return self.length

    def __getitem__(self, ind):
        return self.inputs[ind], self.targets[ind]


In [None]:
train_data = CNNDailyMailDataset(prefix="train", encoder=enc)
val_data = CNNDailyMailDataset(prefix="validation",encoder=enc)
test_data = CNNDailyMailTestDataset(prefix="test",encoder=enc)

In [None]:
import multiprocessing

train_loader = torch.utils.data.DataLoader(
     dataset     = train_data,
     num_workers = 1,
     batch_size  = config['batch_size'],
     pin_memory  = True,
     drop_last   = True,
     shuffle     = True
)

val_loader = torch.utils.data.DataLoader(
     dataset     = val_data,
     num_workers = 1,
     batch_size  = config['batch_size'],
     pin_memory  = True,
     drop_last   = True,
     shuffle     = False
)

test_loader = torch.utils.data.DataLoader(
    dataset     = test_data,
    num_workers = 1,
    batch_size  = 1,
    pin_memory  = True,
    shuffle     = False
)

In [None]:
# Testing code to check if your data loaders are working
for i, data in enumerate(train_loader):
     x, y, article_idx = data
     print(x.shape, y.shape)
     print(article_idx)
     print(x, y)
     break

torch.Size([64, 256]) torch.Size([64, 30])
tensor([153355, 124372, 191885, 248379, 175910, 116445,  22201, 104562, 142908,
         59902, 175479, 231670, 272493, 126000, 217487, 263987, 177404, 132113,
        173588, 224715,  95684, 275608, 105229,  23827, 192049, 215502,  30395,
        172546, 212987, 269283, 264306,  57238,  31544,  30474, 173563, 176491,
         60742, 148873, 135951, 225883,  33081, 243449,  62388, 230082, 179230,
        274306, 261521,  90175, 199905,  62483, 146215, 169031, 270411,  34565,
        264669,  92047, 183236,  41192, 273252, 161119, 108714, 132067, 245069,
        215919])
tensor([[13065,  3876,  1096,  ..., 12526,  1683,  2826],
        [13065,  3876,  1096,  ...,  9074, 26618,  2087],
        [13065,  3876,  1096,  ...,  2863,    11,   772],
        ...,
        [13065,  3876,  1096,  ...,  2263,  5986,   286],
        [13065,  3876,  1096,  ...,  1521,   326, 30597],
        [13065,  3876,  1096,  ...,    11,   705,  7109]]) tensor([[31407,   

In [None]:
for i, data in enumerate(test_loader):
  x, y = data
  print(x[1][1])
  print(y)
  break


tensor([0])
['Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .\nIsrael and the United States opposed the move, which could open the door to war crimes investigations against Israelis .']


# Load Pretrained Model from Checkpoint / Optimizer / Criterion

In [None]:
# prompt: Instantiate and load a model from a checkpoint file

model = GPT(config).to(device)
checkpoint_path = './pretrained_model_checkpoint.pth'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [69]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=END_OF_TEXT)

optimizer = torch.optim.AdamW(model.parameters(), lr= config['init_lr']) #Defining Optimizer
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.75)
scaler = torch.cuda.amp.GradScaler()

In [107]:
gc.collect()
torch.cuda.empty_cache()



In [None]:
def generate_text_logits(model, seq, max_new_tokens=config['summary_length']):
  text_logits = None

  for i in range(max_new_tokens):
      # if the sequence context is growing too long we must crop it at block_size
      seq_cond = seq if seq.size(1) <= config['block_size'] else seq[:, -config['block_size']:]
      # forward the model to get the logits for the index in the sequence
      logits = model(seq_cond)

      if i == max_new_tokens-1:
        text_logits = logits[:, -max_new_tokens:, :] # batch size, max_new_tokens, vocab_size

      # pluck the logits at the final step and scale by desired temperature
      logits = logits[:, -1, :]
      # apply softmax to convert logits to (normalized) probabilities
      probs = F.softmax(logits, dim=-1)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1)
      # append sampled index to the running sequence and continue
      seq = torch.cat((seq, idx_next), dim=1)

  return text_logits, seq[:, -max_new_tokens:]

In [123]:
def calculate_bleu_score(reference, candidate):
    """
    Calculate the average BLEU score for a batch of translations.

    Args:
    - references_batch: A list of lists, where each inner list contains a single reference translation.
    - candidates_batch: A list of candidate translations.

    Returns:
    - average_bleu_score: The average BLEU score for the entire batch.
    """

    # Tokenize the strings
    reference_tokenized = nltk.word_tokenize(reference)
    candidate_tokenized = nltk.word_tokenize(candidate)

    # Calculate BLEU score for each translation in the batch
    bleu_score = corpus_bleu([reference_tokenized], [candidate_tokenized])

    return bleu_score

# Train, Eval, Test

In [None]:
def train(model, dataloader, optimizer, criterion):

    model.train()
    tloss = 0 # Monitoring loss and accuracy
    batch_bar   = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train')

    scaler = torch.cuda.amp.GradScaler()
    counter = 0

    for i, (inputs, targets, article_idx) in enumerate(dataloader):

        counter += 1

        ### Initialize Gradients
        optimizer.zero_grad()

        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
          ### Move Data to Device (Ideally GPU)
          inputs      = inputs.to(device)
          targets    = targets.to(device)

          ### Forward Propagation
          logits, _ = generate_text_logits(model, inputs)

          B, T, C = logits.shape

          logits = logits.reshape(B*T, -1)
          targets = targets.reshape(-1)

          loss =  criterion(logits, targets)


        ### Backward Propagation
        scaler.scale(loss).backward()

        ### Gradient Descent
        scaler.step(optimizer)

        scaler.update()

        tloss   += loss.item()

        batch_bar.set_postfix(loss="{:.04f}".format(float(tloss / (i + 1))))
        batch_bar.update()

        if counter % 50 == 0:
          wandb.log({'train_loss': (tloss/counter), 'lr': curr_lr})


        ### Release memory
        del inputs, targets, logits
        torch.cuda.empty_cache()

    batch_bar.close()
    tloss   /= len(dataloader)

    return tloss

In [83]:
def eval(model, dataloader):

    model.eval() # set model in evaluation mode
    vloss = 0 # Monitoring loss, accuracy, and distance
    batch_bar   = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='Val')

    counter = 0

    for i, (inputs, targets, article_idx) in enumerate(dataloader):

        counter += 1

        ### Move data to device (ideally GPU)
        inputs      = inputs.to(device)
        targets    = targets.to(device)

        # makes sure that there are no gradients computed as we are not training the model now
        with torch.inference_mode():
            logits, text = generate_text_logits(model, inputs)

            B, T, C = logits.shape

            logits = logits.reshape(B*T, -1)
            targets = targets.reshape(-1)

            loss =  criterion(logits, targets)

        # strip
        vloss   += loss.item()

        batch_bar.set_postfix(loss="{:.07f}".format(float(vloss / (i + 1))))

        batch_bar.update()

        if counter % 50 == 0:
          wandb.log({'val_loss': (vloss/counter), 'dist/lr': curr_lr})

        ### Release memory
        del inputs, targets, logits
        torch.cuda.empty_cache()

    batch_bar.close()
    vloss   /= len(dataloader)

    return vloss

In [136]:
def test(model, dataloader):

  model.eval()

  batch_bar   = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='Test')

  counter = 0

  bleu_score = 0.0

  for i, (inputs, targets) in enumerate(dataloader):

    counter += 1

    generated_summary = []

    for chunk in inputs:
      chunk = torch.tensor([chunk[0]]).to(device)
      _ , text = generate_text_logits(model, chunk)
      truncated_text = []
      for token in text[0]:
        if token == END_OF_TEXT:
          break
        truncated_text.append(token)
      generated_summary.extend(truncated_text)

    generated_summary = enc.decode(generated_summary)


    bleu_score += calculate_bleu_score(targets[0], generated_summary)

    print(bleu_score/(i+1))

    batch_bar.set_postfix(score="{:.07f}".format(float(bleu_score / (i + 1))))
    batch_bar.update()

  batch_bar.close()
  bleu_score /= len(dataloader)

  return bleu_score

# WandB

In [None]:
wandb.login(key="9312acc23a6389a925ba54b1bdf81ff99fe4d2e4") # API key for the project

[34m[1mwandb[0m: Currently logged in as: [33mkkmittal[0m ([33midl-f23[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
# Create your wandb run
run = wandb.init(
    name    = "summarization-finetuning", ### Wandb last name initializer
    reinit  = True, ### Allows reinitalizing runs when you re-run this cell
    project = "hw5-finetune", ### Project should be created in WandB
    config  = config ### Wandb Config for your run
)

# Training Loop

In [None]:
# Iterate over number of epochs to train and evaluate your model
best_val_loss = float('inf')

torch.cuda.empty_cache()
gc.collect()

for epoch in range(config['epochs']):

    print("\nEpoch {}/{}".format(epoch, config['epochs']))

    curr_lr      = float(optimizer.param_groups[0]['lr'])
    train_loss   = train(model, train_loader, optimizer, criterion)
    val_loss     = eval(model, val_loader)

    print("\tTrain Loss {:.07f}\t Learning Rate {:.07f}".format(train_loss, curr_lr))
    print("\tVal Loss {:.07f}\t".format(val_loss))

    wandb.log({'train_loss': train_loss, 'valid_loss': val_loss, 'lr': curr_lr})

    ### Highly Recommended: Save checkpoint in drive and/or wandb if accuracy is better than your current best
    torch.save({'model_state_dict' : model.state_dict(), 'optimizer_state_dict' : optimizer.state_dict()}, '/content/Finetuning/Summarization/checkpoint_epoch_' + str(epoch) + '.pth')

    if val_loss < best_val_loss:
      best_val_loss = val_loss
      torch.save({'model_state_dict' : model.state_dict(), 'optimizer_state_dict' : optimizer.state_dict()}, '/content/Finetuning/Summarization/checkpoint_best.pth')

    scheduler.step(val_loss)

# Testing / Calculate Bleu Score

In [139]:
bleu_score = test(model, test_loader)
print(bleu_score)

5.5612842184142e-32
