# Testing on Base model

In [None]:
import os
import jsonlines
import torch
import multiprocessing
from transformers import LongformerModel, LongformerTokenizerFast
import spacy
import evaluate
from typing import List, Dict
from tqdm.auto import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

# ============================================================
# Force 32 CPU cores for all libraries
# ============================================================
NUM_CORES = 32
torch.set_num_threads(NUM_CORES)
os.environ["OMP_NUM_THREADS"] = str(NUM_CORES)
os.environ["MKL_NUM_THREADS"] = str(NUM_CORES)
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# ============================================================
# Spacy sentence tokenizer
# ============================================================
try:
    nlp = spacy.load("en_core_web_sm", disable=["ner", "tagger", "lemmatizer"])
except OSError:
    print("Downloading spacy model 'en_core_web_sm'...")
    from spacy.cli import download
    download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm", disable=["ner", "tagger", "lemmatizer"])

def spacy_sent_tokenize(text):
    # nlp.pipe uses parallel processes (n_process=NUM_CORES)
    return [sent.text.strip() for sent in nlp(text).sents if sent.text.strip()]

# ============================================================
# Config
# ============================================================
DATA_DIR = "govreport_tfidf_vscode2"
NUM_VAL_SAMPLES = 2000
F1_THRESHOLD = 0.4
MAX_LENGTH = 4096
validation_file = f"{DATA_DIR}/validation.json"

# ============================================================
# Load JSONL
# ============================================================
def load_jsonl_data(file_path, max_samples=None):
    data = []
    with jsonlines.open(file_path) as reader:
        for i, obj in enumerate(reader):
            if max_samples and i >= max_samples:
                break
            data.append(obj)
    return data

# ============================================================
# Token-level F1 score
# ============================================================
def get_token_f1_score(candidate_tokens, reference_tokens):
    candidate_tokens_set = set(candidate_tokens)
    reference_tokens_set = set(reference_tokens)
    if not candidate_tokens_set or not reference_tokens_set:
        return 0.0
    intersection = len(candidate_tokens_set & reference_tokens_set)
    precision = intersection / len(candidate_tokens_set)
    recall = intersection / len(reference_tokens_set)
    return 0.0 if precision + recall == 0 else 2 * (precision * recall) / (precision + recall)

# ============================================================
# Prepare single example
# ============================================================
def prepare_single_example(document, summary, tokenizer):
    all_doc_sentences = spacy_sent_tokenize(document)
    summary_sentences = spacy_sent_tokenize(summary)
    tokenized_summary = [tokenizer.tokenize(s) for s in summary_sentences]
    
    chunk_text = " ".join(all_doc_sentences)
    tokenized_chunk = tokenizer(chunk_text, truncation=True, max_length=MAX_LENGTH, return_offsets_mapping=True)

    sentence_spans = []
    sentence_labels = []
    current_offset = 0

    for chunk_sent in all_doc_sentences:
        sent_tokens = tokenizer.tokenize(chunk_sent)
        start_token_idx, end_token_idx = -1, -1
        
        start_offset = chunk_text.find(chunk_sent, current_offset)
        if start_offset == -1:
            current_offset += len(chunk_sent) + 1
            continue
        end_offset = start_offset + len(chunk_sent)
        current_offset = end_offset + 1

        for j, (start_char, end_char) in enumerate(tokenized_chunk['offset_mapping']):
            if start_char == start_offset and start_token_idx == -1:
                start_token_idx = j
            if end_char == end_offset-1 and end_token_idx == -1:
                end_token_idx = j + 1
            if start_token_idx != -1 and end_token_idx != -1:
                break
        
        if start_token_idx != -1 and end_token_idx != -1:
            sentence_spans.append([start_token_idx, end_token_idx])
            max_f1 = max((get_token_f1_score(sent_tokens, sum_tokens) for sum_tokens in tokenized_summary), default=0.0)
            label = 1.0 if max_f1 >= F1_THRESHOLD else 0.0
            sentence_labels.append(label)

    return {
        "input_ids": tokenized_chunk['input_ids'],
        "attention_mask": tokenized_chunk['attention_mask'],
        "sentence_spans": sentence_spans,
        "labels": sentence_labels,
        "document_text": document,
        "summary_text": summary
    }

