<a href="https://colab.research.google.com/github/kalidasuu/Movie-Recommender-System/blob/main/longformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm.auto import tqdm # For progress bars

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
OUTPUT_DIR = "./longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your generated dataset with combined extractive summaries
# This should be the output directory from your first script
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# Max length for a single sentence input to LongformerForSequenceClassification
# Longformer's max input is 4096, but here we process sentence by sentence.
# A typical sentence max length is 512, but can be adjusted.
SENTENCE_MAX_LENGTH = 512

# --- 1. Load the Generated GovReport Dataset ---
print(f"Loading generated GovReport dataset from: {GENERATED_DATASET_DIR}")
try:
    data_files = {
        "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
        "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
        "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
    }

    # Load the dataset. It should contain 'original_text' and 'extractive_summary_combined'
    govreport_data_splits = load_dataset("json", data_files=data_files)

    print("\nGovReport dataset loaded successfully!")
    print(govreport_data_splits)

    # --- IMPORTANT: Select a small subset for quick testing/debugging ---
    # Comment out or adjust these lines for full dataset training
    print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
    govreport_data_splits["train"] = govreport_data_splits["train"].select(range(10))
    govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(5))
    govreport_data_splits["test"] = govreport_data_splits["test"].select(range(5))
    print("Subset selected:")
    print(govreport_data_splits)
    # --- END SUBSET SELECTION ---

    # Ensure the required columns exist
    if "original_text" not in govreport_data_splits["train"].column_names or \
       "extractive_summary" not in govreport_data_splits["train"].column_names:
        raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
except Exception as e:
    print(f"\nError loading GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
    print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
    exit()

# --- 2. Initialize Longformer Tokenizer ---
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
def generate_extractive_labels(examples):
    """
    Generates extractive labels for each sentence in a document based on ROUGE-L
    overlap with the combined extractive summary.
    """
    processed_examples = {
        "input_ids": [],
        "attention_mask": [],
        "global_attention_mask": [],
        "labels": [],
    }

    # Iterate through each example in the batch
    for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
        document_text = examples["original_text"][i]
        extractive_summary_reference = examples["extractive_summary"][i]

        if not document_text or not extractive_summary_reference:
            continue

        sentences = sent_tokenize(document_text)
        if not sentences:
            continue

        # Calculate ROUGE-L F1 for each sentence against the extractive summary reference
        sentence_rouge_scores = []
        for sent_idx, sentence in enumerate(sentences):
            # ROUGE metric expects lists of predictions and references
            # Use postprocess_text for proper sentence splitting for ROUGE
            processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                [sentence], [extractive_summary_reference]
            )
            score = rouge_metric.compute(
                predictions=processed_sent,
                references=processed_ref_summary,
                rouge_types=["rougeL"],
                use_stemmer=True,
            )
            # Store the fmeasure (F1 score). Access directly as it's a float.
            sentence_rouge_scores.append((score["rougeL"], sent_idx))

        # Sort sentences by ROUGE-L F1 score in descending order
        sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

        # Create binary labels for each sentence in the original document order
        # Heuristic: Select sentences that have a ROUGE-L F1 score above a threshold.
        # This threshold determines what is considered a "summary" sentence.
        ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality

        labels_for_document = [0] * len(sentences) # Initialize all labels to 0 (non-summary)

        # Keep track of selected sentences to avoid duplicates if needed, though not strictly
        # necessary for binary classification if only the label matters.
        # selected_sentence_indices = set()

        for score, sent_idx in sentence_rouge_scores:
            if score >= ROUGE_THRESHOLD:
                labels_for_document[sent_idx] = 1 # Mark as summary sentence
                # selected_sentence_indices.add(sent_idx)
            # Optional: You can also limit the total number of selected sentences per document
            # if len(selected_sentence_indices) >= MAX_EXTRACTIVE_SENTENCES_PER_DOC: break

        # Tokenize each sentence and add to processed_examples
        # Each sentence becomes a separate input example for the model
        for sent_idx, sentence in enumerate(sentences):
            inputs = tokenizer(
                sentence,
                truncation=True,
                max_length=SENTENCE_MAX_LENGTH,
                padding="max_length", # Pad all inputs to SENTENCE_MAX_LENGTH
                return_tensors="pt"
            )

            # Global attention mask: only on the CLS token for sentence classification
            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1 # Set global attention on the first token (CLS)

            processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
            processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
            processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
            processed_examples["labels"].append(labels_for_document[sent_idx])

    return processed_examples

# Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
def postprocess_text_for_rouge(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # ROUGE typically works better when sentences are separated by newlines
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]
    return preds, labels

# --- 4. Apply Label Generation to Dataset Splits ---
# This step will take a significant amount of time as it processes each sentence.
print("\nGenerating extractive labels for dataset splits (this may take a while)...")
# Get current column names to remove them after processing
original_column_names = govreport_data_splits["train"].column_names

tokenized_datasets = govreport_data_splits.map(
    generate_extractive_labels,
    batched=True, # Process examples in batches for efficiency
    remove_columns=original_column_names, # Remove original columns after processing
    num_proc=os.cpu_count() if os.cpu_count() else 1 # Use multiple processes if available
)
print("\nExtractive labels generated and tokenized.")
print(tokenized_datasets)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # Calculate precision, recall, f1, accuracy for binary classification
    # Assuming labels are 0 (non-summary) and 1 (summary)
    # We want to evaluate how well we predict the 'summary' class (1)
    from sklearn.metrics import precision_recall_fscore_support, accuracy_score

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
# num_labels=2 for binary classification (summary/non-summary)
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# If GPU is available, move model to GPU
if torch.cuda.is_available():
    model.to("cuda")
    print("LongformerForSequenceClassification model moved to GPU.")
else:
    print("No GPU found, model running on CPU.")

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3, # Adjust number of epochs
    per_device_train_batch_size=2, # Reduced batch size for lower RAM
    per_device_eval_batch_size=2, # Reduced batch size for lower RAM
    gradient_accumulation_steps=4, # Accumulate gradients over 4 steps to simulate a batch size of 8
    gradient_checkpointing=True, # Trade computation for memory
    fp16=torch.cuda.is_available(), # Use mixed precision if GPU is available
    warmup_steps=500, # Number of warmup steps for learning rate scheduler
    weight_decay=0.01, # Strength of weight decay
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="epoch", # Corrected: Use eval_strategy instead of evaluation_strategy
    save_strategy="epoch", # Save model at the end of each epoch
    load_best_model_at_end=True, # Load the best model based on evaluation metric
    metric_for_best_model="f1", # Metric to use for early stopping/best model selection
    greater_is_better=True,
    report_to="none", # Disable reporting to W&B, MLflow etc.
)

# --- 8. Create Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")
trainer.train()
print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


Downloading 'punkt_tab' NLTK data for ROUGE evaluation...
'punkt_tab' downloaded.


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Loading generated GovReport dataset from: ./govreport_tfidf_vscode2


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]


GovReport dataset loaded successfully!
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 17517
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
})

Selecting a small subset of the dataset for quick testing. Adjust or remove for full training.
Subset selected:
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 10
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 5
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 5
    })
})


vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]


Generating extractive labels for dataset splits (this may take a while)...


Map (num_proc=2):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/5 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/5 [00:00<?, ? examples/s]


Extractive labels generated and tokenized.
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 2074
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 775
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 1838
    })
})


pytorch_model.bin:   0%|          | 0.00/597M [00:00<?, ?B/s]

Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LongformerForSequenceClassification model moved to GPU.

