<a href="https://colab.research.google.com/github/Nanda654/HEADS/blob/main/Longformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm.auto import tqdm # For progress bars

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
OUTPUT_DIR = "./longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your generated dataset with combined extractive summaries
# This should be the output directory from your first script
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# Max length for a single sentence input to LongformerForSequenceClassification
# Longformer's max input is 4096, but here we process sentence by sentence.
# A typical sentence max length is 512, but can be adjusted.
SENTENCE_MAX_LENGTH = 512

# --- 1. Load the Generated GovReport Dataset ---
print(f"Loading generated GovReport dataset from: {GENERATED_DATASET_DIR}")
try:
    data_files = {
        "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
        "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
        "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
    }

    # Load the dataset. It should contain 'original_text' and 'extractive_summary_combined'
    govreport_data_splits = load_dataset("json", data_files=data_files)

    print("\nGovReport dataset loaded successfully!")
    print(govreport_data_splits)
    # Ensure the required columns exist
    if "original_text" not in govreport_data_splits["train"].column_names or \
       "extractive_summary" not in govreport_data_splits["train"].column_names:
        raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
except Exception as e:
    print(f"\nError loading GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
    print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
    exit()

# --- 2. Initialize Longformer Tokenizer ---
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
def generate_extractive_labels(examples):
    """
    Generates extractive labels for each sentence in a document based on ROUGE-L
    overlap with the combined extractive summary.
    """
    processed_examples = {
        "input_ids": [],
        "attention_mask": [],
        "global_attention_mask": [],
        "labels": [],
    }

    # Iterate through each example in the batch
    for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
        document_text = examples["original_text"][i]
        extractive_summary_reference = examples["extractive_summary"][i]

        if not document_text or not extractive_summary_reference:
            continue

        sentences = sent_tokenize(document_text)
        if not sentences:
            continue

        # Calculate ROUGE-L F1 for each sentence against the extractive summary reference
        sentence_rouge_scores = []
        for sent_idx, sentence in enumerate(sentences):
            # ROUGE metric expects lists of predictions and references
            # Use postprocess_text for proper sentence splitting for ROUGE
            processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                [sentence], [extractive_summary_reference]
            )
            score = rouge_metric.compute(
                predictions=processed_sent,
                references=processed_ref_summary,
                rouge_types=["rougeL"],
                use_stemmer=True,
            )
            # Store the fmeasure (F1 score). Access directly as it's a float.
            sentence_rouge_scores.append((score["rougeL"], sent_idx))

        # Sort sentences by ROUGE-L F1 score in descending order
        sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

        # Create binary labels for each sentence in the original document order
        # Heuristic: Select sentences that have a ROUGE-L F1 score above a threshold.
        # This threshold determines what is considered a "summary" sentence.
        ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality

        labels_for_document = [0] * len(sentences) # Initialize all labels to 0 (non-summary)

        # Keep track of selected sentences to avoid duplicates if needed, though not strictly
        # necessary for binary classification if only the label matters.
        # selected_sentence_indices = set()

        for score, sent_idx in sentence_rouge_scores:
            if score >= ROUGE_THRESHOLD:
                labels_for_document[sent_idx] = 1 # Mark as summary sentence
                # selected_sentence_indices.add(sent_idx)
            # Optional: You can also limit the total number of selected sentences per document
            # if len(selected_sentence_indices) >= MAX_EXTRACTIVE_SENTENCES_PER_DOC: break

        # Tokenize each sentence and add to processed_examples
        # Each sentence becomes a separate input example for the model
        for sent_idx, sentence in enumerate(sentences):
            inputs = tokenizer(
                sentence,
                truncation=True,
                max_length=SENTENCE_MAX_LENGTH,
                padding="max_length", # Pad all inputs to SENTENCE_MAX_LENGTH
                return_tensors="pt"
            )

            # Global attention mask: only on the CLS token for sentence classification
            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1 # Set global attention on the first token (CLS)

            processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
            processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
            processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
            processed_examples["labels"].append(labels_for_document[sent_idx])

    return processed_examples

# Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
def postprocess_text_for_rouge(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # ROUGE typically works better when sentences are separated by newlines
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]
    return preds, labels

