# Full pipeline (Longformer + Bridging mechanism + BART)

In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import LongformerModel, LongformerTokenizer, BartTokenizer, BartForConditionalGeneration, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util
import os
import nltk
from nltk.tokenize import sent_tokenize
import hashlib
import jsonlines
from rouge_score import rouge_scorer
import bert_score
from tqdm import tqdm
import spacy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import LogitsProcessorList, LogitsProcessor

# Download NLTK punkt tokenizer if not present
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Load Spacy model for sentence splitting
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    print("Downloading 'en_core_web_sm' model. This will happen only once.")
    spacy.cli.download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

# Load Factual Consistency Model (FactCC)
print("Loading FactCC model from Hugging Face...")
factcc_model_path = 'manueldeprada/FactCC'
factcc_tokenizer = AutoTokenizer.from_pretrained(factcc_model_path)
factcc_model = AutoModelForSequenceClassification.from_pretrained(factcc_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
factcc_model.eval()
factcc_model.to(device)
print(f"Using device: {device}")
print("FactCC model loaded successfully.")

# Helper Function for Sentence Splitting
def split_sentences(text):
    return [sent.text.strip() for sent in nlp(text).sents if sent.text.strip()]

# Helper function to create overlapping document chunks
def create_document_chunks(document, tokenizer, max_length=512, overlap=50):
    tokens = tokenizer.tokenize(document)
    chunks = []
    doc_max_length = max_length - 50
    start = 0
    while start < len(tokens):
        end = min(start + doc_max_length, len(tokens))
        chunk_tokens = tokens[start:end]
        chunks.append(tokenizer.convert_tokens_to_string(chunk_tokens))
        if end == len(tokens):
            break
        start += (doc_max_length - overlap)
    return chunks

# Key Sentence Extractor for Constraint Guidance
class KeySentenceExtractor:
    def __init__(self, extraction_percent=0.15, lambda_mmr=0.5):  # Adjusted parameters
        self.extraction_percent = extraction_percent
        self.lambda_mmr = lambda_mmr
        self.vectorizer = TfidfVectorizer(stop_words='english')

    def extract_key_sentences(self, text: str) -> list:
        sentences = sent_tokenize(text)
        if len(sentences) <= 1:
            return sentences
        tfidf_matrix = self.vectorizer.fit_transform(sentences)
        similarities = cosine_similarity(tfidf_matrix)
        sentence_scores = similarities.sum(axis=1)
        num_to_extract = max(1, int(len(sentences) * self.extraction_percent))
        num_to_extract = min(num_to_extract, len(sentences))
        selected_indices = []
        unselected_indices = list(range(len(sentences)))
        selected_indices.append(unselected_indices.pop(np.argmax(sentence_scores)))
        for _ in range(num_to_extract - 1):
            if not unselected_indices:
                break
            mmr_scores = []
            for i in unselected_indices:
                relevance = sentence_scores[i]
                diversity = max(similarities[i, j] for j in selected_indices) if selected_indices else 0
                mmr = self.lambda_mmr * relevance - (1 - self.lambda_mmr) * diversity
                mmr_scores.append(mmr)
            next_idx = unselected_indices.pop(np.argmax(mmr_scores))
            selected_indices.append(next_idx)
        selected_indices.sort()
        return [sentences[i] for i in selected_indices]

# Constraint Extractor
class ConstraintExtractor:
    def __init__(self):
        self.nlp = nlp

    def extract_constraints(self, sentences: list) -> list:
        constraints = []
        for sentence in sentences:
            doc = self.nlp(sentence)
            entities = [(ent.text, ent.label_) for ent in doc.ents]
            numbers = [token.text for token in doc if token.like_num]
            noun_phrases = [chunk.text for chunk in doc.noun_chunks]
            constraint = {
                'sentence': sentence,
                'entities': entities,
                'numbers': numbers,
                'noun_phrases': noun_phrases
            }
            constraints.append(constraint)
        return constraints

# Constrained Logits Processor
class ConstrainedLogitsProcessor(LogitsProcessor):
    def __init__(self, constraint_token_ids: list, boost_factor: float = 1.5):  # Reduced boost_factor
        self.constraint_token_ids = set(constraint_token_ids)
        self.boost_factor = boost_factor

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        constraint_mask = torch.zeros_like(scores)
        for token_id in self.constraint_token_ids:
            if token_id < scores.shape[-1]:
                constraint_mask[:, token_id] = 1
        scores = scores + (constraint_mask * self.boost_factor)
        return scores

# Longformer Extractive Summarization Model
class LongformerExtractiveSummarizationModel(nn.Module):
    def __init__(self, pos_weight=None):
        super(LongformerExtractiveSummarizationModel, self).__init__()
        self.longformer = LongformerModel.from_pretrained('allenai/longformer-base-4096')
        self.dropout = nn.Dropout(p=0.1)
        self.classifier = nn.Linear(self.longformer.config.hidden_size, 1)
        self.pos_weight = pos_weight if pos_weight is not None else torch.tensor(1.0)

    def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, labels=None):
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask
        )
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        logits = logits.squeeze(-1)
        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight.to(logits.device))
            loss = loss_fct(logits, labels.float())
        return (loss, logits) if loss is not None else logits