Starting Longformer extractive summarization training...


  trainer = Trainer(


model.safetensors:   0%|          | 0.00/597M [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.11,0.14224,0.978065,0.0,0.0,0.0
2,0.0727,0.133154,0.978065,0.0,0.0,0.0
3,0.1018,0.124034,0.978065,0.0,0.0,0.0



Training complete! Best model saved to: ./longformer_extractive_govreport/checkpoint-260

Evaluating on test set...


Test Results: {'eval_loss': 0.018999800086021423, 'eval_accuracy': 0.9972796517954298, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_runtime': 143.9305, 'eval_samples_per_second': 12.77, 'eval_steps_per_second': 6.385, 'epoch': 3.0}


#Training with 100 samples

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm.auto import tqdm # For progress bars

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
OUTPUT_DIR = "./longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your generated dataset with combined extractive summaries
# This should be the output directory from your first script
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# Max length for a single sentence input to LongformerForSequenceClassification
# Longformer's max input is 4096, but here we process sentence by sentence.
# A typical sentence max length is 512, but can be adjusted.
SENTENCE_MAX_LENGTH = 512

# --- 1. Load the Generated GovReport Dataset ---
print(f"Loading generated GovReport dataset from: {GENERATED_DATASET_DIR}")
try:
    data_files = {
        "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
        "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
        "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
    }

    # Load the dataset. It should contain 'original_text' and 'extractive_summary_combined'
    govreport_data_splits = load_dataset("json", data_files=data_files)

    print("\nGovReport dataset loaded successfully!")
    print(govreport_data_splits)

    # --- IMPORTANT: Select a small subset for quick testing/debugging ---
    # Comment out or adjust these lines for full dataset training
    print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
    govreport_data_splits["train"] = govreport_data_splits["train"].select(range(100))
    govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(10))
    govreport_data_splits["test"] = govreport_data_splits["test"].select(range(10))
    print("Subset selected:")
    print(govreport_data_splits)
    # --- END SUBSET SELECTION ---

    # Ensure the required columns exist
    if "original_text" not in govreport_data_splits["train"].column_names or \
       "extractive_summary" not in govreport_data_splits["train"].column_names:
        raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
except Exception as e:
    print(f"\nError loading GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
    print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
    exit()

# --- 2. Initialize Longformer Tokenizer ---
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
def generate_extractive_labels(examples):
    """
    Generates extractive labels for each sentence in a document based on ROUGE-L
    overlap with the combined extractive summary.
    """
    processed_examples = {
        "input_ids": [],
        "attention_mask": [],
        "global_attention_mask": [],
        "labels": [],
    }

    # Iterate through each example in the batch
    for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
        document_text = examples["original_text"][i]
        extractive_summary_reference = examples["extractive_summary"][i]

        if not document_text or not extractive_summary_reference:
            continue

        sentences = sent_tokenize(document_text)
        if not sentences:
            continue

        # Calculate ROUGE-L F1 for each sentence against the extractive summary reference
        sentence_rouge_scores = []
        for sent_idx, sentence in enumerate(sentences):
            # ROUGE metric expects lists of predictions and references
            # Use postprocess_text for proper sentence splitting for ROUGE
            processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                [sentence], [extractive_summary_reference]
            )
            score = rouge_metric.compute(
                predictions=processed_sent,
                references=processed_ref_summary,
                rouge_types=["rougeL"],
                use_stemmer=True,
            )
            # Store the fmeasure (F1 score). Access directly as it's a float.
            sentence_rouge_scores.append((score["rougeL"], sent_idx))

        # Sort sentences by ROUGE-L F1 score in descending order
        sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

        # Create binary labels for each sentence in the original document order
        # Heuristic: Select sentences that have a ROUGE-L F1 score above a threshold.
        # This threshold determines what is considered a "summary" sentence.
        ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality

        labels_for_document = [0] * len(sentences) # Initialize all labels to 0 (non-summary)

        # Keep track of selected sentences to avoid duplicates if needed, though not strictly
        # necessary for binary classification if only the label matters.
        # selected_sentence_indices = set()

        for score, sent_idx in sentence_rouge_scores:
            if score >= ROUGE_THRESHOLD:
                labels_for_document[sent_idx] = 1 # Mark as summary sentence
                # selected_sentence_indices.add(sent_idx)
            # Optional: You can also limit the total number of selected sentences per document
            # if len(selected_sentence_indices) >= MAX_EXTRACTIVE_SENTENCES_PER_DOC: break

        # Tokenize each sentence and add to processed_examples
        # Each sentence becomes a separate input example for the model
        for sent_idx, sentence in enumerate(sentences):
            inputs = tokenizer(
                sentence,
                truncation=True,
                max_length=SENTENCE_MAX_LENGTH,
                padding="max_length", # Pad all inputs to SENTENCE_MAX_LENGTH
                return_tensors="pt"
            )

            # Global attention mask: only on the CLS token for sentence classification
            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1 # Set global attention on the first token (CLS)

            processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
            processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
            processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
            processed_examples["labels"].append(labels_for_document[sent_idx])

    return processed_examples

# Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
def postprocess_text_for_rouge(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # ROUGE typically works better when sentences are separated by newlines
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]
    return preds, labels

# --- 4. Apply Label Generation to Dataset Splits ---
# This step will take a significant amount of time as it processes each sentence.
print("\nGenerating extractive labels for dataset splits (this may take a while)...")
# Get current column names to remove them after processing
original_column_names = govreport_data_splits["train"].column_names

tokenized_datasets = govreport_data_splits.map(
    generate_extractive_labels,
    batched=True, # Process examples in batches for efficiency
    remove_columns=original_column_names, # Remove original columns after processing
    num_proc=os.cpu_count() if os.cpu_count() else 1 # Use multiple processes if available
)
print("\nExtractive labels generated and tokenized.")
print(tokenized_datasets)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # Calculate precision, recall, f1, accuracy for binary classification
    # Assuming labels are 0 (non-summary) and 1 (summary)
    # We want to evaluate how well we predict the 'summary' class (1)
    from sklearn.metrics import precision_recall_fscore_support, accuracy_score

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
# num_labels=2 for binary classification (summary/non-summary)
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# If GPU is available, move model to GPU
if torch.cuda.is_available():
    model.to("cuda")
    print("LongformerForSequenceClassification model moved to GPU.")
else:
    print("No GPU found, model running on CPU.")

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3, # Adjust number of epochs
    per_device_train_batch_size=2, # Reduced batch size for lower RAM
    per_device_eval_batch_size=2, # Reduced batch size for lower RAM
    gradient_accumulation_steps=4, # Accumulate gradients over 4 steps to simulate a batch size of 8
    gradient_checkpointing=True, # Trade computation for memory
    fp16=torch.cuda.is_available(), # Use mixed precision if GPU is available
    warmup_steps=500, # Number of warmup steps for learning rate scheduler
    weight_decay=0.01, # Strength of weight decay
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="epoch", # Corrected: Use eval_strategy instead of evaluation_strategy
    save_strategy="epoch", # Save model at the end of each epoch
    load_best_model_at_end=True, # Load the best model based on evaluation metric
    metric_for_best_model="f1", # Metric to use for early stopping/best model selection
    greater_is_better=True,
    report_to="none", # Disable reporting to W&B, MLflow etc.
)

# --- 8. Create Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")
trainer.train()
print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


ModuleNotFoundError: No module named 'evaluate'

In [None]:
!pip install evaluate
!pip install rouge_score

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.5
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=cae0770bd476b8aee30148b1d62409d019cf8cbc34285eb983b6be59a73d383d
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


# Longformer training with checkpointing -1000 samples

In [9]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk # Added load_from_disk
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm import tqdm # Direct import of tqdm
from collections import Counter # To calculate class weights
from sklearn.metrics import precision_recall_fscore_support, accuracy_score # For compute_metrics

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
OUTPUT_DIR = "./longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your raw dataset with combined extractive summaries (input for label generation)
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# New: Path to store the processed (labeled and tokenized) dataset
PROCESSED_DATA_DIR = "/content/drive/My Drive/longformer_processed_govreport" # Ensure this is a Drive path for persistence
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)

# Max length for a single sentence input to LongformerForSequenceClassification
SENTENCE_MAX_LENGTH = 512

# --- 1. Load or Generate Processed Dataset ---
tokenized_datasets = None
govreport_data_splits = None # Initialize to None

# Check if processed dataset already exists on disk
if os.path.exists(PROCESSED_DATA_DIR) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "train")) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "validation")) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "test")):
    print(f"Loading processed dataset from: {PROCESSED_DATA_DIR}")
    try:
        tokenized_datasets = DatasetDict.load_from_disk(PROCESSED_DATA_DIR)
        print("\nProcessed dataset loaded successfully!")
        print(tokenized_datasets)
    except Exception as e:
        print(f"\nError loading processed dataset from {PROCESSED_DATA_DIR}: {e}")
        print("Attempting to regenerate dataset as loading failed...")
        tokenized_datasets = None # Reset to trigger regeneration