# --- 4. Apply Label Generation to Dataset Splits ---
# This step will take a significant amount of time as it processes each sentence.
print("\nGenerating extractive labels for dataset splits (this may take a while)...")
# Get current column names to remove them after processing
original_column_names = govreport_data_splits["train"].column_names

tokenized_datasets = govreport_data_splits.map(
    generate_extractive_labels,
    batched=True, # Process examples in batches for efficiency
    remove_columns=original_column_names, # Remove original columns after processing
    num_proc=os.cpu_count() if os.cpu_count() else 1 # Use multiple processes if available
)
print("\nExtractive labels generated and tokenized.")
print(tokenized_datasets)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # Calculate precision, recall, f1, accuracy for binary classification
    # Assuming labels are 0 (non-summary) and 1 (summary)
    # We want to evaluate how well we predict the 'summary' class (1)
    from sklearn.metrics import precision_recall_fscore_support, accuracy_score

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
# num_labels=2 for binary classification (summary/non-summary)
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# If GPU is available, move model to GPU
if torch.cuda.is_available():
    model.to("cuda")
    print("LongformerForSequenceClassification model moved to GPU.")
else:
    print("No GPU found, model running on CPU.")

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3, # Adjust number of epochs
    per_device_train_batch_size=8, # Adjust batch size based on GPU memory
    per_device_eval_batch_size=8,
    warmup_steps=500, # Number of warmup steps for learning rate scheduler
    weight_decay=0.01, # Strength of weight decay
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    evaluation_strategy="epoch", # Evaluate at the end of each epoch
    save_strategy="epoch", # Save model at the end of each epoch
    load_best_model_at_end=True, # Load the best model based on evaluation metric
    metric_for_best_model="f1", # Metric to use for early stopping/best model selection
    greater_is_better=True,
    report_to="none", # Disable reporting to W&B, MLflow etc.
)

# --- 8. Create Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")
trainer.train()
print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


Downloading 'punkt_tab' NLTK data for ROUGE evaluation...
'punkt_tab' downloaded.


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Loading generated GovReport dataset from: ./govreport_tfidf_vscode2


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]


GovReport dataset loaded successfully!
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 17517
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
})


vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]


Generating extractive labels for dataset splits (this may take a while)...


Map (num_proc=2):   0%|          | 0/17517 [00:00<?, ? examples/s]

Process ForkPoolWorker-1:
Process ForkPoolWorker-2:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/multiprocess/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.11/dist-packages/multiprocess/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.11/dist-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.11/dist-packages/multiprocess/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.11/dist-packages/datasets/utils/py_utils.py", line 586, in _write_generator_to_queue
    for i, result in enumerate(func(**kwargs)):
  File "/usr/local/lib/python3.11/dist-pac

TimeoutError: 

In [2]:
cd drive/MyDrive/

/content/drive/MyDrive


In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm.auto import tqdm # For progress bars

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
OUTPUT_DIR = "./longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your generated dataset with combined extractive summaries
# This should be the output directory from your first script
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# Max length for a single sentence input to LongformerForSequenceClassification
# Longformer's max input is 4096, but here we process sentence by sentence.
# A typical sentence max length is 512, but can be adjusted.
SENTENCE_MAX_LENGTH = 512

# --- 1. Load the Generated GovReport Dataset ---
print(f"Loading generated GovReport dataset from: {GENERATED_DATASET_DIR}")
try:
    data_files = {
        "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
        "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
        "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
    }

    # Load the dataset. It should contain 'original_text' and 'extractive_summary_combined'
    govreport_data_splits = load_dataset("json", data_files=data_files)

    print("\nGovReport dataset loaded successfully!")
    print(govreport_data_splits)

    # --- IMPORTANT: Select a small subset for quick testing/debugging ---
    # Comment out or adjust these lines for full dataset training
    print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
    govreport_data_splits["train"] = govreport_data_splits["train"].select(range(10))
    govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(5))
    govreport_data_splits["test"] = govreport_data_splits["test"].select(range(5))
    print("Subset selected:")
    print(govreport_data_splits)
    # --- END SUBSET SELECTION ---

    # Ensure the required columns exist
    if "original_text" not in govreport_data_splits["train"].column_names or \
       "extractive_summary" not in govreport_data_splits["train"].column_names:
        raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
