In [1]:
# System imports
import os
import requests
import json
from typing import Dict, Any, Optional, Union
from pathlib import Path
import psutil
from datetime import datetime
from dotenv import load_dotenv
import regex as re

# External imports
from tqdm.auto import tqdm
import yaml
from loguru import logger
import tiktoken
import wandb


import numpy as np
import tensorflow as tf

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from transformers import get_cosine_schedule_with_warmup
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm
2025-07-31 22:20:26.281618: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754014826.359482    5661 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754014826.381467    5661 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754014826.527953    5661 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754014826.527989    5661 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754014826.527993    5661

In [2]:
load_dotenv()
wandb.login(key=os.getenv("WANDB_API_KEY"))

[34m[1mwandb[0m: Currently logged in as: [33mnilesh-auradkar[0m ([33mnilesh-auradkar-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
"""
This script downloads the GPT-2 model weights and loads them into a dictionary.
"""

def download_and_load_gpt2(model_size="124M", models_dir="../model_weights/"):
    # Validate model size
    allowed_sizes = ("124M", "355M", "774M", "1558M")
    if model_size not in allowed_sizes:
        raise ValueError(f"Model size not in {allowed_sizes}")

    # Define paths
    model_dir = os.path.join(models_dir, model_size)
    base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
    filenames = [
        "checkpoint", "encoder.json", "hparams.json",
        "model.ckpt.data-00000-of-00001", "model.ckpt.index",
        "model.ckpt.meta", "vocab.bpe"
    ]

    print(f"Downloading GPT-2 {model_size} model to {model_dir}")
    
    # Download files
    os.makedirs(model_dir, exist_ok=True)
    for filename in filenames:
        file_url = f"{base_url}/{model_size}/{filename}"
        file_path = os.path.join(model_dir, filename)
        download_file(file_url, file_path)

    print("Download completed. Loading model parameters...")

    # Load settings and params
    tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
    settings = json.load(open(os.path.join(model_dir, "hparams.json")))
    params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)

    print("Model loaded successfully!")
    print(f"Model configuration: {settings}")
    
    return settings, params

def download_file(url, destination):
    try:
        # Send a GET request to download the file, disabling SSL verification
        response = requests.get(url, stream=True, verify=False)
        response.raise_for_status()  # Raise an exception for bad status codes

        # Get the total file size from headers, defaulting to 0 if not present
        file_size = int(response.headers.get("content-length", 0))

        # Check if file exists and has the same size
        if os.path.exists(destination):
            file_size_local = os.path.getsize(destination)
            if file_size == file_size_local and file_size > 0:
                print(f"File already exists and is up-to-date: {destination}")
                return

        # Define the block size for reading the file
        block_size = 1024  # 1 Kilobyte

        # Initialize the progress bar with total file size
        progress_bar_description = url.split("/")[-1]  # Extract filename from URL
        with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
            # Open the destination file in binary write mode
            with open(destination, "wb") as file:
                # Iterate over the file data in chunks
                for chunk in response.iter_content(block_size):
                    if chunk:  # Filter out keep-alive chunks
                        progress_bar.update(len(chunk))  # Update progress bar
                        file.write(chunk)  # Write the chunk to the file

        print(f"Downloaded: {destination}")

    except requests.exceptions.RequestException as e:
        print(f"Error downloading the file: {e}")
        print(f"Please check the URL: {url}")
        raise

def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
    # Initialize parameters dictionary with empty blocks for each layer
    params = {"blocks": [{} for _ in range(settings["n_layer"])]}

    # Iterate over each variable in the checkpoint
    for name, _ in tf.train.list_variables(ckpt_path):
        # Load the variable and remove singleton dimensions
        variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))

        # Process the variable name to extract relevant parts
        variable_name_parts = name.split("/")[1:]  # Skip the 'model/' prefix

        # Identify the target dictionary for the variable
        target_dict = params
        if variable_name_parts[0].startswith("h"):
            layer_number = int(variable_name_parts[0][1:])
            target_dict = params["blocks"][layer_number]

        # Recursively access or create nested dictionaries
        for key in variable_name_parts[1:-1]:
            target_dict = target_dict.setdefault(key, {})

        # Assign the variable array to the last key
        last_key = variable_name_parts[-1]
        target_dict[last_key] = variable_array

    return params

# if __name__ == "__main__":
#     try:
#         # Download and load GPT-2 model
#         settings, params = download_and_load_gpt2(model_size="124M")
        
#         # Print some basic info about the loaded model
#         print("\nModel details:")
#         print(f"- Vocabulary size: {settings.get('n_vocab', 'Unknown')}")
#         print(f"- Number of layers: {settings.get('n_layer', 'Unknown')}")
#         print(f"- Number of attention heads: {settings.get('n_head', 'Unknown')}")
#         print(f"- Embedding dimension: {settings.get('n_embd', 'Unknown')}")
#         print(f"- Context length: {settings.get('n_ctx', 'Unknown')}")
        
#     except Exception as e:
#         print(f"Error: {e}")
#         print("Make sure you have the required dependencies installed:")
#         print("pip install tensorflow requests tqdm numpy")

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3))
        ))


