In [1]:
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple
import numpy as np

from core.models import Message, MessageId, ThreadId, SpaceId
from core.stores import EmbeddingStore, MembershipStore, MessageStore, ThreadStore
from core.interfaces import Formatter, Embedder, Reducer, Clusterer, ThreadRepComputer, Assigner, UpdateStrategy

import re
import datetime
import numpy as np
import pandas as pd

import logging
logging.basicConfig(level=logging.INFO)

## Defining the components for the processor

In [2]:
# ---------- Formatter (matches your ipynb add_context_window) ----------

@dataclass
class ContextWindowFormatter(Formatter):
    window_back: int = 2
    window_fwd: int = 1
    time_threshold_minutes: int = 10
    repeat_center: int = 2

    def format(self, idx: int, messages: List[Message]) -> str:
        cur = messages[idx]
        cur_time = cur.timestamp

        back_parts: List[str] = []
        for k in range(1, self.window_back + 1):
            j = idx - k
            if j < 0:
                break
            prev = messages[j]
            dt_min = (cur_time - prev.timestamp).total_seconds() / 60.0
            if dt_min > self.time_threshold_minutes:
                break
            back_parts.insert(0, f"{prev.user}: {prev.text}")

        fwd_parts: List[str] = []
        for k in range(1, self.window_fwd + 1):
            j = idx + k
            if j >= len(messages):
                break
            nxt = messages[j]
            dt_min = (nxt.timestamp - cur_time).total_seconds() / 60.0
            if dt_min > self.time_threshold_minutes:
                break
            fwd_parts.append(f"{nxt.user}: {nxt.text}")

        center = f"{cur.user}: {cur.text}"
        full = back_parts + [center] * self.repeat_center + fwd_parts
        return " \n ".join(full)


# ---------- Embedder (SentenceTransformer all-MiniLM-L6-v2) ----------

class MiniLMEmbedder(Embedder):
    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        from sentence_transformers import SentenceTransformer
        self.model = SentenceTransformer(model_name)

    def embed_texts(self, texts: List[str]) -> np.ndarray:
        X = self.model.encode(texts, show_progress_bar=True)
        return np.asarray(X, dtype=np.float32)


# ---------- Reducer (UMAP params from ipynb) ----------

class UMAPReducer(Reducer):
    def __init__(
        self,
        n_neighbors: int = 30,
        n_components: int = 5,
        min_dist: float = 0.0,
        metric: str = "cosine",
        random_state: int = 42,
    ):
        import umap
        self._umap = umap.UMAP(
            n_neighbors=n_neighbors,
            n_components=n_components,
            min_dist=min_dist,
            metric=metric,
            random_state=random_state,
        )
        self._is_fit = False

    def fit_transform(self, X: np.ndarray) -> np.ndarray:
        Y = self._umap.fit_transform(X)
        self._is_fit = True
        return np.asarray(Y, dtype=np.float32)

    def transform(self, X: np.ndarray) -> np.ndarray:
        if not self._is_fit:
            raise RuntimeError("UMAPReducer.transform() called before fit_transform().")
        Y = self._umap.transform(X)
        return np.asarray(Y, dtype=np.float32)


# ---------- Clusterer (HDBSCAN params from ipynb) ----------