except Exception as e:
    print(f"\nError loading GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
    print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
    exit()

# --- 2. Initialize Longformer Tokenizer ---
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
def generate_extractive_labels(examples):
    """
    Generates extractive labels for each sentence in a document based on ROUGE-L
    overlap with the combined extractive summary.
    """
    processed_examples = {
        "input_ids": [],
        "attention_mask": [],
        "global_attention_mask": [],
        "labels": [],
    }

    # Iterate through each example in the batch
    for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
        document_text = examples["original_text"][i]
        extractive_summary_reference = examples["extractive_summary"][i]

        if not document_text or not extractive_summary_reference:
            continue

        sentences = sent_tokenize(document_text)
        if not sentences:
            continue

        # Calculate ROUGE-L F1 for each sentence against the extractive summary reference
        sentence_rouge_scores = []
        for sent_idx, sentence in enumerate(sentences):
            # ROUGE metric expects lists of predictions and references
            # Use postprocess_text for proper sentence splitting for ROUGE
            processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                [sentence], [extractive_summary_reference]
            )
            score = rouge_metric.compute(
                predictions=processed_sent,
                references=processed_ref_summary,
                rouge_types=["rougeL"],
                use_stemmer=True,
            )
            # Store the fmeasure (F1 score). Access directly as it's a float.
            sentence_rouge_scores.append((score["rougeL"], sent_idx))

        # Sort sentences by ROUGE-L F1 score in descending order
        sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

        # Create binary labels for each sentence in the original document order
        # Heuristic: Select sentences that have a ROUGE-L F1 score above a threshold.
        # This threshold determines what is considered a "summary" sentence.
        ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality

        labels_for_document = [0] * len(sentences) # Initialize all labels to 0 (non-summary)

        # Keep track of selected sentences to avoid duplicates if needed, though not strictly
        # necessary for binary classification if only the label matters.
        # selected_sentence_indices = set()

        for score, sent_idx in sentence_rouge_scores:
            if score >= ROUGE_THRESHOLD:
                labels_for_document[sent_idx] = 1 # Mark as summary sentence
                # selected_sentence_indices.add(sent_idx)
            # Optional: You can also limit the total number of selected sentences per document
            # if len(selected_sentence_indices) >= MAX_EXTRACTIVE_SENTENCES_PER_DOC: break

        # Tokenize each sentence and add to processed_examples
        # Each sentence becomes a separate input example for the model
        for sent_idx, sentence in enumerate(sentences):
            inputs = tokenizer(
                sentence,
                truncation=True,
                max_length=SENTENCE_MAX_LENGTH,
                padding="max_length", # Pad all inputs to SENTENCE_MAX_LENGTH
                return_tensors="pt"
            )

            # Global attention mask: only on the CLS token for sentence classification
            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1 # Set global attention on the first token (CLS)

            processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
            processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
            processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
            processed_examples["labels"].append(labels_for_document[sent_idx])

    return processed_examples

# Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
def postprocess_text_for_rouge(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # ROUGE typically works better when sentences are separated by newlines
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]
    return preds, labels

# --- 4. Apply Label Generation to Dataset Splits ---
# This step will take a significant amount of time as it processes each sentence.
print("\nGenerating extractive labels for dataset splits (this may take a while)...")
# Get current column names to remove them after processing
original_column_names = govreport_data_splits["train"].column_names

tokenized_datasets = govreport_data_splits.map(
    generate_extractive_labels,
    batched=True, # Process examples in batches for efficiency
    remove_columns=original_column_names, # Remove original columns after processing
    num_proc=os.cpu_count() if os.cpu_count() else 1 # Use multiple processes if available
)
print("\nExtractive labels generated and tokenized.")
print(tokenized_datasets)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # Calculate precision, recall, f1, accuracy for binary classification
    # Assuming labels are 0 (non-summary) and 1 (summary)
    # We want to evaluate how well we predict the 'summary' class (1)
    from sklearn.metrics import precision_recall_fscore_support, accuracy_score

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
# num_labels=2 for binary classification (summary/non-summary)
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# If GPU is available, move model to GPU
if torch.cuda.is_available():
    model.to("cuda")
    print("LongformerForSequenceClassification model moved to GPU.")