class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["embedding_dim"], 4 * cfg["embedding_dim"]),
            GELU(),
            nn.Linear(4 * cfg["embedding_dim"], cfg["embedding_dim"]),
        )

    def forward(self, x):
        return self.layers(x)
    
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention/self-attention
        attn_scores = queries @ keys.transpose(2, 3)

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec
    
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.mask_attn = MultiHeadAttention(
            d_in=cfg["embedding_dim"],
            d_out=cfg["embedding_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["num_heads"], 
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ffn_block = FeedForward(cfg)
        self.norm_1 = LayerNorm(cfg["embedding_dim"])
        self.norm_2 = LayerNorm(cfg["embedding_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm_1(x)
        x = self.mask_attn(x)
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed forward block
        shortcut = x
        x = self.norm_2(x)
        x = self.ffn_block(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        return x
        # 2*4*768

class GPT2ModelClone(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_embeddings = nn.Embedding(cfg["vocab_size"], cfg["embedding_dim"])
        self.pos_embeddings = nn.Embedding(cfg["context_length"], cfg["embedding_dim"])
        self.drop_embeddings = nn.Dropout(cfg["drop_rate"])
        
        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["num_layers"])])
        
        self.final_norm = LayerNorm(cfg["embedding_dim"])
        self.out_head = nn.Linear(
            cfg["embedding_dim"], cfg["vocab_size"], bias=False
        )

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_embeddings(in_idx)
        pos_embeds = self.pos_embeddings(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_embeddings(x)
        x = self.transformer_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

In [5]:
class LoadModelConfig:
    def __init__(self, yaml_file: Union[str, Path] = "../config/model_config.yaml"):
        """
        Inititalize the model config class

        Args:
            yaml_file: Path to the YAML configuration file
        """
        if yaml_file is None:
            yaml_file = Path(__file__).parent.parent / "config" / "model_config.yaml"

        self.yaml_file = Path(yaml_file)
        self.config_data = self.load_yaml()

    def load_yaml(self) -> Dict:
        """Load the YAML configuration file."""
        try:
            if not self.yaml_file.exists():
                raise FileNotFoundError(f"The file {self.yaml_file} does not exist.")
            
            with self.yaml_file.open("r", encoding="utf-8") as f:
                data = yaml.safe_load(f)
                return data.get("model_configs", {})
        except yaml.YAMLError as e:
            raise ValueError(f"Error parsing YAML file: {e}")
        except Exception as e:
            raise IOError(f"Error reading {self.yaml_file}: {e}")
        
    def list_all_models(self) -> list:
        """Return a list of all model keys."""
        return list(self.config_data.keys())
    
    def get_model_config(self, model_name: str) -> Optional[Dict[str, Any]]:
        """
        Extract the model configuration for a given model name.

        Args:
            model_name: The name of the model to extract configuration for.

        Returns:
            Model Configuration dictionary or None if the model name is not found.
        """
        return self.config_data.get(model_name, None)

In [6]:
class PrepareModelWithPreTrainedWeights:
    def __init__(self, model_name: str, device: str = "cuda"):
        self.model_name = model_name
        self.device = device
        self._load_all()

    def _load_all(self):
        self.settings, self.params = self._get_settings_and_params()
        self.model_config = self._get_model_config()
        # print(f"Self.model_config: {self.model_config}")
        self.model = GPT2ModelClone(self.model_config)
        self.model.eval()
        self._load_gpt2_weights_into_model()
        self.model.to(self.device)

    def _assign_params(self, left, right):
        if left.shape != right.shape:
            raise ValueError(f"Shape mismatch: {left.shape} != {right.shape}")
        return torch.nn.Parameter(torch.tensor(right))

    def _get_settings_and_params(self):
        settings, params = download_and_load_gpt2(model_size="124M", models_dir="./model_weights/")
        print(f"Settings: {settings}")
        print(f"Params: {params.keys()}")
        return settings, params

    def _get_model_config(self):
        config = LoadModelConfig()
        print(config.list_all_models())
        model_config = config.get_model_config(model_name=self.model_name)
        print(f"Returned Model config for {self.model_name}: {model_config}")
        return model_config

    def _load_gpt2_weights_into_model(self):
        self.model.pos_embeddings.weight = self._assign_params(self.model.pos_embeddings.weight, self.params["wpe"])
        self.model.tok_embeddings.weight = self._assign_params(self.model.tok_embeddings.weight, self.params["wte"])

        for block in range(len(self.params["blocks"])):
            # Load the weights for Query, key and value
            q_w, k_w, v_w = np.split(
                (self.params["blocks"][block]["attn"]["c_attn"])["w"], 3, axis=-1)
            self.model.transformer_blocks[block].mask_attn.W_query.weight = self._assign_params(
                self.model.transformer_blocks[block].mask_attn.W_query.weight, q_w.T)
            self.model.transformer_blocks[block].mask_attn.W_key.weight = self._assign_params(
                self.model.transformer_blocks[block].mask_attn.W_key.weight, k_w.T)
            self.model.transformer_blocks[block].mask_attn.W_value.weight = self._assign_params(
                self.model.transformer_blocks[block].mask_attn.W_value.weight, v_w.T)

            # Load the weights for bias
            q_b, k_b, v_b = np.split(
                (self.params["blocks"][block]["attn"]["c_attn"])["b"], 3, axis=-1)
            self.model.transformer_blocks[block].mask_attn.W_query.bias = self._assign_params(
                self.model.transformer_blocks[block].mask_attn.W_query.bias, q_b)
            self.model.transformer_blocks[block].mask_attn.W_key.bias = self._assign_params(
                self.model.transformer_blocks[block].mask_attn.W_key.bias, k_b)
            self.model.transformer_blocks[block].mask_attn.W_value.bias = self._assign_params(
                self.model.transformer_blocks[block].mask_attn.W_value.bias, v_b)
            
            # Load output layer weights
            self.model.transformer_blocks[block].mask_attn.out_proj.weight = self._assign_params(
                self.model.transformer_blocks[block].mask_attn.out_proj.weight,
                self.params["blocks"][block]["attn"]["c_proj"]["w"].T)
            self.model.transformer_blocks[block].mask_attn.out_proj.bias = self._assign_params(
                self.model.transformer_blocks[block].mask_attn.out_proj.bias,
                self.params["blocks"][block]["attn"]["c_proj"]["b"])
            
            # Load the weights for feed forward block
            self.model.transformer_blocks[block].ffn_block.layers[0].weight = self._assign_params(
                self.model.transformer_blocks[block].ffn_block.layers[0].weight,
                self.params["blocks"][block]["mlp"]["c_fc"]["w"].T)
            self.model.transformer_blocks[block].ffn_block.layers[0].bias = self._assign_params(
                self.model.transformer_blocks[block].ffn_block.layers[0].bias,
                self.params["blocks"][block]["mlp"]["c_fc"]["b"])
            self.model.transformer_blocks[block].ffn_block.layers[2].weight = self._assign_params(
                self.model.transformer_blocks[block].ffn_block.layers[2].weight,
                self.params["blocks"][block]["mlp"]["c_proj"]["w"].T)
            self.model.transformer_blocks[block].ffn_block.layers[2].bias = self._assign_params(
                self.model.transformer_blocks[block].ffn_block.layers[2].bias,
                self.params["blocks"][block]["mlp"]["c_proj"]["b"])

            # Load Normalization weights
            self.model.transformer_blocks[block].norm_1.scale = self._assign_params(
                self.model.transformer_blocks[block].norm_1.scale,
                self.params["blocks"][block]["ln_1"]["g"])
            self.model.transformer_blocks[block].norm_1.shift = self._assign_params(
                self.model.transformer_blocks[block].norm_1.shift,
                self.params["blocks"][block]["ln_1"]["b"])
            self.model.transformer_blocks[block].norm_2.scale = self._assign_params(
                self.model.transformer_blocks[block].norm_2.scale,
                self.params["blocks"][block]["ln_2"]["g"])
            self.model.transformer_blocks[block].norm_2.shift = self._assign_params(
                self.model.transformer_blocks[block].norm_2.shift,
                self.params["blocks"][block]["ln_2"]["b"])
            
        self.model.final_norm.scale = self._assign_params(self.model.final_norm.scale, self.params["g"])
        self.model.final_norm.shift = self._assign_params(self.model.final_norm.shift, self.params["b"])
        self.model.out_head.weight = self._assign_params(self.model.out_head.weight, self.params["wte"])

In [7]:
class MaskedBillSumDatasetNew(Dataset):
    def __init__(self, dataset, tokenizer, max_length, min_summary_length=1):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.processed_data = []
        self.min_summary_length = min_summary_length
        self.dataset = dataset
        self.pad_token_id = tokenizer.eot_token

        logger.info("Processng and masking dataset for summarization...")
        logger.info(f"Original Dataset Size: {len(self.dataset)}")
        skipped_count = 0
        truncate_stats = {
            "full": 0,
            "section_trunc": 0,
            "token_trunc": 0,
            "sliding_window": 0,
        }

        for idx, sample in tqdm(enumerate(self.dataset)):
            try:

                if not sample.get("text") or not sample.get("summary"):
                    skipped_count += 1
                    continue

                article_text = str(sample.get("text")).strip()
                summary_text = str(sample.get("summary")).strip()

                if len(article_text) <= 50 or len(summary_text) < self.min_summary_length:
                    skipped_count += 1
                    continue

                clean_text = self._remove_boilerplate(article_text)

                summary_str = f"SUMMARY: {summary_text}"
                summary_tokens = self.tokenizer.encode(
                    summary_str,
                    allowed_special="all",
                    add_special_tokens=False,
                )

                required_summary_space = len(summary_tokens) + 1
                if required_summary_space > self.max_length - 50:
                    skipped_count += 1
                    continue

                available_article_tokens = self.max_length - required_summary_space

                article_str = f"ARTICLE: {clean_text}"
                article_tokens = self.tokenizer.encode(
                    article_str,
                    allowed_special="all",
                    add_special_tokens=False,
                )

                if len(article_tokens) <= available_article_tokens:
                    truncate_stats["full"] += 1
                else:
                    truncated_text = self._truncate_sections(
                        clean_text,
                        available_article_tokens,
                    )

                    if not truncated_text or len(truncated_text) < 50:
                        truncated_text = self._token_level_truncate(
                            clean_text,
                            available_article_tokens,
                        )
                        truncate_stats["token_trunc"] += 1
                    else:
                        truncate_stats["section_trunc"] += 1

                    article_str = f"ARTICLE: {truncated_text}"
                    article_tokens = self.tokenizer.encode(
                        article_str,
                        allowed_special="all",
                        add_special_tokens=False,
                    )[:available_article_tokens]

                if len(article_tokens) < 50:
                    skipped_count += 1
                    continue

                input_ids = article_tokens + summary_tokens + self.tokenizer.convert_tokens_to_ids(self.tokenizer.eot_token)
                labels = [-100] * len(article_tokens) + summary_tokens + [self.tokenizer.convert_tokens_to_ids(self.tokenizer.eot_token)]

                if len(input_ids) > self.max_length:
                    input_ids = input_ids[:self.max_length]
                    labels = labels[:self.max_length]

                padding_length = self.max_length - len(input_ids)
                if padding_length > 0:
                    input_ids += [self.pad_token_id] * padding_length
                    labels += [-100] * padding_length

                self.processed_data.append({
                    "input_ids": torch.tensor(input_ids, dtype=torch.long),
                    "labels": torch.tensor(labels, dtype=torch.long),
                })
            except Exception as e:
                logger.warning(f"Error processing samples: {idx}: {str(e)}")
                skipped_count += 1
                continue
        
        total_samples = len(self.dataset)
        logger.info(f"Processed {len(self.processed_data)}/{total_samples} samples")
        logger.info(f"Skipped: {skipped_count} | Truncated: {truncate_stats}")
        logger.info(f"Section trunc: {truncate_stats['section_trunc']} | Token trunc: {truncate_stats['token_trunc']}")

    def _remove_boilerplate(self, text):
        """Remove legal boilerplate from the billsum dataset."""
        text = re.sub(
            r'(Be it enacted by the Senate and House of Representatives|'
            r'The Congress of the United States.*?)(?=\n\s*SECTION\s+\d)',
            '',
            text,
            flags=re.DOTALL | re.IGNORECASE
        )
        # Remove middle section.
        # As per the paper, the middle section is not useful for summarization.
        # The main information for summarization lies in the first 3-5 sections and the last sections.
        text = re.sub(
            r'\nSECTION\s+\d+\.\s*\(a\)\s*.*?(?=\n\s*SECTION|\Z)',
            '\n',
            text,
            flags=re.DOTALL
        )
        return text.strip()
    
    def _truncate_sections(self, text, max_tokens):
        """Section truncation."""
        sections = re.split(
            r'(\n\s*SECTION\s+\d+\.?|\n\s*Sec\.\s+\d+\.?)',
            text,
            flags=re.IGNORECASE
        )

        if len(sections) <= 2:
            return None

        start_content = min(4, max(2, len(sections) // 4))
        end_content = min(3, max(1, len(sections) // 5))

        content = [sections[0]]
        content += sections[1: start_content*2]
        content += sections[-end_content*2:]

        text = "".join(content).strip()

        tokens = self.tokenizer.encode(
            f"ARTICLE: {text}",
            add_special_tokens=False,
            allowed_special="all",
        )

        if len(tokens) <= max_tokens:
            return text
        
        return self._reduce_sections(sections, max_tokens)
    
    def _reduce_sections(self, sections, max_tokens):
        """Progressively reduce sections until token limit is met"""
        for front in range(min(4, len(sections)//3), 1, -1):
            for back in range(min(3, len(sections)//4), 0, -1):
                kept = [sections[0]] + sections[1:front*2] + sections[-back*2:]
                candidate = ''.join(kept).strip()
                
                tokens = self.tokenizer.encode(
                    f"ARTICLE: {candidate}", 
                    add_special_tokens=False,
                    allowed_special="all"
                )
                
                if len(tokens) <= max_tokens:
                    return candidate
                    
        # Final fallback: Minimal critical sections
        candidate = sections[0] + sections[1] + sections[-2] + sections[-1]
        return candidate.strip()
    
    def _token_level_truncation(self, text, max_tokens):
        """Fallback truncation with sentence boundary preservation"""
        tokens = self.tokenizer.encode(
            text, 
            add_special_tokens=False,
            allowed_special="all"
        )
        
        if len(tokens) <= max_tokens:
            return text
            
        # Preserve first 60% and last 30% with sentence boundaries
        front_ratio = 0.65
        front_tokens = tokens[:int(len(tokens) * front_ratio)]
        back_tokens = tokens[-int(len(tokens) * 0.3):]
        
        # Find sentence boundaries
        front_text = self.tokenizer.decode(front_tokens, skip_special_tokens=True)
        last_sentence = re.search(r'[.!?]($|\s)', front_text)
        if last_sentence:
            front_text = front_text[:last_sentence.end()]
            front_tokens = self.tokenizer.encode(
                front_text, 
                add_special_tokens=False,
                allowed_special="all"
            )
        
        # Combine and validate
        combined_tokens = front_tokens + back_tokens
        if len(combined_tokens) > max_tokens:
            combined_tokens = combined_tokens[:max_tokens]
            
        return self.tokenizer.decode(combined_tokens, skip_special_tokens=True)


    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        return self.processed_data[idx]


In [8]:
class MaskedBillSumDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length, min_summary_length=5):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.min_summary_length = min_summary_length
        self.processed_data = []
        self.dataset = dataset
        self.pad_token_id = tokenizer.eot_token
        
        logger.info("Processing and masking dataset for summarization...")
        logger.info(f"Original dataset size: {len(self.dataset)}")
        
        skipped_count = 0
        
        for idx, sample in enumerate(tqdm(self.dataset)):
            try:
                # Basic validation
                if not sample.get('text') or not sample.get('summary'):
                    skipped_count += 1
                    continue
                
                article_text = str(sample['text']).strip()
                summary_text = str(sample['summary']).strip()
                
                # Skip if either is too short
                if len(article_text) < 10 or len(summary_text) < 10:
                    skipped_count += 1
                    continue
                
                # Create formatted input
                article = f"ARTICLE: {article_text} "
                summary = f"SUMMARY: {summary_text}"
                
                # Encode text
                article_tokens = self.tokenizer.encode(article, allowed_special="all")
                summary_tokens = self.tokenizer.encode(summary, allowed_special="all")
                
                # Add EOS token
                eos_token = [self.tokenizer.eot_token]
                
                # Skip if summary is too short after tokenization
                if len(summary_tokens) < self.min_summary_length:
                    skipped_count += 1
                    continue
                
                # Combine tokens
                input_ids = article_tokens + summary_tokens + eos_token
                
                # Skip if too long even before padding
                if len(input_ids) > self.max_length:
                    # Try to truncate article but keep summary
                    available_space = self.max_length - len(summary_tokens) - len(eos_token) - 50  # Buffer
                    if available_space > 100:  # Need reasonable article length
                        article_tokens = article_tokens[:available_space]
                        input_ids = article_tokens + summary_tokens + eos_token
                    else:
                        skipped_count += 1
                        continue
                
                # Create labels: mask article tokens, keep summary + eos tokens
                labels = [-100] * len(article_tokens) + summary_tokens + eos_token
                
                # Pad to max_length
                padding_length = self.max_length - len(input_ids)
                if padding_length > 0:
                    input_ids += [self.pad_token_id] * padding_length
                    labels += [-100] * padding_length
                
                # Final validation - ensure we have target tokens
                active_labels = [l for l in labels if l != -100]
                if len(active_labels) < self.min_summary_length:
                    skipped_count += 1
                    continue
                
                # Add to processed data
                self.processed_data.append({
                    "input_ids": torch.tensor(input_ids, dtype=torch.long),
                    "labels": torch.tensor(labels, dtype=torch.long),
                })
                
                # # Debug first few samples
                # if len(self.processed_data) <= 3:
                #     logger.info(f"Sample {len(self.processed_data)}:")
                #     logger.info(f"  Article length: {len(article_tokens)}")
                #     logger.info(f"  Summary length: {len(summary_tokens)}")
                #     logger.info(f"  Total length: {len(input_ids)}")
                #     logger.info(f"  Active labels: {len(active_labels)}")
                
            except Exception as e:
                logger.warning(f"Error processing sample {idx}: {e}")
                skipped_count += 1
                continue
        
        logger.info("Dataset processing complete:")
        logger.info(f"  Original samples: {len(self.dataset)}")
        logger.info(f"  Processed samples: {len(self.processed_data)}")
        logger.info(f"  Skipped samples: {skipped_count}")
        
        if len(self.processed_data) == 0:
            raise ValueError("No valid samples found in dataset! Check your data format and tokenization.")
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        return self.processed_data[idx]

In [9]:
def calculate_metrics(logits, targets):
    """Calculates loss, accuracy, and perplexity score for a batch."""
    logits = logits.to(targets.device)
    mask = targets != -100

    if not mask.any():
        return torch.tensor(0.0, device=logits.device, requires_grad=True), torch.tensor(0.0, device=logits.device), torch.tensor(1.0, device=logits.device)

    logits_flat = logits.view(-1, logits.size(-1))
    targets_flat = targets.view(-1)

    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=-100)

    active_mask = targets_flat != -100
    if active_mask.any():
        active_logits = logits_flat[active_mask]
        active_targets = targets_flat[active_mask]
        predicted_labels = torch.argmax(active_logits, dim=1)
        accuracy = (predicted_labels == active_targets).float().mean()
    else:
        accuracy = torch.tensor(0.0, device=logits.device)

    perplexity = torch.exp(torch.clamp(loss, max=10.0))

    return loss, accuracy, perplexity

In [10]:
"""This scripts:
    1. initializes new nn head for summarization.
    2. freezes all pre-train weights.
    3. unfreezes last block of transformer.
    4. replaces original out head with new nn head for summarization training."""

# from src.model.prepare_for_fine_tune import PrepareModelWithPreTrainedWeights
# from utils.load_dataset import MaskedBillSumDataset
# from utils.util import calculate_metrics

class SummarizationNNHead(nn.Module):
    """A Multi-Layer Perceptron Head for Summarization"""
    def __init__(self, embedding_dim, vocab_size, hidden_dim_factor=4):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * hidden_dim_factor),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim_factor * embedding_dim, vocab_size),
        )

    def forward(self, x):
        return self.layers(x)
    
class MemoryUsageLogger(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        if not trainer.is_global_zero:
            return
        
        ram_stats = psutil.virtual_memory()
        ram_used_gb = ram_stats.used / (1024 ** 3)

        vram_used_gb = 0
        if torch.cuda.is_available():
            device_idx = trainer.local_rank
            vram_used_gb = torch.cuda.memory_allocated(device_idx) / (1024 ** 3)

        metrics = {"memory/ram_used_gb": ram_used_gb, "memory/vram_used_gb": vram_used_gb}
        trainer.logger.log_metrics(metrics, step=trainer.global_step)
        logger.info(f"Memory @ epoch {trainer.current_epoch}: RAM: {ram_used_gb:.2f}GB | VRAM: {vram_used_gb:.2f}GB")


class SummarizationDataModule(pl.LightningDataModule):
    """Enhanced Data Module with better error handling"""
    def __init__(self, model_config, train_config):
        super().__init__()
        self.save_hyperparameters()
        self.tokenizer = tiktoken.get_encoding("gpt2")
        
        # Add validation
        logger.info(f"Tokenizer vocab size: {self.tokenizer.n_vocab}")
        logger.info(f"EOT token: {self.tokenizer.eot_token}")

    def setup(self, stage=None):
        try:
            # Load datasets with error handling
            logger.info("Loading datasets...")
            
            # Try smaller subset first for testing
            if stage == "fit" or stage is None:
                full_dataset = load_dataset("FiscalNote/billsum", split="train")
                split_dataset = full_dataset.train_test_split(test_size=0.3, seed=47)
                train_set = split_dataset["train"]
                val_set = split_dataset["test"]
                
                logger.info(f"Loaded train set: {len(train_set)} samples")
                logger.info(f"Loaded val set: {len(val_set)} samples")
                
                # Smaller subset for debugging
                # train_set = train_set.select(range(min(1000, len(train_set))))
                # val_set = val_set.select(range(min(200, len(val_set))))
                
                # Process datasets
                logger.info("Processing training dataset...")
                self.train_dataset = MaskedBillSumDataset(
                    train_set, 
                    self.tokenizer, 
                    self.hparams.model_config['context_length']
                )
                
                logger.info("Processing validation dataset...")
                self.val_dataset = MaskedBillSumDataset(
                    val_set, 
                    self.tokenizer, 
                    self.hparams.model_config['context_length']
                )
            
            # Test dataset for later use
            if stage == "test" or stage is None:
                test_set = load_dataset("FiscalNote/billsum", split="ca_test")
                logger.info(f"Loaded test set: {len(test_set)} samples")
                
                self.test_dataset = MaskedBillSumDataset(
                    test_set, 
                    self.tokenizer, 
                    self.hparams.model_config['context_length']
                )
                
        except Exception as e:
            logger.error(f"Error in setup: {e}")
            raise
    
    def train_dataloader(self):
        if not hasattr(self, 'train_dataset') or len(self.train_dataset) == 0:
            raise ValueError("Training dataset is empty or not initialized!")
        
        logger.info(f"Creating train dataloader with {len(self.train_dataset)} samples")
        
        return DataLoader(
            self.train_dataset, 
            batch_size=self.hparams.train_config["batch_size"], 
            num_workers=0, 
            drop_last=True, 
            shuffle=True,
            pin_memory=True
        )
    
    def val_dataloader(self):
        if not hasattr(self, 'val_dataset') or len(self.val_dataset) == 0:
            raise ValueError("Validation dataset is empty or not initialized!")
        
        logger.info(f"Creating val dataloader with {len(self.val_dataset)} samples")
        
        return DataLoader(
            self.val_dataset, 
            batch_size=self.hparams.train_config["batch_size"], 
            num_workers=5, 
            drop_last=False, 
            shuffle=False,
            pin_memory=True
        )
    
    def test_dataloader(self):
        if not hasattr(self, 'test_dataset') or len(self.test_dataset) == 0:
            raise ValueError("Test dataset is empty or not initialized!")
        
        return DataLoader(
            self.test_dataset, 
            batch_size=self.hparams.train_config["batch_size"], 
            num_workers=5, 
            drop_last=False, 
            shuffle=False,
            pin_memory=True
        )
    
class SummarizationFineTuneModel(pl.LightningModule):
    def __init__(self, model_name, model_config, train_config, num_training_steps):
        super().__init__()
        self.save_hyperparameters()

        self.gpt2_base = None
        self.model_loaded = False

        self.summarization_head = SummarizationNNHead(
            embedding_dim=self.hparams.model_config['embedding_dim'],
            vocab_size=self.hparams.model_config['vocab_size']
        )

        self._init_weights()

    def setup(self, stage=None):
        if not self.model_loaded:
            logger.info(f"Loading {self.hparams.model_name} model in setup()")
            model_loader = PrepareModelWithPreTrainedWeights(
                model_name=self.hparams.model_name,
                device="cpu"
            )
            if model_loader.model is not None:
                self.gpt2_base = model_loader.model
                logger.info("Model laoded Successfully! from setup()")
            else:
                logger.error("Model loading failed from setup()")
                raise ValueError("Model loading failed from setup()")

            logger.info(f"Freezing all parameters of the {self.hparams.model_name} model")
            for param in self.gpt2_base.parameters():
                param.requires_grad = False

            self.model_loaded = True

    def _init_weights(self):
        """Initializing weights with small values to prevent instability in training."""
        for module in self.summarization_head.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, batch):
        in_idx = batch['input_ids']

        # Input validation
        if torch.isnan(in_idx).any() or torch.isinf(in_idx).any():
            raise ValueError("Invalid input tokens detected")
        
        if self.gpt2_base is None:
            raise ValueError("GPT-2 model not loaded.")

        with torch.no_grad():        
            tok_embeds = self.gpt2_base.tok_embeddings(in_idx)
            pos_embeds = self.gpt2_base.pos_embeddings(torch.arange(in_idx.shape[1], device=self.device))
            input_embeds = tok_embeds + pos_embeds
            x = self.gpt2_base.drop_embeddings(input_embeds)
            x = self.gpt2_base.transformer_blocks(x)
            hidden_states = self.gpt2_base.final_norm(x)

        logits = self.summarization_head(hidden_states)

        # Checking inf/nan in outputs
        if torch.isnan(logits).any() or torch.isinf(logits).any():
            logger.warning("NaN or Inf detected in logits")
            return torch.zeros_like(logits)
            
        return logits
    
    def training_step(self, batch, batch_idx):

        # # Debugging
        # if batch_idx == 0:
        #     logger.info(f"Batch input_ids shape: {batch['input_ids'].shape}")
        #     logger.info(f"Batch labels shape: {batch['labels'].shape}")
        #     logger.info(f"Active labels count: {(batch['labels'] != -100).sum()}")
            
        logits = self(batch)
        loss, acc, perplexity = calculate_metrics(logits, batch['labels'])

        # NaN Checking
        if torch.isnan(loss) or torch.isinf(loss):
            logger.warning(f"NaN/Inf loss detected at step {batch_idx}")
            return None
        
        self.log('train_loss', loss, on_step=True, on_epoch=True,
                 prog_bar=True, logger=True, sync_dist=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True,
                 logger=True, sync_dist=True)
        self.log("train_perplexity", perplexity, on_epoch=True, logger=True, sync_dist=True)
        # print(f"Train Loss: {loss} | Train Acc: {acc} | Train Perplexity Score: {perplexity}")
        return loss
    
    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss, acc, perplexity = calculate_metrics(logits, batch['labels'])

        if torch.isnan(loss) or torch.isinf(loss):
            logger.warning(f"NaN/Inf validation loss detected at step {batch_idx}. Skipping...")
            return None

        self.log('val_loss', loss, on_epoch=True, prog_bar=True,
                 logger=True, sync_dist=True)
        self.log('val_acc', acc, on_epoch=True, logger=True, sync_dist=True)
        self.log('val_perplexity', perplexity, on_epoch=True, logger=True, sync_dist=True)
        # print(f"Val loss: {loss} | Val Acc: {acc} | Val Perplexity Score: {perplexity}")
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.summarization_head.parameters(),
            lr=self.hparams.train_config['learning_rate'],
            weight_decay=0.01,
            eps=1e-8
        )
        num_training_steps = self.hparams.num_training_steps
        num_warmup_steps = int(0.1 * num_training_steps)
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
        )
        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}

# Debug function to test dataset loading
def debug_dataset_loading():
    """Function to debug dataset loading issues"""
    
    # Test basic dataset loading
    try:
        logger.info("Testing basic dataset loading...")
        train_set = load_dataset("FiscalNote/billsum", split="train")
        logger.info(f"Successfully loaded {len(train_set)} training samples")
        
        # Check first sample
        sample = train_set[0]
        logger.info(f"Sample keys: {sample.keys()}")
        logger.info(f"Text length: {len(sample.get('text', ''))}")
        logger.info(f"Summary length: {len(sample.get('summary', ''))}")
        
        # Test tokenizer
        tokenizer = tiktoken.get_encoding("gpt2")
        logger.info(f"Tokenizer loaded, vocab size: {tokenizer.n_vocab}")
        
        # Test tokenization
        test_text = "This is a test."
        tokens = tokenizer.encode(test_text)
        logger.info(f"Test tokenization: '{test_text}' -> {tokens}")
        
        # Test dataset processing with just first sample
        logger.info("Testing dataset processing...")
        small_dataset = train_set.select(range(5))  # Just 5 samples
        
        processed_dataset = MaskedBillSumDataset(
            small_dataset,
            tokenizer,
            max_length=1024
        )
        
        logger.info(f"Successfully processed {len(processed_dataset)} samples")
        
        if len(processed_dataset) > 0:
            sample = processed_dataset[0]
            logger.info(f"Processed sample keys: {sample.keys()}")
            logger.info(f"Input IDs shape: {sample['input_ids'].shape}")
            logger.info(f"Labels shape: {sample['labels'].shape}")
            logger.info(f"Active labels count: {(sample['labels'] != -100).sum()}")
            
        return True
        
    except Exception as e:
        logger.error(f"Dataset loading test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

if __name__ == "__main__":
    # Load Model Configs

    if debug_dataset_loading():
        logger.info("Dataset loading test passed!")
    else:
        logger.error("Dataset loading test failed!")
        exit(1)

    with open("../config/training_config.yaml", "r") as f:
        train_config_full = yaml.safe_load(f)

    with open("../config/model_config.yaml", "r") as f:
        model_config_full = yaml.safe_load(f)

    # Extract model and training config
    model_name = "gpt2-small (124M)"
    model_config = model_config_full["model_configs"][model_name]
    train_config = train_config_full["train_config"]
    wandb_config = train_config_full["wandb_config"]
    experiment_path = train_config_full["experiment_path"]

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    wandb_project_name = wandb_config["project"] + str("_V2")
    wandb_run_name = f"fine-tune-{model_name}-{timestamp}"


    os.makedirs(os.path.join(experiment_path, "checkpoints"), exist_ok=True)

    # Initialize Data Module and Lightning Module
    data_module = SummarizationDataModule(model_config, train_config)
    data_module.setup(stage="fit")

    num_training_steps = (len(data_module.train_dataset) // train_config["batch_size"]) * train_config["num_epochs"]
    model_module = SummarizationFineTuneModel(model_name, model_config, train_config, num_training_steps)

    wandb_logger = WandbLogger(
        name=wandb_run_name,
        project=wandb_project_name,
        log_model="all",
        config={"model_config": model_config, "train_config": train_config},
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(experiment_path, "checkpoints"),
        filename="summarization-gpt2-finetune-Epoch-{epoch:02d}-val_loss-{val_loss:.2f}",
        save_top_k=3,
        monitor="val_loss",
        mode="min",
    )

    early_stop_callback = EarlyStopping(monitor="val_loss", patience=3, verbose=True, mode="min")
    memory_logger_callback = MemoryUsageLogger()

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1, # -1 for all available GPUs
        # strategy="ddp_notebook",
        max_epochs=train_config["num_epochs"],
        logger=wandb_logger,
        callbacks=[checkpoint_callback, memory_logger_callback, early_stop_callback],
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
        accumulate_grad_batches=1,
        precision="16-mixed",
        detect_anomaly=True,
    )

    trainer.fit(model_module, datamodule=data_module)

    if trainer.is_global_zero:
        wandb.finish()

    print("Training complete!")    

[32m2025-07-31 22:21:14.117[0m | [1mINFO    [0m | [36m__main__[0m:[36mdebug_dataset_loading[0m:[36m283[0m - [1mTesting basic dataset loading...[0m
[32m2025-07-31 22:21:16.586[0m | [1mINFO    [0m | [36m__main__[0m:[36mdebug_dataset_loading[0m:[36m285[0m - [1mSuccessfully loaded 18949 training samples[0m
[32m2025-07-31 22:21:16.589[0m | [1mINFO    [0m | [36m__main__[0m:[36mdebug_dataset_loading[0m:[36m289[0m - [1mSample keys: dict_keys(['text', 'summary', 'title'])[0m
[32m2025-07-31 22:21:16.590[0m | [1mINFO    [0m | [36m__main__[0m:[36mdebug_dataset_loading[0m:[36m290[0m - [1mText length: 5026[0m
[32m2025-07-31 22:21:16.591[0m | [1mINFO    [0m | [36m__main__[0m:[36mdebug_dataset_loading[0m:[36m291[0m - [1mSummary length: 1561[0m
[32m2025-07-31 22:21:18.407[0m | [1mINFO    [0m | [36m__main__[0m:[36mdebug_dataset_loading[0m:[36m295[0m - [1mTokenizer loaded, vocab size: 50257[0m
[32m2025-07-31 22:21:18.409[0m | [1mI

[32m2025-07-31 22:21:52.423[0m | [1mINFO    [0m | [36m__main__[0m:[36msetup[0m:[36m57[0m - [1mLoading datasets...[0m
[32m2025-07-31 22:21:53.459[0m | [1mINFO    [0m | [36m__main__[0m:[36msetup[0m:[36m66[0m - [1mLoaded train set: 13264 samples[0m
[32m2025-07-31 22:21:53.460[0m | [1mINFO    [0m | [36m__main__[0m:[36msetup[0m:[36m67[0m - [1mLoaded val set: 5685 samples[0m
[32m2025-07-31 22:21:53.461[0m | [1mINFO    [0m | [36m__main__[0m:[36msetup[0m:[36m75[0m - [1mProcessing training dataset...[0m
[32m2025-07-31 22:21:53.463[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m10[0m - [1mProcessing and masking dataset for summarization...[0m
[32m2025-07-31 22:21:53.464[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m11[0m - [1mOriginal dataset size: 13264[0m
100%|██████████| 13264/13264 [00:20<00:00, 650.94it/s]
[32m2025-07-31 22:22:13.844[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[

Downloading GPT-2 124M model to ./model_weights/124M




File already exists and is up-to-date: ./model_weights/124M/checkpoint




File already exists and is up-to-date: ./model_weights/124M/encoder.json




File already exists and is up-to-date: ./model_weights/124M/hparams.json




File already exists and is up-to-date: ./model_weights/124M/model.ckpt.data-00000-of-00001




File already exists and is up-to-date: ./model_weights/124M/model.ckpt.index




File already exists and is up-to-date: ./model_weights/124M/model.ckpt.meta




File already exists and is up-to-date: ./model_weights/124M/vocab.bpe
Download completed. Loading model parameters...
Model loaded successfully!
Model configuration: {'n_vocab': 50257, 'n_ctx': 1024, 'n_embd': 768, 'n_head': 12, 'n_layer': 12}
Settings: {'n_vocab': 50257, 'n_ctx': 1024, 'n_embd': 768, 'n_head': 12, 'n_layer': 12}
Params: dict_keys(['blocks', 'b', 'g', 'wpe', 'wte'])
['gpt2-small (124M)', 'gpt2-medium (355M)', 'gpt2-large (774M)', 'gpt2-xl (1558M)']
Returned Model config for gpt2-small (124M): {'vocab_size': 50257, 'context_length': 1024, 'embedding_dim': 768, 'num_layers': 12, 'num_heads': 12, 'drop_rate': 0.1, 'qkv_bias': True}


[32m2025-07-31 22:22:26.325[0m | [1mINFO    [0m | [36m__main__[0m:[36msetup[0m:[36m176[0m - [1mModel laoded Successfully! from setup()[0m
[32m2025-07-31 22:22:26.326[0m | [1mINFO    [0m | [36m__main__[0m:[36msetup[0m:[36m181[0m - [1mFreezing all parameters of the gpt2-small (124M) model[0m
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type                | Params | Mode 
-------------------------------------------------------------------
0 | summarization_head | SummarizationNNHead | 156 M  | train
1 | gpt2_base          | GPT2ModelClone      | 163 M  | eval 
-------------------------------------------------------------------
156 M     Trainable params
163 M     Non-trainable params
319 M     Total params
1,279.357 Total estimated model params size (MB)
6         Modules in train mode
187       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

[32m2025-07-31 22:22:27.248[0m | [1mINFO    [0m | [36m__main__[0m:[36mval_dataloader[0m:[36m123[0m - [1mCreating val dataloader with 5668 samples[0m


                                                                           

[32m2025-07-31 22:22:37.754[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_dataloader[0m:[36m108[0m - [1mCreating train dataloader with 13218 samples[0m
/media/cosmic-muffin/wd_black/git/envs/fine-tune-tf/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0:   1%|          | 55/4406 [09:11<12:07:02,  0.10it/s, v_num=19w1, train_loss_step=17.30]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined