In [2]:
import os
import pickle
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
import gradio as gr
import nltk
from nltk.tokenize import sent_tokenize
import time

# Download necessary NLTK data
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

class TextSummarizer:
    def __init__(self, model_name="facebook/bart-large-cnn"):
        """
        Initialize the summarizer with a specific BART model
        Args:
            model_name: The name of the model to use (default: facebook/bart-large-cnn)
        """
        print(f"Loading model: {model_name}")
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForConditionalGeneration.from_pretrained(model_name)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        print(f"Model loaded successfully. Using device: {self.device}")
    
    def count_tokens(self, text):
        """Count the number of tokens in the text"""
        return len(self.tokenizer.encode(text))
    
    def chunk_text(self, text, max_chunk_size=1024):
        """
        Split text into chunks that fit within the model's token limit
        Args:
            text: The text to chunk
            max_chunk_size: Maximum number of tokens per chunk
        Returns:
            List of text chunks
        """
        sentences = sent_tokenize(text)
        chunks = []
        current_chunk = []
        current_chunk_size = 0
        
        for sentence in sentences:
            sentence_tokens = self.count_tokens(sentence)
            
            # If this sentence alone exceeds limit, split it (simplified approach)
            if sentence_tokens > max_chunk_size:
                if current_chunk:
                    chunks.append(" ".join(current_chunk))
                    current_chunk = []
                    current_chunk_size = 0
                
                # Add a truncated version of the long sentence
                trunc_sentence = self.tokenizer.decode(
                    self.tokenizer.encode(sentence, max_length=max_chunk_size, truncation=True)[:-2]
                )
                chunks.append(trunc_sentence)
                continue
            
            # Check if adding this sentence would exceed the limit
            if current_chunk_size + sentence_tokens > max_chunk_size:
                chunks.append(" ".join(current_chunk))
                current_chunk = [sentence]
                current_chunk_size = sentence_tokens
            else:
                current_chunk.append(sentence)
                current_chunk_size += sentence_tokens
        
        # Add the last chunk if it's not empty
        if current_chunk:
            chunks.append(" ".join(current_chunk))
            
        return chunks
    
    def summarize(self, text, max_length=150, min_length=40, length_penalty=2.0, 
                  num_beams=4, early_stopping=True):
        """
        Summarize the given text
        Args:
            text: The text to summarize
            max_length: Maximum length of the summary
            min_length: Minimum length of the summary
            length_penalty: Length penalty for beam search
            num_beams: Number of beams for beam search
            early_stopping: Whether to stop when num_beams complete sentences are found
        Returns:
            Summary text
        """
        start_time = time.time()
        
        # Check if text is too long and needs chunking
        if self.count_tokens(text) > 1024:
            return self.summarize_long_text(text, max_length, min_length, length_penalty, num_beams, early_stopping)
        
        # Prepare inputs
        inputs = self.tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
        inputs = inputs.to(self.device)
        
        # Generate summary
        summary_ids = self.model.generate(
            inputs, 
            max_length=max_length, 
            min_length=min_length, 
            length_penalty=length_penalty,
            num_beams=num_beams, 
            early_stopping=early_stopping
        )
        
        summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        
        end_time = time.time()
        processing_time = end_time - start_time
        
        return summary, processing_time

# Save the model using pickle
def save_model(summarizer, filename="text_summarizer.pkl"):
    with open(filename, "wb") as f:
        pickle.dump(summarizer, f)

# Load the model using pickle
def load_model(filename="text_summarizer.pkl"):
    with open(filename, "rb") as f:
        return pickle.load(f)

# Initialize the summarizer and save it
summarizer = TextSummarizer()
save_model(summarizer)

print("Model saved successfully as text_summarizer.pkl")


Loading model: facebook/bart-large-cnn
Model loaded successfully. Using device: cpu
Model saved successfully as text_summarizer.pkl