else:
    print("No GPU found, model running on CPU.")

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3, # Adjust number of epochs
    per_device_train_batch_size=8, # Adjust batch size based on GPU memory
    per_device_eval_batch_size=8,
    warmup_steps=500, # Number of warmup steps for learning rate scheduler
    weight_decay=0.01, # Strength of weight decay
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="epoch", # Corrected: Use eval_strategy instead of evaluation_strategy
    save_strategy="epoch", # Save model at the end of each epoch
    load_best_model_at_end=True, # Load the best model based on evaluation metric
    metric_for_best_model="f1", # Metric to use for early stopping/best model selection
    greater_is_better=True,
    report_to="none", # Disable reporting to W&B, MLflow etc.
)

# --- 8. Create Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")
trainer.train()
print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


Downloading 'punkt_tab' NLTK data for ROUGE evaluation...
'punkt_tab' downloaded.


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Loading generated GovReport dataset from: ./govreport_tfidf_vscode2

GovReport dataset loaded successfully!
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 17517
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
})

Selecting a small subset of the dataset for quick testing. Adjust or remove for full training.
Subset selected:
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 10
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 5
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'a

Map (num_proc=2):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/5 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/5 [00:00<?, ? examples/s]


Extractive labels generated and tokenized.
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 2074
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 775
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 1838
    })
})


Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


No GPU found, model running on CPU.

Starting Longformer extractive summarization training...


Epoch,Training Loss,Validation Loss


In [2]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm.auto import tqdm # For progress bars

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
OUTPUT_DIR = "./longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your generated dataset with combined extractive summaries
# This should be the output directory from your first script
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# Max length for a single sentence input to LongformerForSequenceClassification
# Longformer's max input is 4096, but here we process sentence by sentence.
# A typical sentence max length is 512, but can be adjusted.
SENTENCE_MAX_LENGTH = 512

# --- 1. Load the Generated GovReport Dataset ---
print(f"Loading generated GovReport dataset from: {GENERATED_DATASET_DIR}")
try:
    data_files = {
        "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
        "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
        "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
    }

    # Load the dataset. It should contain 'original_text' and 'extractive_summary_combined'
    govreport_data_splits = load_dataset("json", data_files=data_files)

    print("\nGovReport dataset loaded successfully!")
    print(govreport_data_splits)

    # --- IMPORTANT: Select a small subset for quick testing/debugging ---
    # Comment out or adjust these lines for full dataset training
    print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
    govreport_data_splits["train"] = govreport_data_splits["train"].select(range(10))
    govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(5))
    govreport_data_splits["test"] = govreport_data_splits["test"].select(range(5))
    print("Subset selected:")
    print(govreport_data_splits)
    # --- END SUBSET SELECTION ---

    # Ensure the required columns exist
    if "original_text" not in govreport_data_splits["train"].column_names or \
       "extractive_summary" not in govreport_data_splits["train"].column_names:
        raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
except Exception as e:
    print(f"\nError loading GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
    print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
    exit()

# --- 2. Initialize Longformer Tokenizer ---
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
def generate_extractive_labels(examples):
    """
    Generates extractive labels for each sentence in a document based on ROUGE-L
    overlap with the combined extractive summary.
    """
    processed_examples = {
        "input_ids": [],
        "attention_mask": [],
        "global_attention_mask": [],
        "labels": [],
    }

    # Iterate through each example in the batch
    for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
        document_text = examples["original_text"][i]
        extractive_summary_reference = examples["extractive_summary"][i]

        if not document_text or not extractive_summary_reference:
            continue

        sentences = sent_tokenize(document_text)
        if not sentences:
            continue

        # Calculate ROUGE-L F1 for each sentence against the extractive summary reference
        sentence_rouge_scores = []
        for sent_idx, sentence in enumerate(sentences):
            # ROUGE metric expects lists of predictions and references
            # Use postprocess_text for proper sentence splitting for ROUGE
            processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                [sentence], [extractive_summary_reference]
            )
            score = rouge_metric.compute(
                predictions=processed_sent,
                references=processed_ref_summary,
                rouge_types=["rougeL"],
                use_stemmer=True,
            )
            # Store the fmeasure (F1 score). Access directly as it's a float.
            sentence_rouge_scores.append((score["rougeL"], sent_idx))

        # Sort sentences by ROUGE-L F1 score in descending order
        sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

        # Create binary labels for each sentence in the original document order
        # Heuristic: Select sentences that have a ROUGE-L F1 score above a threshold.
        # This threshold determines what is considered a "summary" sentence.
        ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality

        labels_for_document = [0] * len(sentences) # Initialize all labels to 0 (non-summary)

        # Keep track of selected sentences to avoid duplicates if needed, though not strictly
        # necessary for binary classification if only the label matters.
        # selected_sentence_indices = set()

        for score, sent_idx in sentence_rouge_scores:
            if score >= ROUGE_THRESHOLD:
                labels_for_document[sent_idx] = 1 # Mark as summary sentence
                # selected_sentence_indices.add(sent_idx)
            # Optional: You can also limit the total number of selected sentences per document
            # if len(selected_sentence_indices) >= MAX_EXTRACTIVE_SENTENCES_PER_DOC: break

        # Tokenize each sentence and add to processed_examples
        # Each sentence becomes a separate input example for the model
        for sent_idx, sentence in enumerate(sentences):
            inputs = tokenizer(
                sentence,
                truncation=True,
                max_length=SENTENCE_MAX_LENGTH,
                padding="max_length", # Pad all inputs to SENTENCE_MAX_LENGTH
                return_tensors="pt"
            )

            # Global attention mask: only on the CLS token for sentence classification
            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1 # Set global attention on the first token (CLS)

            processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
            processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
            processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
            processed_examples["labels"].append(labels_for_document[sent_idx])

    return processed_examples

# Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
def postprocess_text_for_rouge(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # ROUGE typically works better when sentences are separated by newlines
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]
    return preds, labels

# --- 4. Apply Label Generation to Dataset Splits ---
# This step will take a significant amount of time as it processes each sentence.
print("\nGenerating extractive labels for dataset splits (this may take a while)...")
# Get current column names to remove them after processing
original_column_names = govreport_data_splits["train"].column_names

tokenized_datasets = govreport_data_splits.map(
    generate_extractive_labels,
    batched=True, # Process examples in batches for efficiency
    remove_columns=original_column_names, # Remove original columns after processing
    num_proc=os.cpu_count() if os.cpu_count() else 1 # Use multiple processes if available
)
print("\nExtractive labels generated and tokenized.")
print(tokenized_datasets)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # Calculate precision, recall, f1, accuracy for binary classification
    # Assuming labels are 0 (non-summary) and 1 (summary)
    # We want to evaluate how well we predict the 'summary' class (1)
    from sklearn.metrics import precision_recall_fscore_support, accuracy_score

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
# num_labels=2 for binary classification (summary/non-summary)
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# If GPU is available, move model to GPU
if torch.cuda.is_available():
    model.to("cuda")
    print("LongformerForSequenceClassification model moved to GPU.")
else:
    print("No GPU found, model running on CPU.")

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3, # Adjust number of epochs
    per_device_train_batch_size=2, # Reduced batch size for lower RAM
    per_device_eval_batch_size=2, # Reduced batch size for lower RAM
    gradient_accumulation_steps=4, # Accumulate gradients over 4 steps to simulate a batch size of 8
    gradient_checkpointing=True, # Trade computation for memory
    fp16=torch.cuda.is_available(), # Use mixed precision if GPU is available
    warmup_steps=500, # Number of warmup steps for learning rate scheduler
    weight_decay=0.01, # Strength of weight decay
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="epoch", # Corrected: Use eval_strategy instead of evaluation_strategy
    save_strategy="epoch", # Save model at the end of each epoch
    load_best_model_at_end=True, # Load the best model based on evaluation metric
    metric_for_best_model="f1", # Metric to use for early stopping/best model selection
    greater_is_better=True,
    report_to="none", # Disable reporting to W&B, MLflow etc.
)

# --- 8. Create Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")
trainer.train()
print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


Downloading 'punkt_tab' NLTK data for ROUGE evaluation...
'punkt_tab' downloaded.


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading generated GovReport dataset from: ./govreport_tfidf_vscode2

GovReport dataset loaded successfully!
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 17517
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
})

Selecting a small subset of the dataset for quick testing. Adjust or remove for full training.
Subset selected:
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 10
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 5
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'a

Map (num_proc=2):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/5 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/5 [00:00<?, ? examples/s]


Extractive labels generated and tokenized.
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 2074
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 775
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 1838
    })
})


Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


No GPU found, model running on CPU.

