In [1]:
from langchain.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy

from langchain.text_splitter import RecursiveCharacterTextSplitter
from llm.gemini import Gemini
from llm.llm_utils import *
#
EMBEDDING_MODEL_NAME = 'bkai-foundation-models/vietnamese-bi-encoder'

embd = HuggingFaceEmbeddings(
    model_name=EMBEDDING_MODEL_NAME,
    multi_process=True,
    model_kwargs={"device": "cuda"},
    encode_kwargs={"normalize_embeddings": True},  # set True for cosine similarity
)


## LLM for summarization

In [2]:
llm = Gemini()

In [4]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)

In [5]:
import chromadb
raptor_client = chromadb.PersistentClient(path='database/raptor_2.db')
raptor_db = Chroma(client=raptor_client, embedding_function=embd)

## Clustering

In [13]:
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import umap
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
import multiprocessing as mp

RANDOM_SEED = 224  # Fixed seed for reproducibility

### --- Code from citations referenced above (added comments and docstrings) --- ###


def global_cluster_embeddings(
    embeddings: np.ndarray,
    dim: int,
    n_neighbors: Optional[int] = None,
    metric: str = "cosine",
) -> np.ndarray:
    
    
    if n_neighbors is None:
        n_neighbors = int((len(embeddings) - 1) ** 0.5)
    return umap.UMAP(
        n_neighbors=n_neighbors, n_components=dim, metric=metric
    ).fit_transform(embeddings)


def local_cluster_embeddings(
    embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
) -> np.ndarray:
    
    
    return umap.UMAP(
        n_neighbors=num_neighbors, n_components=dim, metric=metric
    ).fit_transform(embeddings)



def get_optimal_clusters(
    embeddings: np.ndarray, max_clusters: int = 400, random_state: int = RANDOM_SEED
) -> int:
    
    
    max_clusters = min(max_clusters, int(len(embeddings)/2))


    if max_clusters >= 300:
        n_clusters = np.arange(50, max_clusters+1,10)
    elif max_clusters > 100:
        n_clusters = np.arange(1, max_clusters,2)
    else:
        n_clusters = np.arange(1, max_clusters)
    bics = []
    
    if n_clusters[0] == 400:
        return 400
    for n in tqdm(n_clusters, desc="Optimizing clusters length: "+str(len(embeddings))):
        gm = GaussianMixture(n_components=n, random_state=random_state)
        gm.fit(embeddings)
        bics.append(gm.bic(embeddings))
    # print(max_clusters," : ",n_clusters[np.argmin(bics)])
    return n_clusters[np.argmin(bics)]




def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):
    
    
    n_clusters = get_optimal_clusters(embeddings)
    gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
    gm.fit(embeddings)
    probs = gm.predict_proba(embeddings)
    labels = [np.where(prob > threshold)[0] for prob in probs]
    return labels, n_clusters


def perform_clustering(
    embeddings: np.ndarray,
    dim: int,
    threshold: float,
) -> List[np.ndarray]:
    

    if len(embeddings) <= dim + 1:
        # Avoid clustering when there's insufficient data
        return [np.array([0]) for _ in range(len(embeddings))]

    # Global dimensionality reduction
    reduced_embeddings_global = global_cluster_embeddings(embeddings, dim)
    # Global clustering
    global_clusters, n_global_clusters = GMM_cluster(
        reduced_embeddings_global, threshold
    )

    all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
    total_clusters = 0

    # Iterate through each global cluster to perform local clustering
    for i in range(n_global_clusters):
        # Extract embeddings belonging to the current global cluster
        global_cluster_embeddings_ = embeddings[
            np.array([i in gc for gc in global_clusters])
        ]

        if len(global_cluster_embeddings_) == 0:
            continue
        if len(global_cluster_embeddings_) <= dim + 1:
            # Handle small clusters with direct assignment
            local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
            n_local_clusters = 1
        else:
            # Local dimensionality reduction and clustering
            reduced_embeddings_local = local_cluster_embeddings(
                global_cluster_embeddings_, dim
            )
            local_clusters, n_local_clusters = GMM_cluster(
                reduced_embeddings_local, threshold
            )

        # Assign local cluster IDs, adjusting for total clusters already processed
        for j in range(n_local_clusters):
            local_cluster_embeddings_ = global_cluster_embeddings_[
                np.array([j in lc for lc in local_clusters])
            ]
            indices = np.where(
                (embeddings == local_cluster_embeddings_[:, None]).all(-1)
            )[1]
            for idx in indices:
                all_local_clusters[idx] = np.append(
                    all_local_clusters[idx], j + total_clusters
                )

        total_clusters += n_local_clusters

    return all_local_clusters


