# user defined length for extractive summary

In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import LongformerModel, LongformerTokenizer
from sentence_transformers import SentenceTransformer, util
import os
import nltk
from nltk.tokenize import sent_tokenize
import hashlib

# Download NLTK punkt tokenizer if not present
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# ==============================================================================
# 1. Re-define the Custom Model Architecture
# ==============================================================================
class LongformerExtractiveSummarizationModel(nn.Module):
    def __init__(self, pos_weight=None):
        super(LongformerExtractiveSummarizationModel, self).__init__()
        self.longformer = LongformerModel.from_pretrained('allenai/longformer-base-4096')
        self.dropout = nn.Dropout(p=0.1)
        self.classifier = nn.Linear(self.longformer.config.hidden_size, 1)
        self.pos_weight = pos_weight if pos_weight is not None else torch.tensor(1.0)
    def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, labels=None):
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask
        )
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        logits = logits.squeeze(-1)
        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight.to(logits.device))
            loss = loss_fct(logits, labels.float())
        return (loss, logits) if loss is not None else logits

# ==============================================================================
# 2. Setup
# ==============================================================================
# Path to the best checkpoint
CHECKPOINT_PATH = "./extractive_summarization_results/checkpoint-3148"
CHUNK_SIZE = 4096
# Lambda for MMR: 0.7 is a good starting point for balancing relevance and diversity
LAMBDA_MMR = 0.7
# Initialize tokenizer and sentence embedding model
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
sent_model = SentenceTransformer('all-MiniLM-L6-v2')