else: # This block runs if processed data is NOT found, meaning we need to load raw and process
    # --- Load the Generated GovReport Dataset (RAW) ---
    print(f"Loading raw GovReport dataset from: {GENERATED_DATASET_DIR}")
    try:
        data_files = {
            "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
            "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
            "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
        }
        govreport_data_splits = load_dataset("json", data_files=data_files)
        print("\nRaw GovReport dataset loaded successfully!")
        print(govreport_data_splits)

        # --- IMPORTANT: Select a small subset for quick testing/debugging ---
        print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
        govreport_data_splits["train"] = govreport_data_splits["train"].select(range(1000))
        govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(100))
        govreport_data_splits["test"] = govreport_data_splits["test"].select(range(100))
        print("Subset selected:")
        print(govreport_data_splits)
        # --- END SUBSET SELECTION ---

        # Ensure the required columns exist
        if "original_text" not in govreport_data_splits["train"].column_names or \
           "extractive_summary" not in govreport_data_splits["train"].column_names:
            raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
    except Exception as e:
        print(f"\nError loading raw GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
        print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
        exit() # Exit if raw data loading fails

    # --- Initialize Longformer Tokenizer (needed for generate_extractive_labels) ---
    # This tokenizer is specifically for the data preprocessing step
    tokenizer_for_preprocessing = LongformerTokenizer.from_pretrained(MODEL_NAME)

    # --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
    def generate_extractive_labels(examples):
        """
        Generates extractive labels for each sentence in a document based on ROUGE-L
        overlap with the combined extractive summary.
        """
        processed_examples = {
            "input_ids": [],
            "attention_mask": [],
            "global_attention_mask": [],
            "labels": [],
        }

        for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
            document_text = examples["original_text"][i]
            extractive_summary_reference = examples["extractive_summary"][i]

            if not document_text or not extractive_summary_reference:
                continue

            sentences = sent_tokenize(document_text)
            if not sentences:
                continue

            sentence_rouge_scores = []
            for sent_idx, sentence in enumerate(sentences):
                processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                    [sentence], [extractive_summary_reference]
                )
                score = rouge_metric.compute(
                    predictions=processed_sent,
                    references=processed_ref_summary,
                    rouge_types=["rougeL"],
                    use_stemmer=True,
                )
                sentence_rouge_scores.append((score["rougeL"], sent_idx))

            sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

            ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality
            labels_for_document = [0] * len(sentences)

            for score, sent_idx in sentence_rouge_scores:
                if score >= ROUGE_THRESHOLD:
                    labels_for_document[sent_idx] = 1

            for sent_idx, sentence in enumerate(sentences):
                inputs = tokenizer_for_preprocessing( # Use the tokenizer defined for preprocessing
                    sentence,
                    truncation=True,
                    max_length=SENTENCE_MAX_LENGTH,
                    padding="max_length",
                    return_tensors="pt"
                )

                global_attention_mask = torch.zeros_like(inputs["input_ids"])
                global_attention_mask[:, 0] = 1

                processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
                processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
                processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
                processed_examples["labels"].append(labels_for_document[sent_idx])

        return processed_examples

    # Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
    def postprocess_text_for_rouge(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
        preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(sent_tokenize(label)) for label in labels]
        return preds, labels

    # --- 4. Apply Label Generation to Dataset Splits ---
    print("\nGenerating extractive labels for dataset splits (this may take a while)...")
    # This line is now safely inside the else block where govreport_data_splits is guaranteed to be defined
    original_column_names = govreport_data_splits["train"].column_names

    tokenized_datasets = govreport_data_splits.map(
        generate_extractive_labels,
        batched=True,
        remove_columns=original_column_names,
        num_proc=os.cpu_count() if os.cpu_count() else 1
    )
    print("\nExtractive labels generated and tokenized.")
    print(tokenized_datasets)

    # --- Save the Processed Dataset ---
    print(f"\nSaving processed dataset to: {PROCESSED_DATA_DIR}")
    tokenized_datasets.save_to_disk(PROCESSED_DATA_DIR)
    print("Processed dataset saved successfully!")

# Ensure tokenizer is initialized globally for the Trainer, regardless of data loading path
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# --- Calculate Class Weights for Imbalanced Data ---
train_labels = tokenized_datasets["train"]["labels"]
label_counts = Counter(train_labels)
num_class_0 = label_counts.get(0, 0)
num_class_1 = label_counts.get(1, 0)
total_samples = len(train_labels)

print(f"\nLabel distribution in training set: Class 0 (Non-summary): {num_class_0}, Class 1 (Summary): {num_class_1}")

if num_class_0 > 0 and num_class_1 > 0:
    weight_for_class_0 = total_samples / (2 * num_class_0)
    weight_for_class_1 = total_samples / (2 * num_class_1)
    class_weights = torch.tensor([weight_for_class_0, weight_for_class_1], dtype=torch.float)
    print(f"Calculated class weights: Class 0: {weight_for_class_0:.4f}, Class 1: {weight_for_class_1:.4f}")
else:
    print("Warning: One or both classes are missing in the training data. Using uniform weights.")
    class_weights = torch.tensor([1.0, 1.0], dtype=torch.float)


# Move class weights to the same device as the model
if torch.cuda.is_available():
    class_weights = class_weights.to("cuda")
    model.to("cuda")
    print("LongformerForSequenceClassification model and class weights moved to GPU.")
else:
    print("No GPU found, model and class weights running on CPU.")

# Override the default loss function to use class weights
class CustomTrainer(Trainer):
    # Updated signature to accept num_items_in_batch and other kwargs
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(), # Enable mixed precision training if GPU is available
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    report_to="none",
)

# --- 8. Create Trainer ---
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")

# Check for existing checkpoints to resume from
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
    checkpoints = [d for d in os.listdir(training_args.output_dir) if d.startswith("checkpoint-")]
    if checkpoints:
        last_checkpoint = os.path.join(training_args.output_dir, max(checkpoints, key=lambda x: int(x.split('-')[1])))
        print(f"Found existing checkpoint: {last_checkpoint}. Resuming training from here.")

trainer.train(resume_from_checkpoint=last_checkpoint)

print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


Downloading 'punkt_tab' NLTK data for ROUGE evaluation...
'punkt_tab' downloaded.


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Loading processed dataset from: /content/drive/My Drive/longformer_processed_govreport

Processed dataset loaded successfully!
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 259432
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 27214
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 28855
    })
})


Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Label distribution in training set: Class 0 (Non-summary): 253686, Class 1 (Summary): 5746
Calculated class weights: Class 0: 0.5113, Class 1: 22.5750
LongformerForSequenceClassification model and class weights moved to GPU.

Starting Longformer extractive summarization training...


  trainer = CustomTrainer(


Found existing checkpoint: ./longformer_extractive_govreport/checkpoint-6268. Resuming training from here.


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

#Data tokenization for 1000-100-100 samples with checkpointing

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import LongformerTokenizer
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm import tqdm # Direct import of tqdm
# For progress bars

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for tokenizer
# Path to your raw dataset with combined extractive summaries
# IMPORTANT: If your raw data is also in Google Drive, update this path accordingly
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# Path to store the processed (labeled and tokenized) dataset
# IMPORTANT: Use a Google Drive path for persistence in Colab
PROCESSED_DATA_DIR = "/content/drive/My Drive/longformer_processed_govreport"
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)

# Max length for a single sentence input to LongformerForSequenceClassification
SENTENCE_MAX_LENGTH = 512

# --- 1. Load the Raw GovReport Dataset ---
print(f"Loading raw GovReport dataset from: {GENERATED_DATASET_DIR}")
try:
    data_files = {
        "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
        "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
        "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
    }

    govreport_data_splits = load_dataset("json", data_files=data_files)

    print("\nRaw GovReport dataset loaded successfully!")
    print(govreport_data_splits)

    # --- IMPORTANT: Select a small subset for quick testing/debugging ---
    # For full dataset processing, comment out or adjust these lines.
    print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full processing.")
    govreport_data_splits["train"] = govreport_data_splits["train"].select(range(1000))
    govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(100))
    govreport_data_splits["test"] = govreport_data_splits["test"].select(range(100))
    print("Subset selected:")
    print(govreport_data_splits)
    # --- END SUBSET SELECTION ---

    # Ensure the required columns exist
    if "original_text" not in govreport_data_splits["train"].column_names or \
       "extractive_summary" not in govreport_data_splits["train"].column_names:
        raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