# Setup
CHECKPOINT_PATH = "./extractive_summarization_results/checkpoint-3148"
CHUNK_SIZE = 4096
tokenizer_long = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
sent_model = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer_bart = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
model_bart = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
model_bart.to(device)
key_sentence_extractor = KeySentenceExtractor()
constraint_extractor = ConstraintExtractor()

# Coverage score
def get_sentence_coverage(para1, para2, threshold=0.5):
    sents1 = split_sentences(para1)
    sents2 = split_sentences(para2)
    emb1 = sent_model.encode(sents1, convert_to_tensor=True)
    emb2 = sent_model.encode(sents2, convert_to_tensor=True)
    matched = 0
    for i in range(len(sents2)):
        sims = util.cos_sim(emb2[i], emb1)[0]
        if sims.max().item() >= threshold:
            matched += 1
    coverage = matched / len(sents2) if sents2 else 0.0
    return round(coverage, 4)

# Factual Consistency Evaluation Function with Sliding Window
def calculate_factcc_score(original_document, summary):
    summary_sentences = split_sentences(summary)
    if not summary_sentences:
        return 0.0
    document_chunks = create_document_chunks(original_document, factcc_tokenizer)
    consistency_scores = []
    with torch.no_grad():
        for sentence in summary_sentences:
            max_sentence_score = 0.0
            for chunk in document_chunks:
                inputs = factcc_tokenizer(chunk, sentence, return_tensors='pt', padding=True, truncation='only_first', max_length=512).to(device)
                outputs = factcc_model(**inputs)
                logits = outputs.logits
                probs = torch.nn.functional.softmax(logits, dim=1)
                correct_score = probs[0][0].item()
                if correct_score > max_sentence_score:
                    max_sentence_score = correct_score
            consistency_scores.append(max_sentence_score)
    if not consistency_scores:
        return 0.0
    average_score = sum(consistency_scores) / len(consistency_scores)
    return round(average_score, 4)