class HDBSCANClusterer(Clusterer):
    def __init__(
        self,
        min_cluster_size: int = 30,
        min_samples: int = 3,
        metric: str = "euclidean",
        cluster_selection_method: str = "eom",
    ):
        import hdbscan
        self._hdbscan = hdbscan.HDBSCAN(
            min_cluster_size=min_cluster_size,
            min_samples=min_samples,
            metric=metric,
            cluster_selection_method=cluster_selection_method,
        )

    def cluster(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        self._hdbscan.fit(X)
        labels = self._hdbscan.labels_.astype(int)
        # "scores" in your pipeline can be probabilities_ (0..1); noise gets 0
        probs = getattr(self._hdbscan, "probabilities_", None)
        if probs is None:
            scores = np.ones((len(labels),), dtype=np.float32)
        else:
            scores = np.asarray(probs, dtype=np.float32)
            scores[labels == -1] = 0.0
        return labels, scores


# ---------- ThreadRepComputer (centroid of msg_space embeddings for active memberships) ----------

@dataclass
class CentroidThreadRepComputer(ThreadRepComputer):
    memberships: MembershipStore
    embeddings: EmbeddingStore
    msg_space: SpaceId = "msg:full"

    def _infer_dim(self) -> int:
        # find any message embedding to determine d
        for (space, _id), v in self.embeddings.data.items():
            if space == self.msg_space:
                return int(v.shape[0])
        return 0

    def compute(self, thread_id: ThreadId) -> np.ndarray:
        ms = self.memberships.for_thread(thread_id, status="active")
        mids = [m.message_id for m in ms]
        if not mids:
            d = self._infer_dim()
            return np.zeros((d,), dtype=np.float32) if d else np.zeros((0,), dtype=np.float32)

        vecs = []
        for mid in mids:
            if self.embeddings.has(self.msg_space, mid):
                vecs.append(self.embeddings.get(self.msg_space, mid))
        if not vecs:
            d = self._infer_dim()
            return np.zeros((d,), dtype=np.float32) if d else np.zeros((0,), dtype=np.float32)

        V = np.stack(vecs, axis=0).astype(np.float32)
        return V.mean(axis=0)


# ---------- Stubs for batch mode (run_batch doesnâ€™t need streaming assignment) ----------

class NoOpAssigner(Assigner):
    def assign(self, message_id: MessageId):
        return []

class NoOpUpdateStrategy(UpdateStrategy):
    def on_new_message(self, message_id: MessageId) -> None:
        return
    def flush(self) -> None:
        return


In [3]:
messages = MessageStore()
threads = ThreadStore()
memberships = MembershipStore()
embeddings = EmbeddingStore()

formatter = ContextWindowFormatter(window_back=2, window_fwd=1, time_threshold_minutes=10, repeat_center=2)
embedder = MiniLMEmbedder("all-MiniLM-L6-v2")
reducer = UMAPReducer(n_neighbors=30, n_components=5, min_dist=0.0, metric="cosine", random_state=42)
clusterer = HDBSCANClusterer(min_cluster_size=30, min_samples=3, metric="euclidean", cluster_selection_method="eom")

thread_rep = CentroidThreadRepComputer(memberships=memberships, embeddings=embeddings, msg_space="msg:full")

  from .autonotebook import tqdm as notebook_tqdm
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2


## Formatting raw data

In [4]:
def raw2df(file, key):
    """
    Converts raw .txt file into a Data Frame

    By tusharnankani, taken from github.com/tusharnankani/whatsapp-chat-data-analysis
    """

    split_formats = {
        '12hr' : '\d{1,2}/\d{1,2}/\d{2,4},\s\d{1,2}:\d{2}\s[APap][mM]\s-\s',
        '24hr' : '\d{1,2}/\d{1,2}/\d{2,4},\s\d{1,2}:\d{2}\s-\s',
        'custom' : ''
    }
    datetime_formats = {
        '12hr' : '%d/%m/%Y, %I:%M %p - ',
        '24hr' : '%d/%m/%Y, %H:%M - ',
        'custom': ''
    }

    with open(file, 'r', encoding='utf-8') as raw_data:
        # print(raw_data.read())
        raw_string = ' '.join(raw_data.read().split('\n')) # converting the list split by newline char. as one whole string as there can be multi-line messages
        user_msg = re.split(split_formats[key], raw_string) [1:] # splits at all the date-time pattern, resulting in list of all the messages with user names
        date_time = re.findall(split_formats[key], raw_string) # finds all the date-time patterns

        df = pd.DataFrame({'date_time': date_time, 'user_msg': user_msg}) # exporting it to a df

    # converting date-time pattern which is of type String to type datetime,
    # format is to be specified for the whole string where the placeholders are extracted by the method
    df['date_time'] = pd.to_datetime(df['date_time'], format=datetime_formats[key])

    # split user and msg
    usernames = []
    msgs = []
    for i in df['user_msg']:
        a = re.split('([\w\W]+?):\s', i) # lazy pattern match to first {user_name}: pattern and spliting it aka each msg from a user
        if(a[1:]): # user typed messages
            usernames.append(a[1])
            msgs.append(a[2])
        else: # other notifications in the group(eg: someone was added, some left ...)
            usernames.append("group_notification")
            msgs.append(a[0])

    # creating new columns
    df['user'] = usernames
    df['message'] = msgs

    # dropping the old user_msg col.
    df.drop('user_msg', axis=1, inplace=True)

    return df

  '12hr' : '\d{1,2}/\d{1,2}/\d{2,4},\s\d{1,2}:\d{2}\s[APap][mM]\s-\s',
  '24hr' : '\d{1,2}/\d{1,2}/\d{2,4},\s\d{1,2}:\d{2}\s-\s',
  a = re.split('([\w\W]+?):\s', i) # lazy pattern match to first {user_name}: pattern and spliting it aka each msg from a user


In [5]:
clean_df = raw2df('data/chats/whatsapp_chat_data_test.txt', '12hr')

## Populating message store

In [6]:
def load_from_df(df_clean) -> None:
    msgs = []
    for i, row in df_clean.iterrows():
        msgs.append(
            Message(
                id=f"m{i}",
                timestamp=row["date_time"].to_pydatetime() if hasattr(row["date_time"], "to_pydatetime") else row["date_time"],
                user=str(row["user"]),
                text=str(row["message"]),
            )
        )
    messages.add(msgs)

load_from_df(clean_df)

## Creating and running the main processor

In [7]:
from core.processor import ChatProcessor

processor = ChatProcessor(
    messages=messages,
    threads=threads,
    memberships=memberships,
    embeddings=embeddings,
    embedder=embedder,
    reducer=reducer,
    clusterer=clusterer,
    thread_rep_computer=thread_rep,
    assigner=NoOpAssigner(),
    update_strategy=NoOpUpdateStrategy(),
    formatter=formatter,
)

In [8]:
processor.run_batch()

INFO:core.processor:run_batch: start
INFO:core.processor:run_batch: messages=13655
INFO:core.processor:run_batch: formatting messages
INFO:core.processor:run_batch: formatted texts=13655
INFO:core.processor:run_batch: embedding texts
Batches: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 427/427 [00:54<00:00,  7.88it/s]
INFO:core.processor:run_batch: embeddings shape=(13655, 384) dtype=float32
INFO:core.processor:run_batch: stored msg embeddings space=msg:full count=13655
INFO:core.processor:run_batch: reducing embeddings
  warn(
INFO:core.processor:run_batch: reduced shape=(13655, 5) dtype=float32
INFO:core.processor:run_batch: stored cluster embeddings space=msg:cluster count=13655
INFO:core.processor:run_batch: clustering
INFO:core.processor:run_batch: clustering done clusters=96 noise=4553
INFO:core.processor:_labels_to_threads: start messages=13655
INFO:core.processor:_labels_to_threads: unique_labels=97 (including -1=True)
INFO:core.processor:_labels_to_threads: created_threads=96
INFO:co

In [9]:
processor.threads.all()

[Thread(id='thread_8cf73dfaff', title='Topic 0', summary='', created_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546707), updated_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546709), metadata={}),
 Thread(id='thread_317bc4ffaa', title='Topic 1', summary='', created_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546718), updated_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546719), metadata={}),
 Thread(id='thread_6bf662f77e', title='Topic 2', summary='', created_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546722), updated_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546723), metadata={}),
 Thread(id='thread_16c8d013a1', title='Topic 3', summary='', created_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546725), updated_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546725), metadata={}),
 Thread(id='thread_4c32291f48', title='Topic 4', summary='', created_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546728), updated_at=datetime.datetime(2026, 1, 4, 18, 41, 21, 546728), metadata={}),
 Thre

In [10]:
processor.memberships

<MembershipStore with 9102 memberships>

In [14]:
processor.messages

<MessageStore with 13655 messages>

In [11]:
import psutil, os
p = psutil.Process(os.getpid())
print("RSS GB:", p.memory_info().rss / 1e9)

RSS GB: 1.236197376


In [12]:
processor.messages.all()[1]

Message(id='m1', timestamp=datetime.datetime(2020, 1, 24, 20, 25), user='group_notification', text='Tanay Kamath (TSEC, CS) created group "CODERSðŸ‘¨\u200dðŸ’»ðŸ‘©\u200dðŸ’»ðŸ–¥ðŸ’»" ', metadata={})