Starting Longformer extractive summarization training...


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm.auto import tqdm # For progress bars
from collections import Counter # To count class occurrences

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
OUTPUT_DIR = "./longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your generated dataset with combined extractive summaries
# This should be the output directory from your first script
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# Max length for a single sentence input to LongformerForSequenceClassification
# Longformer's max input is 4096, but here we process sentence by sentence.
# A typical sentence max length is 512, but can be adjusted.
SENTENCE_MAX_LENGTH = 512

# --- 1. Load the Generated GovReport Dataset ---
print(f"Loading generated GovReport dataset from: {GENERATED_DATASET_DIR}")
try:
    data_files = {
        "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
        "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
        "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
    }

    # Load the dataset. It should contain 'original_text' and 'extractive_summary_combined'
    govreport_data_splits = load_dataset("json", data_files=data_files)

    print("\nGovReport dataset loaded successfully!")
    print(govreport_data_splits)

    # --- IMPORTANT: Select a small subset for quick testing/debugging ---
    # Comment out or adjust these lines for full dataset training
    print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
    govreport_data_splits["train"] = govreport_data_splits["train"].select(range(10))
    govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(5))
    govreport_data_splits["test"] = govreport_data_splits["test"].select(range(5))
    print("Subset selected:")
    print(govreport_data_splits)
    # --- END SUBSET SELECTION ---

    # Ensure the required columns exist
    if "original_text" not in govreport_data_splits["train"].column_names or \
       "extractive_summary" not in govreport_data_splits["train"].column_names:
        raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
except Exception as e:
    print(f"\nError loading GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
    print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
    exit()

# --- 2. Initialize Longformer Tokenizer ---
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
def generate_extractive_labels(examples):
    """
    Generates extractive labels for each sentence in a document based on ROUGE-L
    overlap with the combined extractive summary.
    """
    processed_examples = {
        "input_ids": [],
        "attention_mask": [],
        "global_attention_mask": [],
        "labels": [],
    }

    # Iterate through each example in the batch
    for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
        document_text = examples["original_text"][i]
        extractive_summary_reference = examples["extractive_summary"][i]

        if not document_text or not extractive_summary_reference:
            continue

        sentences = sent_tokenize(document_text)
        if not sentences:
            continue

        # Calculate ROUGE-L F1 for each sentence against the extractive summary reference
        sentence_rouge_scores = []
        for sent_idx, sentence in enumerate(sentences):
            # ROUGE metric expects lists of predictions and references
            # Use postprocess_text for proper sentence splitting for ROUGE
            processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                [sentence], [extractive_summary_reference]
            )
            score = rouge_metric.compute(
                predictions=processed_sent,
                references=processed_ref_summary,
                rouge_types=["rougeL"],
                use_stemmer=True,
            )
            # Store the fmeasure (F1 score). Access directly as it's a float.
            sentence_rouge_scores.append((score["rougeL"], sent_idx))

        # Sort sentences by ROUGE-L F1 score in descending order
        sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

        # Create binary labels for each sentence in the original document order
        # Heuristic: Select sentences that have a ROUGE-L F1 score above a threshold.
        # This threshold determines what is considered a "summary" sentence.
        ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality

        labels_for_document = [0] * len(sentences) # Initialize all labels to 0 (non-summary)

        # Keep track of selected sentences to avoid duplicates if needed, though not strictly
        # necessary for binary classification if only the label matters.
        # selected_sentence_indices = set()

        for score, sent_idx in sentence_rouge_scores:
            if score >= ROUGE_THRESHOLD:
                labels_for_document[sent_idx] = 1 # Mark as summary sentence
                # selected_sentence_indices.add(sent_idx)
            # Optional: You can also limit the total number of selected sentences per document
            # if len(selected_sentence_indices) >= MAX_EXTRACTIVE_SENTENCES_PER_DOC: break

        # Tokenize each sentence and add to processed_examples
        # Each sentence becomes a separate input example for the model
        for sent_idx, sentence in enumerate(sentences):
            inputs = tokenizer(
                sentence,
                truncation=True,
                max_length=SENTENCE_MAX_LENGTH,
                padding="max_length", # Pad all inputs to SENTENCE_MAX_LENGTH
                return_tensors="pt"
            )

            # Global attention mask: only on the CLS token for sentence classification
            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1 # Set global attention on the first token (CLS)

            processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
            processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
            processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
            processed_examples["labels"].append(labels_for_document[sent_idx])

    return processed_examples

# Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
def postprocess_text_for_rouge(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # ROUGE typically works better when sentences are separated by newlines
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]
    return preds, labels

# --- 4. Apply Label Generation to Dataset Splits ---
# This step will take a significant amount of time as it processes each sentence.
print("\nGenerating extractive labels for dataset splits (this may take a while)...")
# Get current column names to remove them after processing
original_column_names = govreport_data_splits["train"].column_names

tokenized_datasets = govreport_data_splits.map(
    generate_extractive_labels,
    batched=True, # Process examples in batches for efficiency
    remove_columns=original_column_names, # Remove original columns after processing
    num_proc=os.cpu_count() if os.cpu_count() else 1 # Use multiple processes if available
)
print("\nExtractive labels generated and tokenized.")
print(tokenized_datasets)

# --- Calculate Class Weights for Imbalanced Data ---
# Count occurrences of each label (0 and 1) in the training set
train_labels = tokenized_datasets["train"]["labels"]
label_counts = Counter(train_labels)
num_class_0 = label_counts.get(0, 0) # Count of non-summary sentences
num_class_1 = label_counts.get(1, 0) # Count of summary sentences
total_samples = len(train_labels)

print(f"\nLabel distribution in training set: Class 0 (Non-summary): {num_class_0}, Class 1 (Summary): {num_class_1}")

# Calculate inverse class weights
# Weight for class C = total_samples / (num_classes * num_samples_in_class_C)
# Or, more simply, inversely proportional to frequency: 1 / frequency
# We want to give higher weight to the minority class (Class 1)
if num_class_0 > 0 and num_class_1 > 0:
    weight_for_class_0 = total_samples / (2 * num_class_0)
    weight_for_class_1 = total_samples / (2 * num_class_1)
    class_weights = torch.tensor([weight_for_class_0, weight_for_class_1], dtype=torch.float)
    print(f"Calculated class weights: Class 0: {weight_for_class_0:.4f}, Class 1: {weight_for_class_1:.4f}")
else:
    # Handle cases where one class might be completely missing (e.g., in very small subsets)
    print("Warning: One or both classes are missing in the training data. Using uniform weights.")
    class_weights = torch.tensor([1.0, 1.0], dtype=torch.float)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # Calculate precision, recall, f1, accuracy for binary classification
    # Assuming labels are 0 (non-summary) and 1 (summary)
    # We want to evaluate how well we predict the 'summary' class (1)
    from sklearn.metrics import precision_recall_fscore_support, accuracy_score

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
# num_labels=2 for binary classification (summary/non-summary)
# Pass class_weights to the model's loss function
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# Move class weights to the same device as the model
if torch.cuda.is_available():
    class_weights = class_weights.to("cuda")
    model.to("cuda")
    print("LongformerForSequenceClassification model and class weights moved to GPU.")
else:
    print("No GPU found, model and class weights running on CPU.")

# Override the default loss function to use class weights
# This is a common pattern when using custom loss or weights with Trainer
# The Trainer will then use this modified model's forward method.
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3, # Adjust number of epochs
    per_device_train_batch_size=2, # Reduced batch size for lower RAM
    per_device_eval_batch_size=2, # Reduced batch size for lower RAM
    gradient_accumulation_steps=4, # Accumulate gradients over 4 steps to simulate a batch size of 8
    gradient_checkpointing=True, # Trade computation for memory
    fp16=torch.cuda.is_available(), # Use mixed precision if GPU is available
    warmup_steps=500, # Number of warmup steps for learning rate scheduler
    weight_decay=0.01, # Strength of weight decay
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="epoch", # Corrected: Use eval_strategy instead of evaluation_strategy
    save_strategy="epoch", # Save model at the end of each epoch
    load_best_model_at_end=True, # Load the best model based on evaluation metric
    metric_for_best_model="f1", # Metric to use for early stopping/best model selection
    greater_is_better=True,
    report_to="none", # Disable reporting to W&B, MLflow etc.
)

# --- 8. Create Trainer ---
# Use the CustomTrainer to incorporate class weights
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")
trainer.train()
print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


In [1]:
cd drive/MyDrive/

/content/drive/MyDrive