except Exception as e:
    print(f"\nError loading raw GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
    print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
    exit()

# --- 2. Initialize Longformer Tokenizer ---
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
# This function also performs the tokenization of sentences
def generate_extractive_labels(examples):
    """
    Generates extractive labels for each sentence in a document based on ROUGE-L
    overlap with the combined extractive summary, and tokenizes the sentences.
    """
    processed_examples = {
        "input_ids": [],
        "attention_mask": [],
        "global_attention_mask": [],
        "labels": [],
    }

    # Use tqdm explicitly
    for i in tqdm(range(len(examples["original_text"])), desc="Generating labels and tokenizing"):
        document_text = examples["original_text"][i]
        extractive_summary_reference = examples["extractive_summary"][i]

        if not document_text or not extractive_summary_reference:
            continue

        sentences = sent_tokenize(document_text)
        if not sentences:
            continue

        sentence_rouge_scores = []
        for sent_idx, sentence in enumerate(sentences):
            processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                [sentence], [extractive_summary_reference]
            )
            score = rouge_metric.compute(
                predictions=processed_sent,
                references=processed_ref_summary,
                rouge_types=["rougeL"],
                use_stemmer=True,
            )
            sentence_rouge_scores.append((score["rougeL"], sent_idx))

        sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

        ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality
        labels_for_document = [0] * len(sentences)

        for score, sent_idx in sentence_rouge_scores:
            if score >= ROUGE_THRESHOLD:
                labels_for_document[sent_idx] = 1

        for sent_idx, sentence in enumerate(sentences):
            inputs = tokenizer( # This is where tokenization happens
                sentence,
                truncation=True,
                max_length=SENTENCE_MAX_LENGTH,
                padding="max_length", # Pad all inputs to SENTENCE_MAX_LENGTH
                return_tensors="pt"
            )

            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1 # Set global attention on the first token (CLS)

            processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
            processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
            processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
            processed_examples["labels"].append(labels_for_document[sent_idx])

    return processed_examples

# Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
def postprocess_text_for_rouge(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]
    return preds, labels

# --- 4. Apply Label Generation and Tokenization to Dataset Splits ---
print("\nApplying label generation and tokenization to dataset splits (this may take a while)...")
original_column_names = govreport_data_splits["train"].column_names

tokenized_datasets = govreport_data_splits.map(
    generate_extractive_labels,
    batched=True, # Process examples in batches for efficiency
    remove_columns=original_column_names, # Remove original columns after processing
    num_proc=1 # IMPORTANT: Set num_proc=1 for more reliable saving to Google Drive
)
print("\nLabels generated and sentences tokenized.")
print(tokenized_datasets)

# --- 5. Save the Processed Dataset ---
print(f"\nSaving processed dataset to: {PROCESSED_DATA_DIR}")
try:
    tokenized_datasets.save_to_disk(PROCESSED_DATA_DIR)
    print("Processed dataset saved successfully!")

    # --- Post-save verification ---
    expected_splits = ["train", "validation", "test"]
    all_splits_saved = True
    for split in expected_splits:
        split_path = os.path.join(PROCESSED_DATA_DIR, split)
        if not os.path.exists(split_path) or not os.listdir(split_path):
            print(f"Warning: Directory for split '{split}' not found or empty at {split_path}")
            all_splits_saved = False
    if all_splits_saved:
        print("All expected dataset splits verified on disk.")
    else:
        print("Warning: Some dataset splits might be missing or incomplete. Please check the directory.")

except Exception as e:
    print(f"Error saving processed dataset to {PROSED_DATA_DIR}: {e}")
    print("The save operation might have been interrupted or failed. Please check your Google Drive connection and disk space.")


print("\nThis script has completed the data preparation (tokenization and label generation).")
print(f"The processed dataset is saved to: {PROCESSED_DATA_DIR}")
print("You can now use this dataset for training your Longformer model without regenerating labels.")


Downloading 'punkt_tab' NLTK data for ROUGE evaluation...
'punkt_tab' downloaded.


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Loading raw GovReport dataset from: ./govreport_tfidf_vscode2


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]


Raw GovReport dataset loaded successfully!
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 17517
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 973
    })
})

Selecting a small subset of the dataset for quick testing. Adjust or remove for full processing.
Subset selected:
DatasetDict({
    train: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 100
    })
    test: Dataset({
        features: ['original_text', 'extractive_summary', 'abstractive_summary'],
        num_rows: 100
    })
})


vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]


Applying label generation and tokenization to dataset splits (this may take a while)...


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]


