# Clustered Summary

This is an attempt to implement the idea of [How to Summarize Large Documents with LangChain and OpenAI](https://medium.com/@myscale/how-to-summarize-large-documents-with-langchain-and-openai-4312568e80b1) in dotnet with Amazon Bedrock.

## Parsing

For PDF parsing we use [PdfPig](https://github.com/UglyToad/PdfPig). See the [NOTICE](NOTICE) file for license information.

In [1]:
#r "nuget: PdfPig"

In [2]:
using System.IO;
using UglyToad.PdfPig;

string GetText(FileInfo file)
{
    var text = new StringBuilder();

    using (var pdfDocument = PdfDocument.Open(file.FullName))
    {
        foreach (var page in pdfDocument.GetPages())
        {
            // word grouping by bottom coordinates taken from https://stackoverflow.com/a/75043692/6466378
            var wordsList = page.GetWords().GroupBy(x => x.BoundingBox.Bottom);
            foreach (var word in wordsList)
            {
                foreach (var item in word)
                {
                    text.Append(item.Text + " ");
                }
                text.Append("\n");
            }
        }
    }
    return text.ToString();
}

In [3]:
var file = new FileInfo("documents\\Towards Trust in Legal AI - Enhancing LLMs with Retrieval Augmented Generation.pdf");
var text = GetText(file);
display($"{text.Length} characters");

299821 characters

## Partitioning

For text paritioning/chunking we use [SemanticKernel](https://github.com/microsoft/semantic-kernel).

In [4]:
#r "nuget: Microsoft.SemanticKernel.Core"

In [5]:
using Microsoft.SemanticKernel.Text;

string[] Partition(string text) 
{
    //Currently we use static chunking. We should replace this with semantic chunking in the future.
    const int maxTokensPerLine = 200;
    const int maxTokensPerParagraph = 700;
    const int overlappingTokens = 70;
    const double charactersPerToken = 4.7; //https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html

    #pragma warning disable SKEXP0050 // experminental API
    TextChunker.TokenCounter tokenCounter = (string s) => (int)(s.Length / charactersPerToken);
    var sentences = TextChunker.SplitPlainTextLines(text, maxTokensPerLine: maxTokensPerLine, tokenCounter: tokenCounter);
    var partitions = TextChunker.SplitPlainTextParagraphs(sentences, maxTokensPerParagraph: maxTokensPerParagraph, overlapTokens: overlappingTokens, tokenCounter: tokenCounter, chunkHeader: null);
    #pragma warning restore SKEXP0050 // experminental API

    return partitions.ToArray();
}

In [6]:
var partitions = Partition(text);
display($"{partitions.Length} partitions");

107 partitions

## Embedding

For text embedding we use [Amazon Titan Text Embeddings v2](https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html) on Amazon Bedrock.

In [7]:
#r "nuget: AWSSDK.BedrockRuntime"
#r "nuget: Microsoft.Extensions.Configuration.Json"

In [8]:
using Amazon.BedrockRuntime;
using Amazon.BedrockRuntime.Model;
using Microsoft.Extensions.Configuration;

var config = new ConfigurationBuilder()
    .SetBasePath(Directory.GetCurrentDirectory())
    .AddJsonFile("appsettings.json", optional: false)
	.AddJsonFile("appsettings.local.json", optional: true)
    .Build();

var bedrock = new AmazonBedrockRuntimeClient(
	awsAccessKeyId: config["AWSBedrockAccessKeyId"]!,
	awsSecretAccessKey: config["AWSBedrockSecretAccessKey"]!,
	region: Amazon.RegionEndpoint.GetBySystemName(config["AWSBedrockRegion"]!));

In [9]:
using System.Text.Json;
using System.Threading;
using Amazon.Util;

record EmbeddingRequest(string inputText);
record EmbeddingResponse(float[] embedding, int inputTextTokenCount);

async Task<EmbeddingResponse> Embed(string text, CancellationToken cancellationToken = default)
{
    var requestBody = AWSSDKUtils.GenerateMemoryStreamFromString(JsonSerializer.Serialize(new EmbeddingRequest(text)));
    var request = new InvokeModelRequest
    {
        ModelId = "amazon.titan-embed-text-v2:0",
        Body = requestBody,
    };
    var response = await bedrock.InvokeModelAsync(request, cancellationToken);
    var embedded = await JsonSerializer.DeserializeAsync<EmbeddingResponse>(response.Body, cancellationToken: cancellationToken);
    return embedded;
}

We embed all partitions in parallel.

In [10]:
using System.Linq;
using System.Diagnostics;

record Embedding(float[] embedding, int inputTokens, int paritionIndex);

var embeddingStopwatch = Stopwatch.StartNew();
var embeddingTasks = partitions
    .Select((p, i) => (value: p, index: i))
    .Select(async partition =>
    {
        var embedding = await Embed(partition.value);
        return new Embedding(embedding.embedding, embedding.inputTextTokenCount, partition.index);
    })
    .ToList();
await Task.WhenAll(embeddingTasks);
var embeddings = embeddingTasks.Select(t => t.Result).ToList();
var totalInputTokens = embeddings.Sum(s => s.inputTokens);
embeddingStopwatch.Stop();

display($"Parallel embedding {embeddings.Count} times took {embeddingStopwatch.Elapsed} and {totalInputTokens} input tokens");

Parallel embedding 107 times took 00:00:01.7143092 and 73674 input tokens

## Clustering

For clustering we use K-means with [ML.NET](https://dotnet.microsoft.com/en-us/apps/ai/ml-dotnet).

In [11]:
#r "nuget: Microsoft.ML"

In [12]:
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

record DataPoint([property:KeyType(2)] uint PartitionIndex, [property:VectorType(1024)] float[] Features);

List<float[]> Kmeans(IEnumerable<DataPoint> dataPoints)
{
    const int k = 10;
    var mlContext = new MLContext(seed: 0);
    
    //Load
    var dataView = mlContext.Data.LoadFromEnumerable(dataPoints);
    
    //Normalize (Does it really make a difference? If we need to, we need to normalize the embeddings in general, not only for clustering.)
    // var dataProcessingPipeline = mlContext.Transforms.NormalizeMeanVariance(nameof(DataPoint.Features));
    // var processedData = dataProcessingPipeline.Fit(dataView).Transform(dataView);
    var processedData = dataView;
    
    //Train
    var pipeline = mlContext.Clustering.Trainers.KMeans(new KMeansTrainer.Options
    {
        NumberOfClusters = k,
        FeatureColumnName = nameof(DataPoint.Features),
    });
    var model = pipeline.Fit(processedData);
    
    //Getting clusters
    VBuffer<float>[] centroids = default;
    var modelParams = model.Model;
    modelParams.GetClusterCentroids(ref centroids, out int clusters);
    return centroids.Select(c => c.DenseValues().ToArray()).ToList();
}

In [13]:
var kmeansStopwatch = Stopwatch.StartNew();
var dataPoints = embeddings.Select(e => new DataPoint((uint)e.paritionIndex, e.embedding));
var clusterCenters = Kmeans(dataPoints);
kmeansStopwatch.Stop();
display($"K-means took {kmeansStopwatch.Elapsed} to build {clusterCenters.Count} clusters");

K-means took 00:00:00.0794333 to build 10 clusters

## Indexing

For each semantic cluster we find a representative that is the nearest neighbour of the cluster center. We use [HNSW.Net](https://github.com/curiosity-ai/hnsw-sharp) for fast vector search in memory.

In [14]:
#r "nuget: HNSW"

In [15]:
using HNSW.Net;

Func<float[], Embedding> Hnsw(List<Embedding> embeddings)
{
    // Parameter explanation https://github.com/curiosity-ai/hnsw-sharp/blob/master/Src/HNSW.Net/SmallWorld.cs#L216
    var parameters = new SmallWorld<Embedding, float>.Parameters()
    {
        M = 15,
        LevelLambda = 1 / Math.Log(15),
    };

    var graph = new SmallWorld<Embedding, float>((e1, e2) => CosineDistance.NonOptimized(e1.embedding, e2.embedding), DefaultRandomGenerator.Instance, parameters);
    graph.AddItems(embeddings);
    
    return search =>
    {
        var searchEmbedding = new Embedding(search, -1, -1);
        var nn = graph.KNNSearch(searchEmbedding, 1).Single();
        return nn.Item;
    };
}

In [16]:
var nnStopwatch = Stopwatch.StartNew();
var nn = Hnsw(embeddings);
var clusterRepresentatives = clusterCenters
    .Select(nn)
    .Select(crnn => new { crnn.paritionIndex, partition = partitions.ElementAt(crnn.paritionIndex) })
    .ToList();
nnStopwatch.Stop();

display($"Building HNSW index with {embeddings.Count} items and finding {clusterRepresentatives.Count} 'cluster center nearest neighbours' took {nnStopwatch.Elapsed}.");
display(clusterRepresentatives.Select(cr => cr.paritionIndex).Order());

Building HNSW index with 107 items and finding 10 'cluster center nearest neighbours' took 00:00:00.5676770.

Alternatively we dont use an index. Since we only have a few embeddings, it might even be faster to calculate cosine similarities directly. It is limited by $O(k \cdot |E|)$ with $k$ being the count of semantic clusters and $E$ the set of embeddings.

In [17]:
#r "nuget: System.Numerics.Tensors"

In [18]:
using System.Numerics.Tensors;

Func<float[], Embedding> BruteForce(List<Embedding> embeddings)
{
    return search => embeddings.OrderByDescending(e => TensorPrimitives.CosineSimilarity(search, e.embedding)).First();
}

In [19]:
var nnStopwatch = Stopwatch.StartNew();
var nn = BruteForce(embeddings);
var clusterRepresentatives = clusterCenters
    .Select(nn)
    .Select(crnn => new { crnn.paritionIndex, partition = partitions.ElementAt(crnn.paritionIndex) })
    .ToList();
nnStopwatch.Stop();

display($"Brute-force finding {clusterRepresentatives.Count} 'cluster center nearest neighbours' took {nnStopwatch.Elapsed}.");
display(clusterRepresentatives.Select(cr => cr.paritionIndex).Order());

Brute-force finding 10 'cluster center nearest neighbours' took 00:00:00.0030793.

## Summarizing

For summarization we use [Anthropic Claude 3.5 Sonnet](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html) on Amazon Bedrock.

In [20]:
record Content(string role, string text);
record Usage(int input_tokens, int output_tokens);
record InvokeResponse(Content[] content, string stop_reason, Usage usage);
record SummarizeResponse(string summary, int inputTokens, int outputTokens);

async Task<SummarizeResponse> Summarize(string text, CancellationToken cancellationToken = default)
{
    var prompt =
        $"""
        Provide a summary of the following text. Your result must be detailed and atleast 2 paragraphs.
        When summarizing, directly dive into the narrative or descriptions from the text without using
        introductory phrases like 'In this passage'. Directly address the main events, characters, and
        themes, encapsulating the essence and significant details from the text in a flowing narrative.
        The goal is to present a unified view of the content, continuing the story seamlessly as if the
        passage naturally progresses into the summary.
        
        Passage:
        ```{text}```
        """;

    var nativeRequest = new
    {
        anthropic_version = "bedrock-2023-05-31",
        max_tokens = 512,
        temperature = 0.0,
        messages = new[]
        {
            new { role = "user", content = prompt }
        }
    };
    
    var response = await bedrock.InvokeModelAsync(new InvokeModelRequest
    {
        ModelId = "anthropic.claude-3-5-sonnet-20240620-v1:0", //is "anthropic.claude-3-haiku-20240307-v1:0"; cheaper/faster?
        Body = AWSSDKUtils.GenerateMemoryStreamFromString(JsonSerializer.Serialize(nativeRequest)),
        ContentType = "application/json",
    }, cancellationToken);

    var responseBody = await JsonSerializer.DeserializeAsync<InvokeResponse>(response.Body, cancellationToken: cancellationToken);
    return new SummarizeResponse(responseBody.content[0].text, responseBody.usage.input_tokens, responseBody.usage.output_tokens);
}

In [21]:
record PartitionSummary(SummarizeResponse summary, int paritionIndex);

var summarizeStopwatch = Stopwatch.StartNew();
var summaryTasks = clusterRepresentatives
    .Select(async clusterRepresentative =>
    {
        var summary = await Summarize(clusterRepresentative.partition);
        return new PartitionSummary(summary, clusterRepresentative.paritionIndex);
    })
    .ToList();
await Task.WhenAll(summaryTasks);
var summaries = summaryTasks
    .Select(t => t.Result)
    .OrderBy(s => s.paritionIndex)//keeping the order of the original partitions of source document
    .ToList();

var concatenatedSummary = string.Join("\n\n", summaries.Select(s => s.summary.summary));
var totalInputTokens = summaries.Sum(s => s.summary.inputTokens);
var totalOutputTokens = summaries.Sum(s => s.summary.outputTokens);
summarizeStopwatch.Stop();

display($"Summarizing {clusterRepresentatives.Count} cluster representatives took {summarizeStopwatch.Elapsed}, {totalInputTokens} input tokens and {totalOutputTokens} output tokens.");
//display(concatenatedSummary);

Summarizing 10 cluster representatives took 00:00:06.7322285, 9013 input tokens and 2744 output tokens.