In [None]:
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords
from nltk.probability import FreqDist
from nltk.tokenize.treebank import TreebankWordDetokenizer
from collections import defaultdict

nltk.download('punkt')
nltk.download('stopwords')

class HierarchicalStyleSummarizer:
    def __init__(self, context_window=4000):
        self.context_window = context_window  # max tokens in a summary chunk

    def get_frequency_distribution(self, text):
        words = word_tokenize(text.lower())
        stop_words = set(stopwords.words('english'))
        return FreqDist(word for word in words if word.isalnum() and word not in stop_words)

    def summarize_text(self, text, freq_dist):
        """
        Summarize the given text chunk by scoring sentences on word freq.
        No fixed target length, just return top sentences up to chunk token limit.
        """
        sentences = sent_tokenize(text)
        sentence_scores = []

        for sentence in sentences:
            score = sum(freq_dist.get(word.lower(), 0) for word in word_tokenize(sentence) if word.isalnum())
            sentence_scores.append((score, sentence))

        # Sort sentences by score descending
        sentence_scores.sort(key=lambda x: x[0], reverse=True)

        # Add sentences until near the context window token limit
        selected_sentences = []
        current_tokens = 0

        for score, sent in sentence_scores:
            sent_tokens = len(word_tokenize(sent))
            if current_tokens + sent_tokens <= self.context_window:
                selected_sentences.append(sent)
                current_tokens += sent_tokens
            else:
                break

        # Return detokenized summary
        return TreebankWordDetokenizer().detokenize(selected_sentences)

    def split_into_chunks(self, text, max_tokens):
        sentences = sent_tokenize(text)
        chunks = []
        current_chunk = []
        current_length = 0

        for sentence in sentences:
            sentence_length = len(word_tokenize(sentence))
            if current_length + sentence_length <= max_tokens:
                current_chunk.append(sentence)
                current_length += sentence_length
            else:
                chunks.append(' '.join(current_chunk))
                current_chunk = [sentence]
                current_length = sentence_length

        if current_chunk:
            chunks.append(' '.join(current_chunk))

        return chunks

    def hierarchical_summarization(self, text):
        """
        Hierarchical summarization:
        - Split text into chunks fitting context window
        - Summarize each chunk (no explicit target size)
        - Collate chunk summaries
        - Repeat until summary fits context window
        """
        current_text = text
        iteration = 0

        while True:
            iteration += 1
            token_count = len(word_tokenize(current_text))
            if token_count <= self.context_window:
                break  # summary fits context window

            # Split into chunks
            chunks = self.split_into_chunks(current_text, self.context_window)
            freq_dist = self.get_frequency_distribution(current_text)

            chunk_summaries = []
            for chunk in chunks:
                summary = self.summarize_text(chunk, freq_dist)
                chunk_summaries.append(summary)

            current_text = ' '.join(chunk_summaries)
            # Safety check to avoid infinite loops
            if iteration > 10:
                print("Reached max iterations in hierarchical summarization.")
                break

        return current_text

    def compute_proportional_lengths(self, text1, text2):
        len1 = len(word_tokenize(text1))
        len2 = len(word_tokenize(text2))
        total = len1 + len2
        # Ensure no zero division
        if total == 0:
            return self.context_window // 2, self.context_window // 2
        return int((len1 / total) * self.context_window), int((len2 / total) * self.context_window)

    def save_document(self, text, filename):
        with open(filename, 'w', encoding='utf-8') as f:
            f.write(text)

    def generate_query(self, input_summary, style_summary):
        """
        Generate a query combining the two summaries and meta info.
        For example, concatenating summaries with a note.
        """
        query = (
            "Input Text Summary:\n"
            + input_summary
            + "\n\nStyle Text Summary:\n"
            + style_summary
            + "\n\n"
            + "Please generate a summary of the input text following the style of the style text."
        )
        return query

    def process_documents(self, input_file, style_file, output_input_summary, output_style_summary, output_query_file):
        # Load texts
        with open(input_file, 'r', encoding='utf-8') as f:
            input_text = f.read()
        with open(style_file, 'r', encoding='utf-8') as f:
            style_text = f.read()

        # Compute proportional target lengths (used for info, though chunk size is fixed)
        target_input_len, target_style_len = self.compute_proportional_lengths(input_text, style_text)

        # Summarize style text hierarchically
        style_summary = self.hierarchical_summarization(style_text)
        self.save_document(style_summary, output_style_summary)
        print(f"Style text summarized to {len(word_tokenize(style_summary))} tokens.")

        # Summarize input text hierarchically
        input_summary = self.hierarchical_summarization(input_text)
        self.save_document(input_summary, output_input_summary)
        print(f"Input text summarized to {len(word_tokenize(input_summary))} tokens.")

        # Generate and save query
        query = self.generate_query(input_summary, style_summary)
        self.save_document(query, output_query_file)
        print("Query generated and saved.")

        return input_summary, style_summary, query

if __name__ == "__main__":
    summarizer = HierarchicalStyleSummarizer(context_window=4000)

    input_file = "/content/input.txt"
    style_file = "/content/style.txt"
    output_input_summary = "/content/input_summary.txt"
    output_style_summary = "/content/style_summary.txt"
    output_query_file = "/content/final_query.txt"

    input_summary, style_summary, query = summarizer.process_documents(
        input_file, style_file, output_input_summary, output_style_summary, output_query_file
    )

    print("\nInput Summary Preview:\n", input_summary[:1000])
    print("\nStyle Summary Preview:\n", style_summary[:1000])
    print("\nQuery Preview:\n", query[:1000])


Style text summarized to 400 tokens.
Input text summarized to 344 tokens.
Query generated and saved.

Input Summary Preview:
 Cognitive Behavioral Therapy

Cognitive behavioral therapy (CBT) is a psycho-social intervention that aims to improve mental health. CBT focuses on challenging and changing unhelpful cognitive distortions and behaviors, improving emotional regulation, and developing personal coping strategies that target solving current problems.

Originally developed to treat depression, CBT is now used for a variety of mental health conditions, including anxiety disorders, alcohol and drug use problems, marital problems, eating disorders, and severe mental illness. It is typically short-term and goal-oriented, involving collaboration between therapist and client.

CBT is based on the concept that your thoughts, feelings, and actions are interconnected, and that negative thoughts and feelings can trap you in a vicious cycle. It helps you deal with overwhelming problems in a mor

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