Generating labels and tokenizing:   0%|          | 0/1000 [00:00<?, ?it/s][A
Generating labels and tokenizing:   0%|          | 1/1000 [00:24<6:49:16, 24.58s/it][A
Generating labels and tokenizing:   0%|          | 2/1000 [00:48<6:41:05, 24.11s/it][A
Generating labels and tokenizing:   0%|          | 3/1000 [01:03<5:33:04, 20.04s/it][A
Generating labels and tokenizing:   0%|          | 4/1000 [01:17<4:51:46, 17.58s/it][A
Generating labels and tokenizing:   1%|          | 6/1000 [01:29<3:10:47, 11.52s/it][A
Generating labels and tokenizing:   1%|          | 7/1000 [01:38<2:58:36, 10.79s/it][A
Generating labels and tokenizing:   1%|          | 8/1000 [01:57<3:38:19, 13.20s/it][A
Generating labels and tokenizing:   1%|          | 9/1000 [02:08<3:26:25, 12.50s/it][A
Generating labels and tokenizing:   1%|          | 10/1000 [02:32<4:20:21, 15.78s/it][A
Generating labels and tokenizing:   1%|          | 11/1000 [02:54<4:52:44, 17.76s/it][A
Generating labels and tokenizing:   1%|

Map:   0%|          | 0/100 [00:00<?, ? examples/s]


Generating labels and tokenizing:   0%|          | 0/100 [00:00<?, ?it/s][A
Generating labels and tokenizing:   1%|          | 1/100 [00:03<06:27,  3.91s/it][A
Generating labels and tokenizing:   2%|▏         | 2/100 [00:12<10:46,  6.60s/it][A
Generating labels and tokenizing:   3%|▎         | 3/100 [00:20<11:52,  7.34s/it][A
Generating labels and tokenizing:   4%|▍         | 4/100 [00:31<14:10,  8.85s/it][A
Generating labels and tokenizing:   5%|▌         | 5/100 [00:46<17:08, 10.82s/it][A
Generating labels and tokenizing:   6%|▌         | 6/100 [01:00<18:56, 12.09s/it][A
Generating labels and tokenizing:   7%|▋         | 7/100 [01:20<22:26, 14.48s/it][A
Generating labels and tokenizing:   8%|▊         | 8/100 [01:37<23:49, 15.53s/it][A
Generating labels and tokenizing:   9%|▉         | 9/100 [01:51<22:28, 14.82s/it][A
Generating labels and tokenizing:  10%|█         | 10/100 [02:02<20:50, 13.90s/it][A
Generating labels and tokenizing:  11%|█         | 11/100 [02:09<17:19,

Map:   0%|          | 0/100 [00:00<?, ? examples/s]


Generating labels and tokenizing:   0%|          | 0/100 [00:00<?, ?it/s][A
Generating labels and tokenizing:   1%|          | 1/100 [00:29<48:53, 29.63s/it][A
Generating labels and tokenizing:   2%|▏         | 2/100 [00:53<42:52, 26.25s/it][A
Generating labels and tokenizing:   3%|▎         | 3/100 [01:17<40:24, 24.99s/it][A
Generating labels and tokenizing:   4%|▍         | 4/100 [01:39<38:39, 24.16s/it][A
Generating labels and tokenizing:   5%|▌         | 5/100 [02:03<37:53, 23.93s/it][A
Generating labels and tokenizing:   6%|▌         | 6/100 [02:08<27:25, 17.51s/it][A
Generating labels and tokenizing:   7%|▋         | 7/100 [02:20<24:24, 15.75s/it][A
Generating labels and tokenizing:   8%|▊         | 8/100 [02:48<30:12, 19.71s/it][A
Generating labels and tokenizing:   9%|▉         | 9/100 [03:06<28:51, 19.02s/it][A
Generating labels and tokenizing:  10%|█         | 10/100 [03:37<34:08, 22.76s/it][A
Generating labels and tokenizing:  11%|█         | 11/100 [03:43<26:18,


Labels generated and sentences tokenized.
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 259432
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 27214
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 28855
    })
})

Saving processed dataset to: /content/drive/My Drive/longformer_processed_govreport


Saving the dataset (0/4 shards):   0%|          | 0/259432 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/27214 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/28855 [00:00<?, ? examples/s]

Processed dataset saved successfully!
All expected dataset splits verified on disk.

This script has completed the data preparation (tokenization and label generation).
The processed dataset is saved to: /content/drive/My Drive/longformer_processed_govreport
You can now use this dataset for training your Longformer model without regenerating labels.


#Step 1: Data preparation and label generation

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import LongformerTokenizer
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm.auto import tqdm # For progress bars

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
# Path to your generated dataset with combined extractive summaries
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"
# Directory where the processed (tokenized and labeled) dataset will be saved
PROCESSED_DATA_DIR = "./longformer_processed_govreport"
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)

# Max length for a single sentence input to LongformerForSequenceClassification
SENTENCE_MAX_LENGTH = 512

# --- 1. Load the Generated GovReport Dataset ---
print(f"Loading generated GovReport dataset from: {GENERATED_DATASET_DIR}")
try:
    data_files = {
        "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
        "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
        "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
    }

    govreport_data_splits = load_dataset("json", data_files=data_files)

    print("\nGovReport dataset loaded successfully!")
    print(govreport_data_splits)

    # --- IMPORTANT: Select a small subset for quick testing/debugging ---
    # For full dataset processing, comment out or adjust these lines.
    # Keep this active for initial runs to ensure the pipeline works.
    print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full processing.")
    govreport_data_splits["train"] = govreport_data_splits["train"].select(range(100))
    govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(10))
    govreport_data_splits["test"] = govreport_data_splits["test"].select(range(10))
    print("Subset selected:")
    print(govreport_data_splits)
    # --- END SUBSET SELECTION ---

    # Ensure the required columns exist
    if "original_text" not in govreport_data_splits["train"].column_names or \
       "extractive_summary" not in govreport_data_splits["train"].column_names:
        raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
except Exception as e:
    print(f"\nError loading GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
    print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
    exit()

# --- 2. Initialize Longformer Tokenizer ---
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
def generate_extractive_labels(examples):
    """
    Generates extractive labels for each sentence in a document based on ROUGE-L
    overlap with the combined extractive summary.
    """
    processed_examples = {
        "input_ids": [],
        "attention_mask": [],
        "global_attention_mask": [],
        "labels": [],
    }

    # Iterate through each example in the batch
    for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
        document_text = examples["original_text"][i]
        extractive_summary_reference = examples["extractive_summary"][i]

        if not document_text or not extractive_summary_reference:
            continue

        sentences = sent_tokenize(document_text)
        if not sentences:
            continue

        # Calculate ROUGE-L F1 for each sentence against the extractive summary reference
        sentence_rouge_scores = []
        for sent_idx, sentence in enumerate(sentences):
            processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                [sentence], [extractive_summary_reference]
            )
            score = rouge_metric.compute(
                predictions=processed_sent,
                references=processed_ref_summary,
                rouge_types=["rougeL"],
                use_stemmer=True,
            )
            sentence_rouge_scores.append((score["rougeL"], sent_idx))

        # Sort sentences by ROUGE-L F1 score in descending order
        sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

        ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality
        labels_for_document = [0] * len(sentences)

        for score, sent_idx in sentence_rouge_scores:
            if score >= ROUGE_THRESHOLD:
                labels_for_document[sent_idx] = 1

        for sent_idx, sentence in enumerate(sentences):
            inputs = tokenizer(
                sentence,
                truncation=True,
                max_length=SENTENCE_MAX_LENGTH,
                padding="max_length",
                return_tensors="pt"
            )

            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1

            processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
            processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
            processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
            processed_examples["labels"].append(labels_for_document[sent_idx])

    return processed_examples

# Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
def postprocess_text_for_rouge(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]
    return preds, labels

# --- 4. Apply Label Generation to Dataset Splits ---
print("\nGenerating extractive labels for dataset splits (this may take a while)...")
original_column_names = govreport_data_splits["train"].column_names

tokenized_datasets = govreport_data_splits.map(
    generate_extractive_labels,
    batched=True,
    remove_columns=original_column_names,
    num_proc=os.cpu_count() if os.cpu_count() else 1
)
print("\nExtractive labels generated and tokenized.")
print(tokenized_datasets)

# --- 5. Save the Processed Dataset ---
print(f"\nSaving processed dataset to: {PROCESSED_DATA_DIR}")
tokenized_datasets.save_to_disk(PROCESSED_DATA_DIR)
print("Processed dataset saved successfully!")


#Step2 : Model training

In [8]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk # Added load_from_disk
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm import tqdm # Direct import of tqdm
from collections import Counter # To calculate class weights
from sklearn.metrics import precision_recall_fscore_support, accuracy_score # For compute_metrics

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
OUTPUT_DIR = "./longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your raw dataset with combined extractive summaries (input for label generation)
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# New: Path to store the processed (labeled and tokenized) dataset
PROCESSED_DATA_DIR = "/content/drive/My Drive/longformer_processed_govreport" # Ensure this is a Drive path for persistence
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)

# Max length for a single sentence input to LongformerForSequenceClassification
SENTENCE_MAX_LENGTH = 512

# --- 1. Load or Generate Processed Dataset ---
tokenized_datasets = None
govreport_data_splits = None # Initialize to None

# Check if processed dataset already exists on disk
if os.path.exists(PROCESSED_DATA_DIR) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "train")) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "validation")) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "test")):
    print(f"Loading processed dataset from: {PROCESSED_DATA_DIR}")
    try:
        tokenized_datasets = DatasetDict.load_from_disk(PROCESSED_DATA_DIR)
        print("\nProcessed dataset loaded successfully!")
        print(tokenized_datasets)
    except Exception as e:
        print(f"\nError loading processed dataset from {PROCESSED_DATA_DIR}: {e}")
        print("Attempting to regenerate dataset as loading failed...")
        tokenized_datasets = None # Reset to trigger regeneration
else: # This block runs if processed data is NOT found, meaning we need to load raw and process
    # --- Load the Generated GovReport Dataset (RAW) ---
    print(f"Loading raw GovReport dataset from: {GENERATED_DATASET_DIR}")
    try:
        data_files = {
            "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
            "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
            "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
        }
        govreport_data_splits = load_dataset("json", data_files=data_files)
        print("\nRaw GovReport dataset loaded successfully!")
        print(govreport_data_splits)

        # --- IMPORTANT: Select a small subset for quick testing/debugging ---
        print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
        govreport_data_splits["train"] = govreport_data_splits["train"].select(range(1000))
        govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(100))
        govreport_data_splits["test"] = govreport_data_splits["test"].select(range(100))
        print("Subset selected:")
        print(govreport_data_splits)
        # --- END SUBSET SELECTION ---

        # Ensure the required columns exist
        if "original_text" not in govreport_data_splits["train"].column_names or \
           "extractive_summary" not in govreport_data_splits["train"].column_names:
            raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
    except Exception as e:
        print(f"\nError loading raw GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
        print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
        exit() # Exit if raw data loading fails

    # --- Initialize Longformer Tokenizer (needed for generate_extractive_labels) ---
    # This tokenizer is specifically for the data preprocessing step
    tokenizer_for_preprocessing = LongformerTokenizer.from_pretrained(MODEL_NAME)

    # --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
    def generate_extractive_labels(examples):
        """
        Generates extractive labels for each sentence in a document based on ROUGE-L
        overlap with the combined extractive summary.
        """
        processed_examples = {
            "input_ids": [],
            "attention_mask": [],
            "global_attention_mask": [],
            "labels": [],
        }

        for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
            document_text = examples["original_text"][i]
            extractive_summary_reference = examples["extractive_summary"][i]

            if not document_text or not extractive_summary_reference:
                continue

            sentences = sent_tokenize(document_text)
            if not sentences:
                continue

            sentence_rouge_scores = []
            for sent_idx, sentence in enumerate(sentences):
                processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                    [sentence], [extractive_summary_reference]
                )
                score = rouge_metric.compute(
                    predictions=processed_sent,
                    references=processed_ref_summary,
                    rouge_types=["rougeL"],
                    use_stemmer=True,
                )
                sentence_rouge_scores.append((score["rougeL"], sent_idx))

            sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

            ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality
            labels_for_document = [0] * len(sentences)

            for score, sent_idx in sentence_rouge_scores:
                if score >= ROUGE_THRESHOLD:
                    labels_for_document[sent_idx] = 1

            for sent_idx, sentence in enumerate(sentences):
                inputs = tokenizer_for_preprocessing( # Use the tokenizer defined for preprocessing
                    sentence,
                    truncation=True,
                    max_length=SENTENCE_MAX_LENGTH,
                    padding="max_length",
                    return_tensors="pt"
                )

                global_attention_mask = torch.zeros_like(inputs["input_ids"])
                global_attention_mask[:, 0] = 1

                processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
                processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
                processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
                processed_examples["labels"].append(labels_for_document[sent_idx])

        return processed_examples

    # Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
    def postprocess_text_for_rouge(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
        preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(sent_tokenize(label)) for label in labels]
        return preds, labels

    # --- 4. Apply Label Generation to Dataset Splits ---
    print("\nGenerating extractive labels for dataset splits (this may take a while)...")
    # This line is now safely inside the else block where govreport_data_splits is guaranteed to be defined
    original_column_names = govreport_data_splits["train"].column_names

    tokenized_datasets = govreport_data_splits.map(
        generate_extractive_labels,
        batched=True,
        remove_columns=original_column_names,
        num_proc=os.cpu_count() if os.cpu_count() else 1
    )
    print("\nExtractive labels generated and tokenized.")
    print(tokenized_datasets)

    # --- Save the Processed Dataset ---
    print(f"\nSaving processed dataset to: {PROCESSED_DATA_DIR}")
    tokenized_datasets.save_to_disk(PROCESSED_DATA_DIR)
    print("Processed dataset saved successfully!")

# Ensure tokenizer is initialized globally for the Trainer, regardless of data loading path
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# --- Calculate Class Weights for Imbalanced Data ---
train_labels = tokenized_datasets["train"]["labels"]
label_counts = Counter(train_labels)
num_class_0 = label_counts.get(0, 0)
num_class_1 = label_counts.get(1, 0)
total_samples = len(train_labels)

print(f"\nLabel distribution in training set: Class 0 (Non-summary): {num_class_0}, Class 1 (Summary): {num_class_1}")

if num_class_0 > 0 and num_class_1 > 0:
    weight_for_class_0 = total_samples / (2 * num_class_0)
    weight_for_class_1 = total_samples / (2 * num_class_1)
    class_weights = torch.tensor([weight_for_class_0, weight_for_class_1], dtype=torch.float)
    print(f"Calculated class weights: Class 0: {weight_for_class_0:.4f}, Class 1: {weight_for_class_1:.4f}")
else:
    print("Warning: One or both classes are missing in the training data. Using uniform weights.")
    class_weights = torch.tensor([1.0, 1.0], dtype=torch.float)


# Move class weights to the same device as the model
if torch.cuda.is_available():
    class_weights = class_weights.to("cuda")
    model.to("cuda")
    print("LongformerForSequenceClassification model and class weights moved to GPU.")
else:
    print("No GPU found, model and class weights running on CPU.")

# Override the default loss function to use class weights
class CustomTrainer(Trainer):
    # Updated signature to accept num_items_in_batch and other kwargs
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(), # Enable mixed precision training if GPU is available
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    report_to="none",
)

# --- 8. Create Trainer ---
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")

# Check for existing checkpoints to resume from
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
    checkpoints = [d for d in os.listdir(training_args.output_dir) if d.startswith("checkpoint-")]
    if checkpoints:
        last_checkpoint = os.path.join(training_args.output_dir, max(checkpoints, key=lambda x: int(x.split('-')[1])))
        print(f"Found existing checkpoint: {last_checkpoint}. Resuming training from here.")

trainer.train(resume_from_checkpoint=last_checkpoint)

print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


Downloading 'punkt' NLTK data for sentence splitting...


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


'punkt' downloaded.
Downloading 'punkt_tab' NLTK data for ROUGE evaluation...
'punkt_tab' downloaded.


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script: 0.00B [00:00, ?B/s]

Loading processed dataset from: /content/drive/My Drive/longformer_processed_govreport

Processed dataset loaded successfully!
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 259432
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 27214
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 28855
    })
})


vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/597M [00:00<?, ?B/s]

Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/597M [00:00<?, ?B/s]


Label distribution in training set: Class 0 (Non-summary): 253686, Class 1 (Summary): 5746
Calculated class weights: Class 0: 0.5113, Class 1: 22.5750
LongformerForSequenceClassification model and class weights moved to GPU.

Starting Longformer extractive summarization training...
Found existing checkpoint: ./longformer_extractive_govreport/checkpoint-6268. Resuming training from here.


  trainer = CustomTrainer(


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

#updated checkpointing model training

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk # Added load_from_disk
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate # Hugging Face's evaluate library for metrics
import nltk
from nltk.tokenize import sent_tokenize
from tqdm import tqdm # Direct import of tqdm
from collections import Counter # To calculate class weights
from sklearn.metrics import precision_recall_fscore_support, accuracy_score # For compute_metrics

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# This ensures you have the latest compatible versions of transformers, datasets, and accelerate
# which are necessary for the TrainingArguments parameters used.
# !pip install --upgrade transformers datasets accelerate
# If the above line is commented out, uncomment it and run this cell.

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    # 'punkt_tab' is often used by ROUGE internally, especially with stemming
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

# Ensure ROUGE metric is available
try:
    rouge_metric = evaluate.load("rouge")
except Exception:
    print("Downloading 'rouge' metric...")
    rouge_metric = evaluate.load("rouge")
    print("'rouge' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096" # Base Longformer model for extractive summarization
# IMPORTANT: Change OUTPUT_DIR to a Google Drive path for persistence
OUTPUT_DIR = "/content/drive/My Drive/longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Path to your raw dataset with combined extractive summaries (input for label generation)
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"

# New: Path to store the processed (labeled and tokenized) dataset
PROCESSED_DATA_DIR = "/content/drive/My Drive/longformer_processed_govreport" # Ensure this is a Drive path for persistence
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)

# Max length for a single sentence input to LongformerForSequenceClassification
SENTENCE_MAX_LENGTH = 512

# --- 1. Load or Generate Processed Dataset ---
tokenized_datasets = None
govreport_data_splits = None # Initialize to None

# Check if processed dataset already exists on disk
if os.path.exists(PROCESSED_DATA_DIR) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "train")) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "validation")) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "test")):
    print(f"Loading processed dataset from: {PROCESSED_DATA_DIR}")
    try:
        tokenized_datasets = DatasetDict.load_from_disk(PROCESSED_DATA_DIR)
        print("\nProcessed dataset loaded successfully!")
        print(tokenized_datasets)
    except Exception as e:
        print(f"\nError loading processed dataset from {PROCESSED_DATA_DIR}: {e}")
        print("Attempting to regenerate dataset as loading failed...")
        tokenized_datasets = None # Reset to trigger regeneration
else: # This block runs if processed data is NOT found, meaning we need to load raw and process
    # --- Load the Generated GovReport Dataset (RAW) ---
    print(f"Loading raw GovReport dataset from: {GENERATED_DATASET_DIR}")
    try:
        data_files = {
            "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
            "validation": os.path.join(GENERATED_DATASET_DIR, "validation.json"),
            "test": os.path.join(GENERATED_DATASET_DIR, "test.json"),
        }
        govreport_data_splits = load_dataset("json", data_files=data_files)
        print("\nRaw GovReport dataset loaded successfully!")
        print(govreport_data_splits)

        # --- IMPORTANT: Select a small subset for quick testing/debugging ---
        print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
        govreport_data_splits["train"] = govreport_data_splits["train"].select(range(1000))
        govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(100))
        govreport_data_splits["test"] = govreport_data_splits["test"].select(range(100))
        print("Subset selected:")
        print(govreport_data_splits)
        # --- END SUBSET SELECTION ---

        # Ensure the required columns exist
        if "original_text" not in govreport_data_splits["train"].column_names or \
           "extractive_summary" not in govreport_data_splits["train"].column_names:
            raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
    except Exception as e:
        print(f"\nError loading raw GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
        print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
        exit() # Exit if raw data loading fails

    # --- Initialize Longformer Tokenizer (needed for generate_extractive_labels) ---
    # This tokenizer is specifically for the data preprocessing step
    tokenizer_for_preprocessing = LongformerTokenizer.from_pretrained(MODEL_NAME)

    # --- 3. Function to Generate Extractive Labels (Oracle Summaries) ---
    def generate_extractive_labels(examples):
        """
        Generates extractive labels for each sentence in a document based on ROUGE-L
        overlap with the combined extractive summary.
        """
        processed_examples = {
            "input_ids": [],
            "attention_mask": [],
            "global_attention_mask": [],
            "labels": [],
        }

        for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
            document_text = examples["original_text"][i]
            extractive_summary_reference = examples["extractive_summary"][i]

            if not document_text or not extractive_summary_reference:
                continue

            sentences = sent_tokenize(document_text)
            if not sentences:
                continue

            sentence_rouge_scores = []
            for sent_idx, sentence in enumerate(sentences):
                processed_sent, processed_ref_summary = postprocess_text_for_rouge(
                    [sentence], [extractive_summary_reference]
                )
                score = rouge_metric.compute(
                    predictions=processed_sent,
                    references=processed_ref_summary,
                    rouge_types=["rougeL"],
                    use_stemmer=True,
                )
                sentence_rouge_scores.append((score["rougeL"], sent_idx))

            sentence_rouge_scores.sort(key=lambda x: x[0], reverse=True)

            ROUGE_THRESHOLD = 0.20 # Adjust this threshold based on desired summary density/quality
            labels_for_document = [0] * len(sentences)

            for score, sent_idx in sentence_rouge_scores:
                if score >= ROUGE_THRESHOLD:
                    labels_for_document[sent_idx] = 1

            for sent_idx, sentence in enumerate(sentences):
                inputs = tokenizer_for_preprocessing( # Use the tokenizer defined for preprocessing
                    sentence,
                    truncation=True,
                    max_length=SENTENCE_MAX_LENGTH,
                    padding="max_length",
                    return_tensors="pt"
                )

                global_attention_mask = torch.zeros_like(inputs["input_ids"])
                global_attention_mask[:, 0] = 1

                processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
                processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
                processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
                processed_examples["labels"].append(labels_for_document[sent_idx])

        return processed_examples

    # Helper for ROUGE (ensures consistent sentence splitting for metric calculation)
    def postprocess_text_for_rouge(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
        preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(sent_tokenize(label)) for label in labels]
        return preds, labels

    # --- 4. Apply Label Generation to Dataset Splits ---
    print("\nGenerating extractive labels for dataset splits (this may take a while)...")
    # This line is now safely inside the else block where govreport_data_splits is guaranteed to be defined
    original_column_names = govreport_data_splits["train"].column_names

    tokenized_datasets = govreport_data_splits.map(
        generate_extractive_labels,
        batched=True,
        remove_columns=original_column_names,
        num_proc=os.cpu_count() if os.cpu_count() else 1
    )
    print("\nExtractive labels generated and tokenized.")
    print(tokenized_datasets)

    # --- Save the Processed Dataset ---
    print(f"\nSaving processed dataset to: {PROCESSED_DATA_DIR}")
    tokenized_datasets.save_to_disk(PROCESSED_DATA_DIR)
    print("Processed dataset saved successfully!")

# Ensure tokenizer is initialized globally for the Trainer, regardless of data loading path
tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

# --- 5. Define Metrics for Training ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


# --- 6. Initialize LongformerForSequenceClassification Model ---
model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# --- Calculate Class Weights for Imbalanced Data ---
train_labels = tokenized_datasets["train"]["labels"]
label_counts = Counter(train_labels)
num_class_0 = label_counts.get(0, 0)
num_class_1 = label_counts.get(1, 0)
total_samples = len(train_labels)

print(f"\nLabel distribution in training set: Class 0 (Non-summary): {num_class_0}, Class 1 (Summary): {num_class_1}")

if num_class_0 > 0 and num_class_1 > 0:
    weight_for_class_0 = total_samples / (2 * num_class_0)
    weight_for_class_1 = total_samples / (2 * num_class_1)
    class_weights = torch.tensor([weight_for_class_0, weight_for_class_1], dtype=torch.float)
    print(f"Calculated class weights: Class 0: {weight_for_class_0:.4f}, Class 1: {weight_for_class_1:.4f}")
else:
    print("Warning: One or both classes are missing in the training data. Using uniform weights.")
    class_weights = torch.tensor([1.0, 1.0], dtype=torch.float)


# Move class weights to the same device as the model
if torch.cuda.is_available():
    class_weights = class_weights.to("cuda")
    model.to("cuda")
    print("LongformerForSequenceClassification model and class weights moved to GPU.")
else:
    print("No GPU found, model and class weights running on CPU.")

# Override the default loss function to use class weights
class CustomTrainer(Trainer):
    # Updated signature to accept num_items_in_batch and other kwargs
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# --- 7. Set up Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(), # Enable mixed precision training if GPU is available
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="steps", # Evaluate based on steps
    save_strategy="steps", # Save checkpoints based on steps
    save_steps=50,        # Save every 50 training steps
    eval_steps=500,        # Evaluate every 50 training steps (match save_steps)
    save_total_limit=2, # Keep the last 2 checkpoints (current and previous)
    load_best_model_at_end=False, # CHANGED: Do NOT load best model based on metric
    # metric_for_best_model="f1", # REMOVED: No longer needed with load_best_model_at_end=False
    # greater_is_better=True,     # REMOVED: No longer needed
    report_to="none",
)

# --- 8. Create Trainer ---
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- 9. Start Training ---
print("\nStarting Longformer extractive summarization training...")

# Check for existing checkpoints to resume from
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
    checkpoints = [d for d in os.listdir(training_args.output_dir) if d.startswith("checkpoint-")]
    if checkpoints:
        # Sort by step number to get the latest checkpoint (max number of steps completed)
        last_checkpoint = os.path.join(training_args.output_dir, max(checkpoints, key=lambda x: int(x.split('-')[1])))
        print(f"Found existing checkpoint: {last_checkpoint}. Resuming training from here.")

try:
    trainer.train(resume_from_checkpoint=last_checkpoint)
except Exception as e:
    print(f"Error during training or resuming: {e}")
    print("If this is an 'AttributeError: 'NoneType' object has no attribute 'load_state_dict'' related to fp16,")
    print("it might mean the previous checkpoint was saved inconsistently (e.g., GPU disconnected mid-save).")
    print(f"Consider deleting the '{OUTPUT_DIR}' folder manually and restarting training from scratch.")


print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

# --- 10. Evaluate on Test Set (Optional) ---
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


Downloading 'punkt_tab' NLTK data for ROUGE evaluation...
'punkt_tab' downloaded.


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Loading processed dataset from: /content/drive/My Drive/longformer_processed_govreport

Processed dataset loaded successfully!
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 259432
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 27214
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
        num_rows: 28855
    })
})


Some weights of LongformerForSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Label distribution in training set: Class 0 (Non-summary): 253686, Class 1 (Summary): 5746
Calculated class weights: Class 0: 0.5113, Class 1: 22.5750
LongformerForSequenceClassification model and class weights moved to GPU.


  trainer = CustomTrainer(



Starting Longformer extractive summarization training...
Found existing checkpoint: /content/drive/My Drive/longformer_extractive_govreport/checkpoint-6268. Resuming training from here.


	eval_steps: 50 (from args) != 500 (from trainer_state.json)
	save_steps: 50 (from args) != 500 (from trainer_state.json)


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
6500,0.2529,0.121341,0.988278,0.0,0.0,0.0
7000,0.2305,0.110604,0.988278,0.0,0.0,0.0
7500,0.186,0.123311,0.988278,0.0,0.0,0.0
8000,0.1266,0.137786,0.988278,0.0,0.0,0.0


#Training with 500 -50-50 samples using BERTscore

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate
import nltk
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# --- Install/Upgrade Libraries (Run this cell first in Colab) ---
# !pip install --upgrade transformers datasets accelerate
# !pip install bert-score

# Ensure NLTK 'punkt' and 'punkt_tab' tokenizers are available
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("Downloading 'punkt' NLTK data for sentence splitting...")
    nltk.download('punkt')
    print("'punkt' downloaded.")

try:
    nltk.data.find('tokenizers/punkt_tab/english.pickle')
except LookupError:
    print("Downloading 'punkt_tab' NLTK data for ROUGE evaluation...")
    nltk.download('punkt_tab')
    print("'punkt_tab' downloaded.")

try:
    bertscore_metric = evaluate.load("bertscore")
except Exception:
    print("Downloading 'bertscore' metric...")
    bertscore_metric = evaluate.load("bertscore")
    print("'bertscore' loaded.")


# --- 0. Configuration ---
MODEL_NAME = "allenai/longformer-base-4096"
OUTPUT_DIR = "/content/drive/My Drive/longformer_extractive_govreport"
os.makedirs(OUTPUT_DIR, exist_ok=True)
GENERATED_DATASET_DIR = "./govreport_tfidf_vscode2"
PROCESSED_DATA_DIR = "/content/drive/My Drive/longformer_processed_govreport"
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
SENTENCE_MAX_LENGTH = 512

# --- IMPORTANT: FORCING A RERUN WITH NEW LABELS ---
# If you are changing the labeling metric (e.g., from ROUGE to BERTScore),
# you MUST delete the old processed dataset to force the script to re-label your data.
# Uncomment the line below to delete the old processed data directory.
# import shutil
# if os.path.exists(PROCESSED_DATA_DIR):
#    print(f"Deleting old processed dataset at: {PROCESSED_DATA_DIR}")
#    shutil.rmtree(PROCESSED_DATA_DIR)
#    print("Old processed dataset deleted. New one will be generated.")

# --- 1. Load or Generate Processed Dataset ---
tokenized_datasets = None
govreport_data_splits = None

if os.path.exists(PROCESSED_DATA_DIR) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "train")) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "validation")) and \
   os.path.exists(os.path.join(PROCESSED_DATA_DIR, "test")):
    print(f"Loading processed dataset from: {PROCESSED_DATA_DIR}")
    try:
        tokenized_datasets = DatasetDict.load_from_disk(PROCESSED_DATA_DIR)
        print("\nProcessed dataset loaded successfully!")
        print(tokenized_datasets)
    except Exception as e:
        print(f"\nError loading processed dataset from {PROCESSED_DATA_DIR}: {e}")
        print("Attempting to regenerate dataset as loading failed...")
        tokenized_datasets = None
else:
    print(f"Loading raw GovReport dataset from: {GENERATED_DATASET_DIR}")
    try:
        data_files = {
            "train": os.path.join(GENERATED_DATASET_DIR, "train.json"),
            "validation": os.path.join(GENERATED_DATA_DIR, "validation.json"),
            "test": os.path.join(GENERATED_DATA_DIR, "test.json"),
        }
        govreport_data_splits = load_dataset("json", data_files=data_files)
        print("\nRaw GovReport dataset loaded successfully!")
        print(govreport_data_splits)

        print("\nSelecting a small subset of the dataset for quick testing. Adjust or remove for full training.")
        govreport_data_splits["train"] = govreport_data_splits["train"].select(range(500))
        govreport_data_splits["validation"] = govreport_data_splits["validation"].select(range(50))
        govreport_data_splits["test"] = govreport_data_splits["test"].select(range(50))
        print("Subset selected:")
        print(govreport_data_splits)

        if "original_text" not in govreport_data_splits["train"].column_names or \
           "extractive_summary" not in govreport_data_splits["train"].column_names:
            raise ValueError("Dataset must contain 'original_text' and 'extractive_summary' columns.")
    except Exception as e:
        print(f"\nError loading raw GovReport dataset from {GENERATED_DATASET_DIR}: {e}")
        print("Please ensure the directory exists and contains 'train.json', 'validation.json', 'test.json'.")
        exit()

    tokenizer_for_preprocessing = LongformerTokenizer.from_pretrained(MODEL_NAME)

    def generate_extractive_labels(examples):
        processed_examples = {
            "input_ids": [],
            "attention_mask": [],
            "global_attention_mask": [],
            "labels": [],
        }

        for i in tqdm(range(len(examples["original_text"])), desc="Generating labels"):
            document_text = examples["original_text"][i]
            extractive_summary_reference = examples["extractive_summary"][i]

            if not document_text or not extractive_summary_reference:
                continue

            sentences = sent_tokenize(document_text)
            if not sentences:
                continue

            sentence_bert_scores = []
            for sent_idx, sentence in enumerate(sentences):
                score = bertscore_metric.compute(
                    predictions=[sentence],
                    references=[extractive_summary_reference],
                    lang="en",
                )
                sentence_bert_scores.append((score["f1"][0], sent_idx))

            sentence_bert_scores.sort(key=lambda x: x[0], reverse=True)

            BERTSCORE_THRESHOLD = 0.80
            labels_for_document = [0] * len(sentences)

            for score, sent_idx in sentence_bert_scores:
                if score >= BERTSCORE_THRESHOLD:
                    labels_for_document[sent_idx] = 1

            for sent_idx, sentence in enumerate(sentences):
                inputs = tokenizer_for_preprocessing(
                    sentence,
                    truncation=True,
                    max_length=SENTENCE_MAX_LENGTH,
                    padding="max_length",
                    return_tensors="pt"
                )

                global_attention_mask = torch.zeros_like(inputs["input_ids"])
                global_attention_mask[:, 0] = 1

                processed_examples["input_ids"].append(inputs["input_ids"].squeeze(0).tolist())
                processed_examples["attention_mask"].append(inputs["attention_mask"].squeeze(0).tolist())
                processed_examples["global_attention_mask"].append(global_attention_mask.squeeze(0).tolist())
                processed_examples["labels"].append(labels_for_document[sent_idx])

        return processed_examples

    def postprocess_text_for_rouge(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
        preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(sent_tokenize(label)) for label in labels]
        return preds, labels

    print("\nGenerating extractive labels for dataset splits (this may take a while)...")
    original_column_names = govreport_data_splits["train"].column_names

    tokenized_datasets = govreport_data_splits.map(
        generate_extractive_labels,
        batched=True,
        remove_columns=original_column_names,
        num_proc=os.cpu_count() if os.cpu_count() else 1
    )
    print("\nExtractive labels generated and tokenized.")
    print(tokenized_datasets)

    print(f"\nSaving processed dataset to: {PROCESSED_DATA_DIR}")
    tokenized_datasets.save_to_disk(PROCESSED_DATA_DIR)
    print("Processed dataset saved successfully!")

tokenizer = LongformerTokenizer.from_pretrained(MODEL_NAME)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    print("\n--- Metrics Debugging ---")
    print(f"Total validation samples: {len(labels)}")
    print(f"Number of true positive labels (1s) in data: {np.sum(labels == 1)}")
    print(f"Number of predicted positive labels (1s) by model: {np.sum(predictions == 1)}")
    print("--- End Debugging ---")

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


model = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

train_labels = tokenized_datasets["train"]["labels"]
label_counts = Counter(train_labels)
num_class_0 = label_counts.get(0, 0)
num_class_1 = label_counts.get(1, 0)
total_samples = len(train_labels)

print(f"\nLabel distribution in training set: Class 0 (Non-summary): {num_class_0}, Class 1 (Summary): {num_class_1}")

if num_class_0 > 0 and num_class_1 > 0:
    weight_for_class_0 = total_samples / (2 * num_class_0)
    weight_for_class_1 = total_samples / (2 * num_class_1)
    class_weights = torch.tensor([weight_for_class_0, weight_for_class_1], dtype=torch.float)
    print(f"Calculated class weights: Class 0: {weight_for_class_0:.4f}, Class 1: {weight_for_class_1:.4f}")
else:
    print("Warning: One or both classes are missing in the training data. Using uniform weights.")
    class_weights = torch.tensor([1.0, 1.0], dtype=torch.float)

if torch.cuda.is_available():
    class_weights = class_weights.to("cuda")
    model.to("cuda")
    print("LongformerForSequenceClassification model and class weights moved to GPU.")
else:
    print("No GPU found, model and class weights running on CPU.")

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(),
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=100,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=50,
    eval_steps=50,
    save_total_limit=2,
    load_best_model_at_end=False,
    report_to="none",
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print("\nStarting Longformer extractive summarization training...")
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
    checkpoints = [d for d in os.listdir(training_args.output_dir) if d.startswith("checkpoint-")]
    if checkpoints:
        last_checkpoint = os.path.join(training_args.output_dir, max(checkpoints, key=lambda x: int(x.split('-')[1])))
        print(f"Found existing checkpoint: {last_checkpoint}. Resuming training from here.")

try:
    # Note: If you want to use the new BERTScore-labeled data, do NOT resume from a checkpoint here.
    # Instead, let it start training from scratch on the newly generated dataset.
    trainer.train(resume_from_checkpoint=last_checkpoint)
except Exception as e:
    print(f"Error during training or resuming: {e}")
    print("If this is an 'AttributeError: 'NoneType' object has no attribute 'load_state_dict'' related to fp16,")
    print("it might mean the previous checkpoint was saved inconsistently. Consider deleting the folder and restarting.")

print("\nTraining complete! Best model saved to:", trainer.state.best_model_checkpoint)

print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_datasets["test"])
print("Test Results:", test_results)


#pip and mounting drive

In [1]:
!pip install rouge_score
!pip install evaluate

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=5ef17922c29c8bb1f2ae15c84fda55d13e00514e2206ad97a465a6d68cd4cdc4
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2
Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.5


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
cd drive/MyDrive/

[0m[01;34m'Colab Notebooks'[0m/                     [01;34mlongformer_processed_govreport[0m/
 [01;34mgovreport_tfidf_vscode2[0m/              longformer_processed_govreport.zip
 [01;34mlongformer_extractive_govreport[0m/     [01;34m'reviews media'[0m/
 longformer_extractive_govreport.zip


In [7]:
import zipfile
with zipfile.ZipFile("longformer_processed_govreport.zip", 'r') as zip_ref:
        zip_ref.extractall("")