# Load the Fine-tuned Model
def load_model(checkpoint_path):
    print(f"Loading model from {checkpoint_path}...")
    model = LongformerExtractiveSummarizationModel()
    model_file = None
    if os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")):
        model_file = "pytorch_model.bin"
    elif os.path.exists(os.path.join(checkpoint_path, "model.safetensors")):
        model_file = "model.safetensors"
    if not model_file:
        raise FileNotFoundError(f"No valid model file found in {checkpoint_path}. Expected 'pytorch_model.bin' or 'model.safetensors'.")
    if model_file.endswith(".safetensors"):
        from safetensors.torch import load_file
        state_dict = load_file(os.path.join(checkpoint_path, model_file))
    else:
        state_dict = torch.load(os.path.join(checkpoint_path, model_file), map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    model.eval()
    return model

# Generate Extractive Summary
def generate_extractive_summary(document: str) -> str:
    if not document or not document.strip():
        return "Input document is empty."
    paragraphs = [p.strip() for p in document.split('\n\n') if p.strip()]
    if not paragraphs:
        document_sentences = sent_tokenize(document)
        if not document_sentences:
            return "No sentences found in the document."
    else:
        document_sentences = [sent for para in paragraphs for sent in sent_tokenize(para)]
    if not document_sentences:
        return "No sentences found in the document."
   
    top_k = max(1, int(len(document_sentences) * 0.1))
   
    sentence_embeddings = sent_model.encode(document_sentences, convert_to_numpy=True)
    all_chunks_tokens = []
    all_chunks_attention = []
    all_chunks_global_attention = []
    sentence_to_chunk_map = []
    for para_idx, paragraph in enumerate(paragraphs):
        current_input_ids = [tokenizer_long.cls_token_id]
        current_attention_mask = [1]
        current_sent_start_idx = len([s for p in paragraphs[:para_idx] for s in sent_tokenize(p)])
        for sent_idx, sentence in enumerate(sent_tokenize(paragraph)):
            global_sent_idx = current_sent_start_idx + sent_idx
            sentence_tokens = tokenizer_long.encode(sentence, add_special_tokens=False)
            if len(current_input_ids) + len(sentence_tokens) + 1 > CHUNK_SIZE:
                padding_length = CHUNK_SIZE - len(current_input_ids)
                current_input_ids += [tokenizer_long.pad_token_id] * padding_length
                current_attention_mask += [0] * padding_length
                global_attention_mask = [0] * CHUNK_SIZE
                global_attention_mask[0] = 1
                all_chunks_tokens.append(current_input_ids)
                all_chunks_attention.append(current_attention_mask)
                all_chunks_global_attention.append(global_attention_mask)
                sentence_to_chunk_map.append((global_sent_idx, len(all_chunks_tokens) - 1))
                current_input_ids = [tokenizer_long.cls_token_id]
                current_attention_mask = [1]
            current_input_ids += sentence_tokens
            current_attention_mask += [1] * len(sentence_tokens)
            sentence_to_chunk_map.append((global_sent_idx, len(all_chunks_tokens)))
        if len(current_input_ids) > 1:
            current_input_ids.append(tokenizer_long.sep_token_id)
            current_attention_mask.append(1)
            padding_length = CHUNK_SIZE - len(current_input_ids)
            current_input_ids += [tokenizer_long.pad_token_id] * padding_length
            current_attention_mask += [0] * padding_length
            global_attention_mask = [0] * CHUNK_SIZE
            global_attention_mask[0] = 1
            all_chunks_tokens.append(current_input_ids)
            all_chunks_attention.append(current_attention_mask)
            all_chunks_global_attention.append(global_attention_mask)
            sentence_to_chunk_map.append((current_sent_start_idx + len(sent_tokenize(paragraph)) - 1, len(all_chunks_tokens) - 1))
    input_ids_tensor = torch.tensor(all_chunks_tokens)
    attention_mask_tensor = torch.tensor(all_chunks_attention)
    global_attention_mask_tensor = torch.tensor(all_chunks_global_attention)
    with torch.no_grad():
        logits = model(
            input_ids=input_ids_tensor,
            attention_mask=attention_mask_tensor,
            global_attention_mask=global_attention_mask_tensor
        )
    predictions = torch.sigmoid(logits)
    aggregated_scores = []
    for chunk, att_mask in zip(predictions, all_chunks_attention):
        effective_len = sum(att_mask)
        if effective_len > 2:
            content_scores = chunk[1:effective_len - 1].tolist()
            aggregated_scores.extend(content_scores)
    sentence_scores = [0.0] * len(document_sentences)
    for global_sent_idx, chunk_idx in sentence_to_chunk_map:
        if global_sent_idx < len(document_sentences):
            start_token = sum(len(tokenizer_long.encode(document_sentences[s], add_special_tokens=False)) for s in range(global_sent_idx))
            end_token = start_token + len(tokenizer_long.encode(document_sentences[global_sent_idx], add_special_tokens=False))
            if end_token <= len(aggregated_scores):
                sentence_logits = aggregated_scores[start_token:end_token]
                sentence_scores[global_sent_idx] = max(sentence_logits) if len(sentence_logits) > 0 else 0.0
    selected_indices = np.argsort(sentence_scores)[-top_k:][::-1]
    predicted_sentences = [document_sentences[i] for i in selected_indices]
    return " ".join(predicted_sentences)

# Rephrase Text with Constraint-Guided BART
def rephrase_text(input_text, original_text, boost_factor=1.5, extraction_percent=0.15):  # Added original_text, updated parameters
    sentences = sent_tokenize(input_text)
    input_count = len(sentences)
    if input_count == 0:
        return input_text
    target_count = max(1, int(input_count * np.random.uniform(0.8, 0.9)))
    rephrased_sentences = []
    # Extract key sentences from both input_text and original_text
    key_sentences = key_sentence_extractor.extract_key_sentences(input_text)
    key_sentences.extend(key_sentence_extractor.extract_key_sentences(original_text))
    key_sentences = list(set(key_sentences))  # Remove duplicates
    constraints = constraint_extractor.extract_constraints(key_sentences)
    constraint_texts = []
    for constraint in constraints:
        constraint_texts.extend([ent[0] for ent in constraint['entities']])
        constraint_texts.extend(constraint['numbers'])
        constraint_texts.extend(constraint['noun_phrases'])
    constraint_texts = list(set(constraint_texts))
    constraint_token_ids = []
    for text in constraint_texts:
        tokens = tokenizer_bart.encode(text, add_special_tokens=False)
        constraint_token_ids.extend(tokens)
    constraint_token_ids = list(set(constraint_token_ids))
    logits_processor = LogitsProcessorList([
        ConstrainedLogitsProcessor(constraint_token_ids, boost_factor)
    ])
    for sentence in sentences[:target_count]:
        input_tokens = tokenizer_bart.encode(sentence, add_special_tokens=False)
        if not input_tokens:
            continue
        target_length = max(10, len(input_tokens) * 2)  # Dynamic max_length
        min_length = max(5, len(input_tokens) // 2)    # Dynamic min_length
        inputs = tokenizer_bart(sentence, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
        with torch.no_grad():
            outputs = model_bart.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_length=target_length,
                min_length=min_length,
                num_beams=8,  # Increased num_beams
                logits_processor=logits_processor,
                early_stopping=True
                # Removed no_repeat_ngram_size
            )
        rephrased_sentence = tokenizer_bart.decode(outputs[0], skip_special_tokens=True)
        rephrased_sentences.append(rephrased_sentence)
    return " ".join(rephrased_sentences)

# Process Summary with Constraint-Guided BART
def process_summary(summary, abstractive_summary, original_text):  # Added original_text
    sentences = sent_tokenize(summary)
    chunks = []
    current_chunk = []
    current_length = 0
    for sentence in sentences:
        sentence_tokens = tokenizer_bart.encode(sentence, add_special_tokens=False)
        if current_length + len(sentence_tokens) > CHUNK_SIZE:
            if current_chunk:
                chunks.append(" ".join(current_chunk))
            current_chunk = [sentence]
            current_length = len(sentence_tokens)
        else:
            current_chunk.append(sentence)
            current_length += len(sentence_tokens)
    if current_chunk:
        chunks.append(" ".join(current_chunk))
    rephrased_chunks = [rephrase_text(chunk, original_text, boost_factor=1.5, extraction_percent=0.15) for chunk in chunks]
    final_summary = " ".join(rephrased_chunks)
    return final_summary

# Evaluation Metrics
def calculate_metrics(original_text, reference, candidate):
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = rouge_scorer_obj.score(reference, candidate)
    rouge = {
        "rouge1": round(rouge_scores["rouge1"].fmeasure, 4),
        "rouge2": round(rouge_scores["rouge2"].fmeasure, 4),
        "rougeL": round(rouge_scores["rougeL"].fmeasure, 4)
    }
    P, R, F1 = bert_score.score([candidate], [reference], lang="en", verbose=False)
    bert = {
        "bertscore_precision": round(P[0].item(), 4),
        "bertscore_recall": round(R[0].item(), 4),
        "bertscore_f1": round(F1[0].item(), 4)
    }
    coverage = get_sentence_coverage(original_text, candidate)
    factcc = calculate_factcc_score(original_text, candidate)
    return {**rouge, **bert, "coverage_score": coverage, "factcc_score": factcc}

# Main Execution with Test Data
if __name__ == "__main__":
    try:
        model = load_model(CHECKPOINT_PATH)
    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading model: {e}")
        exit()
    test_file = "govreport_tfidf_vscode2/test.json"
    test_data = []
    with jsonlines.open(test_file) as reader:
        for i, item in enumerate(reader):
            if i >= 100:
                break
            test_data.append(item)
    all_metrics = []
    for item in tqdm(test_data, desc="Processing samples"):
        original_text = item.get('original_text', '')
        abstractive_summary = item.get('abstractive_summary', '')
        if not original_text or not abstractive_summary:
            continue
        extractive_summary = generate_extractive_summary(original_text)
        final_summary = process_summary(extractive_summary, abstractive_summary, original_text)  # Pass original_text
        metrics = calculate_metrics(original_text, abstractive_summary, final_summary)
        all_metrics.append(metrics)
    if all_metrics:
        avg_metrics = {key: round(sum(m[key] for m in all_metrics) / len(all_metrics), 4) for key in all_metrics[0]}
        for metric, value in avg_metrics.items():
            print(f"{metric}: {value}")
    else:
        print("No valid summaries generated for evaluation.")