# ==============================================================================
# 3. Load the Fine-tuned Model
# ==============================================================================
def load_model(checkpoint_path):
    """Loads the fine-tuned model from a checkpoint."""
    print(f"Loading model from {checkpoint_path}...")
    model = LongformerExtractiveSummarizationModel()
   
    # Check for available model files
    model_file = None
    if os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")):
        model_file = "pytorch_model.bin"
    elif os.path.exists(os.path.join(checkpoint_path, "model.safetensors")):
        model_file = "model.safetensors"
   
    if not model_file:
        raise FileNotFoundError(f"No valid model file found in {checkpoint_path}. Expected 'pytorch_model.bin' or 'model.safetensors'.")
    # Load state dictionary
    if model_file.endswith(".safetensors"):
        from safetensors.torch import load_file
        state_dict = load_file(os.path.join(checkpoint_path, model_file))
    else:
        state_dict = torch.load(os.path.join(checkpoint_path, model_file), map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    model.eval()
    return model


# ==============================================================================
# 4. Main Summary Generation Function with Hierarchical Chunking (No MMR)
# ==============================================================================
def generate_extractive_summary(document: str, top_k: int) -> str:
    """Generates an extractive summary for a given document with a specified number of sentences using hierarchical chunking."""
    if not document or not document.strip():
        return "Input document is empty."
    
    # Split document into paragraphs
    paragraphs = [p.strip() for p in document.split('\n\n') if p.strip()]
    if not paragraphs:
        document_sentences = sent_tokenize(document)
        if not document_sentences:
            return "No sentences found in the document."
    else:
        document_sentences = [sent for para in paragraphs for sent in sent_tokenize(para)]
    
    if not document_sentences:
        return "No sentences found in the document."
   
    # Ensure top_k does not exceed the number of available sentences
    top_k = min(top_k, len(document_sentences))
   
    # Get sentence embeddings (still needed for potential future use or debugging)
    sentence_embeddings = sent_model.encode(document_sentences, convert_to_numpy=True)
   
    # Tokenize and hierarchically chunk the document
    all_chunks_tokens = []
    all_chunks_attention = []
    all_chunks_global_attention = []
    sentence_to_chunk_map = []  # Track which chunk each sentence belongs to
    
    for para_idx, paragraph in enumerate(paragraphs):
        current_input_ids = [tokenizer.cls_token_id]
        current_attention_mask = [1]
        current_sent_start_idx = len([s for p in paragraphs[:para_idx] for s in sent_tokenize(p)])
        
        for sent_idx, sentence in enumerate(sent_tokenize(paragraph)):
            global_sent_idx = current_sent_start_idx + sent_idx
            sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
            if len(current_input_ids) + len(sentence_tokens) + 1 > CHUNK_SIZE:
                # Chunk is full, finalize and start a new one
                padding_length = CHUNK_SIZE - len(current_input_ids)
                current_input_ids += [tokenizer.pad_token_id] * padding_length
                current_attention_mask += [0] * padding_length
                global_attention_mask = [0] * CHUNK_SIZE
                global_attention_mask[0] = 1
                
                all_chunks_tokens.append(current_input_ids)
                all_chunks_attention.append(current_attention_mask)
                all_chunks_global_attention.append(global_attention_mask)
                sentence_to_chunk_map.append((global_sent_idx, len(all_chunks_tokens) - 1))  # Map last sentence to chunk
                
                current_input_ids = [tokenizer.cls_token_id]
                current_attention_mask = [1]
            
            current_input_ids += sentence_tokens
            current_attention_mask += [1] * len(sentence_tokens)
            sentence_to_chunk_map.append((global_sent_idx, len(all_chunks_tokens)))
        
        # Finalize the last chunk for this paragraph
        if len(current_input_ids) > 1:
            current_input_ids.append(tokenizer.sep_token_id)
            current_attention_mask.append(1)
            padding_length = CHUNK_SIZE - len(current_input_ids)
            current_input_ids += [tokenizer.pad_token_id] * padding_length
            current_attention_mask += [0] * padding_length
            global_attention_mask = [0] * CHUNK_SIZE
            global_attention_mask[0] = 1
            
            all_chunks_tokens.append(current_input_ids)
            all_chunks_attention.append(current_attention_mask)
            all_chunks_global_attention.append(global_attention_mask)
            sentence_to_chunk_map.append((current_sent_start_idx + len(sent_tokenize(paragraph)) - 1, len(all_chunks_tokens) - 1))

    # Convert to tensors
    input_ids_tensor = torch.tensor(all_chunks_tokens)
    attention_mask_tensor = torch.tensor(all_chunks_attention)
    global_attention_mask_tensor = torch.tensor(all_chunks_global_attention)
    
    # Get logits from the model
    with torch.no_grad():
        logits = model(
            input_ids=input_ids_tensor,
            attention_mask=attention_mask_tensor,
            global_attention_mask=global_attention_mask_tensor
        )
    
    # Post-process logits to get sentence scores with hierarchical mapping
    predictions = torch.sigmoid(logits)
    aggregated_scores = []
    for chunk_idx, (chunk, att_mask) in enumerate(zip(predictions, all_chunks_attention)):
        effective_len = sum(att_mask)
        if effective_len > 2:
            content_scores = chunk[1:effective_len - 1].tolist()
            aggregated_scores.extend(content_scores)
    
    sentence_scores = [0.0] * len(document_sentences)
    for global_sent_idx, chunk_idx in sentence_to_chunk_map:
        if global_sent_idx < len(document_sentences):
            start_token = sum(len(tokenizer.encode(document_sentences[s], add_special_tokens=False)) for s in range(global_sent_idx))
            end_token = start_token + len(tokenizer.encode(document_sentences[global_sent_idx], add_special_tokens=False))
            if end_token <= len(aggregated_scores):
                sentence_logits = aggregated_scores[start_token:end_token]
                sentence_scores[global_sent_idx] = max(sentence_logits) if len(sentence_logits) > 0 else 0.0
    
    # Select top-k sentences based on scores (no MMR)
    selected_indices = np.argsort(sentence_scores)[-top_k:][::-1]
    
    # Reconstruct summary from selected sentences
    predicted_sentences = [document_sentences[i] for i in selected_indices]
    return " ".join(predicted_sentences)

# ==============================================================================
# 5. Example Usage
# ==============================================================================
if __name__ == "__main__":
    # Load the fine-tuned model
    try:
        model = load_model(CHECKPOINT_PATH)
    except (FileNotFoundError, ValueError) as e:
        print(f"Error loading model: {e}")
        exit()
    # Get user input
    user_input = input("Enter the document text to summarize: \n\n")
    try:
        top_k_input = int(input("\nEnter the desired number of sentences for the summary (e.g., 3, 5): "))
    except ValueError:
        print("Invalid input. Using default of 3 sentences.")
        top_k_input = 3
    # Generate and print the summary
    summary = generate_extractive_summary(user_input, top_k_input)
    print("\n--- Generated Summary ---")
    print(summary)
    print("-------------------------")