# ============================================================
# Evaluator
# ============================================================
class BaseModelRougeEvaluator:
    def __init__(self):
        self.tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096', add_prefix_space=True)
        self.model = LongformerModel.from_pretrained('allenai/longformer-base-4096')
        self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1)
        self.rouge_metric = evaluate.load("rouge")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.classifier.to(self.device)

    def run_evaluation(self, file_path, num_samples):
        validation_data_raw = load_jsonl_data(file_path, max_samples=num_samples)

        predictions, references = [], []
        all_precisions, all_recalls, all_f1s = [], [], []

        self.model.eval()
        with torch.no_grad():
            for item in tqdm(validation_data_raw, desc="Evaluating ROUGE & Metrics"):
                document = item['original_text']
                summary = item['extractive_summary']

                processed_item = prepare_single_example(document, summary, self.tokenizer)
                input_ids = torch.as_tensor([processed_item['input_ids']], dtype=torch.long).to(self.device)
                attention_mask = torch.as_tensor([processed_item['attention_mask']], dtype=torch.long).to(self.device)
                sentence_spans = torch.as_tensor([processed_item['sentence_spans']], dtype=torch.long)
                
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                sequence_output = outputs.last_hidden_state

                doc_logits = []
                for span in sentence_spans[0]:
                    start, end = span[0].item(), span[1].item()
                    if start >= 0 and end >= 0:
                        sentence_hidden_states = sequence_output[0, start:end, :]
                        mean_pooled_sentence_rep = torch.mean(sentence_hidden_states, dim=0, keepdim=True)
                        doc_logits.append(self.classifier(mean_pooled_sentence_rep).squeeze(-1))
                
                if doc_logits:
                    dummy_logits = torch.cat(doc_logits) if len(doc_logits) > 1 else doc_logits[0]
                    num_sentences_to_select = min(3, len(doc_logits))
                    _, topk_indices = torch.topk(dummy_logits, k=num_sentences_to_select)
                    predictions_for_doc = [1 if i in topk_indices.tolist() else 0 for i in range(len(doc_logits))]
                else:
                    predictions_for_doc = []

                doc_sentences = spacy_sent_tokenize(document)
                extracted_sentences = []
                if not isinstance(predictions_for_doc, list):
                    predictions_for_doc = [predictions_for_doc]

                for k, pred in enumerate(predictions_for_doc):
                    if pred == 1 and k < len(doc_sentences):
                        extracted_sentences.append(doc_sentences[k])
                
                generated_summary_text = " ".join(extracted_sentences)
                predictions.append(generated_summary_text)
                references.append(summary)

                if processed_item["labels"]:
                    all_precisions.append(precision_score(processed_item["labels"], predictions_for_doc, zero_division=0))
                    all_recalls.append(recall_score(processed_item["labels"], predictions_for_doc, zero_division=0))
                    all_f1s.append(f1_score(processed_item["labels"], predictions_for_doc, zero_division=0))

        rouge_results = self.rouge_metric.compute(predictions=predictions, references=references)
        
        print("\n--- Base Model Scores ---")
        for key, value in rouge_results.items():
            print(f"  {key}: {value:.4f}")
        print(f"  Precision: {sum(all_precisions)/len(all_precisions):.4f}")
        print(f"  Recall:    {sum(all_recalls)/len(all_recalls):.4f}")
        print(f"  F1:        {sum(all_f1s)/len(all_f1s):.4f}")
        print("-----------------------------\n")

if __name__ == "__main__":
    evaluator = BaseModelRougeEvaluator()
    evaluator.run_evaluation(validation_file, num_samples=NUM_VAL_SAMPLES)


--- Base Model Scores ---
  - rouge1: 0.2057
  - rouge2: 0.0677
  - rougeL: 0.1304
  - rougeLsum: 0.1303
  - Precision: 0.1877
  - Recall:    0.0292
  - F1 score:  0.0471