# Document Summarization

This notebook demonstrates an application of long document summarization techniques to a work of literature.

## Install Dependencies

Granite Kitchen comes with a bundle of dependencies that are required for notebooks. See the list of packages in its [`setup.py`](https://github.com/ibm-granite-community/granite-kitchen/blob/main/setup.py). 

In [None]:
! pip install git+https://github.com/ibm-granite-community/utils \
    "langchain_community<0.3.0" \
    "transformers>=4.45.2" \
    langchain-huggingface \
    replicate \
    torch \
    tiktoken

## Select your model

Select a Granite model from the [`ibm-granite`](https://replicate.com/ibm-granite) org on Replicate. Here we use the Replicate Langchain client to connect to the model.

To get set up with Replicate, see [Getting Started with Replicate](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Getting_Started/Getting_Started_with_Replicate.ipynb).

To connect to a model on a provider other than Replicate, substitute this code cell with one from the [LLM component recipe](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Components/Langchain_LLMs.ipynb).

In [None]:
from langchain_community.llms import Replicate
from ibm_granite_community.notebook_utils import get_env_var

model = Replicate(
    model="ibm-granite/granite-3.0-8b-instruct",
    replicate_api_token=get_env_var('REPLICATE_API_TOKEN'),
)

## Download a book

Here we fetch H.D. Thoreau's "Walden" from [Project Gutenberg](https://www.gutenberg.org/) for summarization.

We have to trim it down so that it will fit in the 128k-token context window of the model.

In [None]:
import requests
from time import sleep

# The following URL contains a text version of H.D. Thoreau's "Walden"
url = "https://www.gutenberg.org/cache/epub/205/pg205.txt"

# Get the contents
response = requests.get(url)
response.raise_for_status()
full_contents = response.text

# Extract the text of the book, leaving out the gutenberg boilerplate.
start_str = "*** START OF THE PROJECT GUTENBERG EBOOK WALDEN, AND ON THE DUTY OF CIVIL DISOBEDIENCE ***"
start_index = full_contents.index(start_str) + len(start_str)
end_str = "*** END OF THE PROJECT GUTENBERG EBOOK WALDEN, AND ON THE DUTY OF CIVIL DISOBEDIENCE ***"
end_index = full_contents.index(end_str)
book_contents = full_contents[start_index:end_index]
print("Length of book text: {} chars".format(len(book_contents)))

# We limit the text to 200k characters, which is about 57k tokens. (400k chars is ~114k tokens; 300k chars is ~86k tokens; 350k chars is ~100k tokens).
char_limit = 10000
contents = book_contents[:char_limit]
print("Length of text for summarization: {} chars".format(len(contents)))

## Count the tokens

Before sending our code to the AI model, it's crucial to understand how much of the model's capacity we're using. Language models typically have a limit on the number of tokens they can process in a single request.

Key points:
- We're using the [`granite-3.0-8b-instruct`](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct) model, which has a context window of 4,000 tokens.
- Tokenization can vary between models, so we use the specific tokenizer for our chosen model.

Understanding token count helps us optimize our prompts and ensure we're using the model efficiently.

In [None]:
from transformers import AutoTokenizer

model_path = "ibm-granite/granite-3.0-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
print("Your model uses the tokenizer " + type(tokenizer).__name__)

print(f"Your document has {len(tokenizer.tokenize(contents))} tokens. ")

## Summarize the text

Use this optimial question-answer format according to the Granite Prompting Guide

In [None]:
prompt_guide_template = """\
<|start_of_role|>user<|end_of_role|>{prompt}<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>"""

Consruct our prompt and send it to the AI model on Replicate for processing.

In [None]:
prompt = prompt_guide_template.format(prompt = f"""
Summarize the following text:
{contents}
""")

output = model.invoke(
    prompt,
    model_kwargs={
        "max_tokens": 10000, # Set the maximum number of tokens to generate as output.
        "min_tokens": 200, # Set the minimum number of tokens to generate as output.
        "temperature": 0.75,
        "system_prompt": "You are a helpful assistant.",
        "presence_penalty": 0,
        "frequency_penalty": 0
    }
    )

print(output)

## Summary of Summaries

Here we use a  hierarchical abstractive summarization technique to adapt to the context length of the model. Our approach is naïve, in that it takes equal-width chunks and groups of chunks from the document. A more sophisticated approach would be to create a document hierarchical structure that accounts for the document's structure and features, and groups text passages by topic or section. 

### Chunk the text

Divide the full text into smaller passages for separate processing. The `chunk_size` (given in tokens) must account for the size of both the messages (input) and the completions (output). The resulting chunk size may sometimes exceed the `chunk_size` provided, so we give it additional headroom. 

The `chunk_overlap` parameter allows us to overlap chunks by a certain number of tokens, to help preserve coherence between chunks.

In [None]:
from langchain.text_splitter import TokenTextSplitter
from langchain.docstore.document import Document

# Use entire book, or truncate if testing.
text = book_contents
if get_env_var('GRANITE_TESTING', 'false') == 'true':
    text = book_contents[:20000]
print(f"The text is {len(tokenizer.tokenize(text))} tokens.")

# Split the documents into chunks
chunk_token_limit = 3000  # In tokens: 3000 message + 512 completion + ~350 padding < 4000 context length
text_splitter = TokenTextSplitter.from_huggingface_tokenizer(tokenizer=tokenizer, chunk_size=chunk_token_limit, chunk_overlap=0)
chunks = text_splitter.split_text(text)

print("Chunk count: " + str(len(chunks)))
print("Max chunk length: " + str(max([len(tokenizer.tokenize(chunk)) for chunk in chunks])))

### Summarize the chunks

Here we create a separate summary of each passage. This can take a few minutes.

In [None]:
def summarize(texts, prompt_template, min, max):
    summaries = []
    for i, text in enumerate(texts):
        print(f"{i + 1}. Input size: {len(tokenizer.tokenize(text))} tokens")
        prompt = prompt_template.format(text=text)
        output = model.invoke(
            prompt,
            model_kwargs={
                "max_tokens": 2000, # Set the maximum number of tokens to generate as output.
                "min_tokens": 200, # Set the minimum number of tokens to generate as output.
                "temperature": 0.75,
                "system_prompt": "You are a helpful assistant.",
                "presence_penalty": 0,
                "frequency_penalty": 0
            }
        )
        print(f"{i + 1}. Output size: {len(tokenizer.tokenize(output))} tokens")
        summary = f"Summary {i+1}:\n{output}\n\n"
        summaries.append(summary)
        print(summary)

    print("Summary count: " + str(len(summaries)))
    summary_contents = "\n\n".join(summaries)
    print(f"Total: {len(tokenizer.tokenize(summary_contents))} tokens")

    return summaries


prompt = prompt_guide_template.format(prompt = """
    Summarize the following text using only the information found in the text:
    {text}
    """)

summaries_lvl_1 = summarize(chunks, prompt, 200, 2000)


### Summarize the Summaries

We signal to the model that it is receiving separate summaries of passages from an original text, and to create a unified summary of that text.

In [None]:
# Define a method for aggregating groups of summaries.
def group_array(arr, n):
    # Calculate the size of each chunk
    avg_len = len(arr) // n
    remainder = len(arr) % n
    result = []
    start = 0

    for i in range(n):
        # Distribute the remainder elements across the first chunks
        end = start + avg_len + (1 if i < remainder else 0)
        result.append(arr[start:end])
        start = end

    return result

# Aggregate groups of summaries for further summarization.
# Summaries are <=512k tokens, so we want at most 6 summaries (<=3072 tokens) per group.
num_groups = (len(chunks) // 6) + (1 if len(chunks) % 6 else 0)
summary_groups = group_array(summaries_lvl_1, num_groups)
texts_lvl_2 = ["\n\n".join(summary_group)[:3500] for summary_group in summary_groups]

prompt = prompt_guide_template.format(prompt = """\
A text was summarized in separate passages; those passage summaries are provided below. 

{text}

From these summaries alone, compose a single, unified summary of the text.
""")

summaries_lvl_2 = summarize(texts_lvl_2, prompt, 500, 1000)

### Create the Final Summary

Generate a single summary from the passage summaries generated above.

In [None]:
texts_lvl_3 = ["\n\n".join(summaries_lvl_2)]
final_summary = summarize(texts_lvl_3, prompt, 500, 1000)[0]