## Embedding and Summarize

In [14]:
### --- Our code below --- ###
import time

def embed(texts):
    """
    Generate embeddings for a list of text documents.

    This function assumes the existence of an `embd` object with a method `embed_documents`
    that takes a list of texts and returns their embeddings.

    Parameters:
    - texts: List[str], a list of text documents to be embedded.

    Returns:
    - numpy.ndarray: An array of embeddings for the given text documents.
    """
    text_embeddings = embd.embed_documents(texts)
    text_embeddings_np = np.array(text_embeddings)
    return text_embeddings_np


def embed_cluster_texts(texts):
    """
    Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels.

    This function combines embedding generation and clustering into a single step. It assumes the existence
    of a previously defined `perform_clustering` function that performs clustering on the embeddings.

    Parameters:
    - texts: List[str], a list of text documents to be processed.

    Returns:
    - pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels.
    """
    text_embeddings_np = embed(texts)  # Generate embeddings
    cluster_labels = perform_clustering(
        text_embeddings_np, 15, 0.11
    )  # Perform clustering on the embeddings
    df = pd.DataFrame()  # Initialize a DataFrame to store the results
    df["text"] = texts  # Store original texts
    df["embd"] = list(text_embeddings_np)  # Store embeddings as a list in the DataFrame
    df["cluster"] = cluster_labels  # Store cluster labels
    return df


def fmt_txt(df: pd.DataFrame) -> str:
    """
    Formats the text documents in a DataFrame into a single string.

    Parameters:
    - df: DataFrame containing the 'text' column with text documents to format.

    Returns:
    - A single string where all text documents are joined by a specific delimiter.
    """
    unique_txt = df["text"].tolist()
    return "--- --- \n --- --- ".join(unique_txt).replace("_"," ")


def embed_cluster_summarize_texts(
    texts: List[str], level: int
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    

    # Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns
    df_clusters = embed_cluster_texts(texts)

    # Prepare to expand the DataFrame for easier manipulation of clusters
    expanded_list = []

    # Expand DataFrame entries to document-cluster pairings for straightforward processing
    for index, row in df_clusters.iterrows():
        for cluster in row["cluster"]:
            expanded_list.append(
                {"text": row["text"], "embd": row["embd"], "cluster": cluster}
            )

    # Create a new DataFrame from the expanded list
    expanded_df = pd.DataFrame(expanded_list)

    # Retrieve unique cluster identifiers for processing
    all_clusters = expanded_df["cluster"].unique()

    print(f"--Generated {len(all_clusters)} clusters--")

    # Summarization
    template = """
Bạn là một AI được huấn luyện trong việc tóm tắt văn bản. Bạn được cung cấp các thông tin hữu ích sau đây. 
    {context}
Hãy viết một văn bản mới có nội tóm tắt và trích xuất nội dung quan trọng có trong các văn bản trên trong khoảng 100-150 từ
    """
    

    # Format text within each cluster for summarization
    summaries = []
    bs = 10
    batch_i =[]
    
    for i in range(0, len(all_clusters), bs):
        batch_i.append(all_clusters[i:i+bs])
    take_cluster = []
    for i in all_clusters:
        try:
            print(f"Summarizing cluster {i}...")
            df_cluster = expanded_df[expanded_df["cluster"] == i]
            formatted_txt = fmt_txt(df_cluster)
            message = template.replace("{context}", formatted_txt)
            messages = [{"role":"user", "content":message}]
            response = llm(messages)
            with open("text2.txt", "a", encoding="utf-8") as f:
                f.write(response)
                f.write("=====================================\n")
            summaries.append(response)
            take_cluster.append(i)
        except:
            print("Error")
            continue

    # Create a DataFrame to store summaries with their corresponding cluster and level
    df_summary = pd.DataFrame(
        {
            "summaries": summaries,
            "level": [level] * len(summaries),
            "cluster": take_cluster,
        }
    )

    return df_clusters, df_summary


def recursive_embed_cluster_summarize(
    texts: List[str], level: int = 1, n_levels: int = 3
) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:
    """
    Recursively embeds, clusters, and summarizes texts up to a specified level or until
    the number of unique clusters becomes 1, storing the results at each level.

    Parameters:
    - texts: List[str], texts to be processed.
    - level: int, current recursion level (starts at 1).
    - n_levels: int, maximum depth of recursion.

    Returns:
    - Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion
      levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level.
    """
    results = {}  # Dictionary to store results at each level

    # Perform embedding, clustering, and summarization for the current level
    df_clusters, df_summary = embed_cluster_summarize_texts(texts, level)

    # Store the results of the current level
    results[level] = (df_clusters, df_summary)

    # Determine if further recursion is possible and meaningful
    unique_clusters = df_summary["cluster"].nunique()
    if level < n_levels and unique_clusters > 1:
        # Use summaries as the input texts for the next level of recursion
        new_texts = df_summary["summaries"].tolist()
        next_level_results = recursive_embed_cluster_summarize(
            new_texts, level + 1, n_levels
        )

        # Merge the results from the next level into the current results dictionary
        results.update(next_level_results)

    return results

## Read the doc

In [16]:
import os
leaf_texts = set()
for file in os.listdir('chunks'):
    if file.endswith(".txt"):
        with open(f'chunks/{file}', 'r') as f:
            text = f.read()
            texts = text_splitter.split_text(text)
            # results = recursive_embed_cluster_summarize(texts, n_levels=3)
            # print(results)
            for text in texts:
                leaf_texts.add(text)


In [17]:
leaf_texts = list(leaf_texts)

In [18]:
len(leaf_texts)

4526

## Iterative approach

In [19]:
stack = [(leaf_texts, 1)]
n_levels = 7
results = {}
while stack:
    texts, level = stack.pop()
    df_clusters, df_summary = embed_cluster_summarize_texts(texts, level)
    results[level] = (df_clusters, df_summary)
    
    unique_clusters = df_summary["cluster"].nunique()
    if level < n_levels and unique_clusters > 1:
        new_texts = df_summary["summaries"].tolist()
        stack.append((new_texts, level+1))
    print(f"Level {level} completed")

Optimizing clusters length: 4526: 100%|██████████| 36/36 [05:02<00:00,  8.39s/it]
Optimizing clusters length: 91: 100%|██████████| 44/44 [00:02<00:00, 15.03it/s]
Optimizing clusters length: 80: 100%|██████████| 39/39 [00:02<00:00, 15.73it/s]
Optimizing clusters length: 86: 100%|██████████| 42/42 [00:02<00:00, 17.63it/s]
Optimizing clusters length: 83: 100%|██████████| 40/40 [00:02<00:00, 16.84it/s]
Optimizing clusters length: 104: 100%|██████████| 51/51 [00:03<00:00, 16.72it/s]
Optimizing clusters length: 160: 100%|██████████| 79/79 [00:05<00:00, 14.32it/s]
Optimizing clusters length: 65: 100%|██████████| 31/31 [00:02<00:00, 14.26it/s]
Optimizing clusters length: 83: 100%|██████████| 40/40 [00:02<00:00, 16.52it/s]
Optimizing clusters length: 103: 100%|██████████| 50/50 [00:02<00:00, 18.84it/s]
Optimizing clusters length: 80: 100%|██████████| 39/39 [00:02<00:00, 16.12it/s]
Optimizing clusters length: 68: 100%|██████████| 33/33 [00:01<00:00, 21.32it/s]
Optimizing clusters length: 146: 10

--Generated 274 clusters--
Summarizing cluster 62.0...
Summarizing cluster 85.0...
Summarizing cluster 63.0...
Summarizing cluster 124.0...
Summarizing cluster 167.0...
Summarizing cluster 101.0...
Summarizing cluster 248.0...
Summarizing cluster 245.0...
Summarizing cluster 227.0...
Summarizing cluster 29.0...
Summarizing cluster 81.0...
Summarizing cluster 27.0...
Summarizing cluster 263.0...
Summarizing cluster 30.0...
Summarizing cluster 169.0...
Summarizing cluster 220.0...
Summarizing cluster 196.0...
Summarizing cluster 223.0...
Summarizing cluster 40.0...
Summarizing cluster 250.0...
Summarizing cluster 273.0...
Summarizing cluster 59.0...
Summarizing cluster 113.0...
Summarizing cluster 122.0...
Summarizing cluster 259.0...
Summarizing cluster 202.0...
Summarizing cluster 266.0...
Summarizing cluster 146.0...
Summarizing cluster 157.0...
Summarizing cluster 86.0...
Summarizing cluster 7.0...
Summarizing cluster 25.0...
Summarizing cluster 26.0...
Summarizing cluster 11.0...
Su

Optimizing clusters length: 274: 100%|██████████| 68/68 [00:15<00:00,  4.25it/s]
Optimizing clusters length: 39: 100%|██████████| 18/18 [00:00<00:00, 21.52it/s]
Optimizing clusters length: 44: 100%|██████████| 21/21 [00:00<00:00, 21.10it/s]
Optimizing clusters length: 37: 100%|██████████| 17/17 [00:00<00:00, 19.30it/s]
Optimizing clusters length: 22: 100%|██████████| 10/10 [00:00<00:00, 19.70it/s]
Optimizing clusters length: 33: 100%|██████████| 15/15 [00:00<00:00, 16.61it/s]
Optimizing clusters length: 40: 100%|██████████| 19/19 [00:00<00:00, 19.58it/s]
Optimizing clusters length: 59: 100%|██████████| 28/28 [00:01<00:00, 20.51it/s]


--Generated 37 clusters--
Summarizing cluster 1.0...
Summarizing cluster 22.0...
Summarizing cluster 28.0...
Summarizing cluster 11.0...
Summarizing cluster 20.0...
Summarizing cluster 7.0...
Summarizing cluster 36.0...
Summarizing cluster 5.0...
Summarizing cluster 34.0...
Summarizing cluster 10.0...
Summarizing cluster 19.0...
Summarizing cluster 31.0...
Summarizing cluster 17.0...
Summarizing cluster 15.0...
Summarizing cluster 8.0...
Summarizing cluster 13.0...
Summarizing cluster 33.0...
Summarizing cluster 27.0...
Summarizing cluster 24.0...
Summarizing cluster 3.0...
Summarizing cluster 23.0...
Summarizing cluster 18.0...
Summarizing cluster 30.0...
Summarizing cluster 32.0...
Summarizing cluster 25.0...
Summarizing cluster 6.0...
Summarizing cluster 14.0...
Summarizing cluster 2.0...
Summarizing cluster 16.0...
Summarizing cluster 0.0...
Summarizing cluster 21.0...
Summarizing cluster 4.0...
Summarizing cluster 12.0...
Summarizing cluster 29.0...
Summarizing cluster 9.0...
Summ

Optimizing clusters length: 37: 100%|██████████| 17/17 [00:01<00:00, 16.90it/s]


--Generated 4 clusters--
Summarizing cluster 1.0...
Summarizing cluster 2.0...
Summarizing cluster 0.0...
Summarizing cluster 3.0...
Level 3 completed
--Generated 1 clusters--
Summarizing cluster 0...
Level 4 completed


In [20]:
# raptor_client2 = chromadb.PersistentClient(path='database/raptor_backup.db')
# raptor_db_backup = Chroma(client=raptor_client2, embedding_function=embd)

for k,v in results.items():
    texts = v[0]['text'].tolist()
    batch_size = 100
    for i in range(0, len(texts), batch_size):
        batch_texts = leaf_texts[i:i+batch_size]

        raptor_db.add_texts(batch_texts)
