Implementing short term memory storage

In [1]:
from transformers import AutoTokenizer
from typing import List, Dict

class ContextManager:
    def __init__(self, model_name="bert-base-uncased", max_tokens=512):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.max_tokens = max_tokens
        self.history: List[Dict[str, str]] = []  # List of dicts like {"role": "user", "content": "..."}
    
    def add_message(self, role: str, content: str):
        """Add a message to history"""
        self.history.append({"role": role, "content": content})
    
    def get_context(self) -> List[Dict[str, str]]:
        """
        Returns a list of messages that fit within the token limit.
        Starts from the most recent and adds until the token limit is hit.
        """
        total_tokens = 0
        context = []
        for message in reversed(self.history):
            tokens = len(self.tokenizer.tokenize(message["content"]))
            if total_tokens + tokens > self.max_tokens:
                break
            context.insert(0, message)  # insert at beginning to preserve order
            total_tokens += tokens
        return context
    
    def reset(self):
        """Clear conversation history"""
        self.history = []


In [2]:
ctx = ContextManager(max_tokens=200)

ctx.add_message("user", "What's the weather like in Bangalore?")
ctx.add_message("assistant", "It’s sunny today.")
ctx.add_message("user", "What about tomorrow?")
ctx.add_message("assistant", "Expect light rain.")

context = ctx.get_context()
for msg in context:
    print(f"{msg['role'].capitalize()}: {msg['content']}")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

User: What's the weather like in Bangalore?
Assistant: It’s sunny today.
User: What about tomorrow?
Assistant: Expect light rain.


Long term Memory storing

In [3]:
pip install sentence-transformers faiss-cpu

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cublas_cu1

In [4]:
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import List, Dict

2025-08-05 20:01:30.259468: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754424090.502439      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754424090.576720      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
class ContextManager:
    def __init__(self, model_name="bert-base-uncased", embedding_model_name="all-MiniLM-L6-v2", max_tokens=512):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.max_tokens = max_tokens
        self.history: List[Dict[str, str]] = []  # temporary memory
        self.emb_model = SentenceTransformer(embedding_model_name)
        
        self.index = faiss.IndexFlatL2(384)  # 384 = dimension of MiniLM embeddings
        self.memory_texts = []  # Keep track of original messages
        self.memory_roles = []  # Who said what (user/assistant)

    def add_message(self, role: str, content: str):
        """Adds message to short-term + long-term memory"""
        self.history.append({"role": role, "content": content})
        
        # Add to long-term memory
        embedding = self.emb_model.encode(content)
        self.index.add(np.array([embedding]).astype("float32"))
        self.memory_texts.append(content)
        self.memory_roles.append(role)

    def get_context(self, query: str = None, summarizer: Summarizer = None) -> List[Dict[str, str]]:
        total_tokens = 0
        context = []
    
        for i, message in enumerate(reversed(self.history)):
            tokens = len(self.tokenizer.tokenize(message["content"]))
            if total_tokens + tokens > self.max_tokens:
                if summarizer:
                    old_msgs = [msg["content"] for msg in reversed(self.history[:len(self.history)-i])]
                    summary = summarizer.summarize(old_msgs)
                    context.insert(0, {"role": "system", "content": f"Summary of earlier: {summary}"})
                break
            context.insert(0, message)
            total_tokens += tokens
        return context


    def reset(self):
        """Clear all memory"""
        self.history = []
        self.memory_texts = []
        self.memory_roles = []
        self.index.reset()

    def get_relevant_memory(self, query: str, k: int = 3) -> List[Dict[str, str]]:
        query_vec = self.emb_model.encode(query)
    
        if self.index.ntotal == 0:
            return []
    
        D, I = self.index.search(np.array([query_vec]).astype("float32"), k)
        relevant = []
        for idx in I[0]:
            if idx < len(self.memory_texts):
                relevant.append({
                    "role": self.memory_roles[idx],
                    "content": self.memory_texts[idx]
                })
        return relevant


NameError: name 'Summarizer' is not defined

In [None]:
ctx = ContextManager(max_tokens=200)

ctx.add_message("user", "How do I reset my password?")
ctx.add_message("assistant", "Click on 'Forgot password' to reset.")
ctx.add_message("user", "How to change email ID?")
ctx.add_message("assistant", "Go to account settings and update your email.")
ctx.add_message("user", "what is the weather now?")
ctx.add_message("assistant", "It is sunny.")
ctx.add_message("user", "I cannot log in to my account in this page")
ctx.add_message("assistant", "You can either reset you password or change your email id")

# Simulate a new query
# query = "I can’t log into my account"
# context = ctx.get_context(query)

# # Print retrieved + relevant messages
# for msg in context:
#     print(f"{msg['role'].capitalize()}: {msg['content']}")

query = "I can’t log into my account"
relevant_only = ctx.get_relevant_memory(query, k=3)

for msg in relevant_only:
    print(f"{msg['role'].capitalize()}: {msg['content']}")


 Summarizer / Condenser (Optional but powerful)

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

class Summarizer:
    def __init__(self, model_name="t5-small"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    def summarize(self, texts: List[str], max_input_tokens=512, max_summary_tokens=100) -> str:
        input_text = " ".join(texts)
        tokens = self.tokenizer.tokenize(input_text)
        if len(tokens) > max_input_tokens:
            tokens = tokens[-max_input_tokens:]  # keep last N tokens
            input_text = self.tokenizer.convert_tokens_to_string(tokens)
        
        inputs = self.tokenizer("summarize: " + input_text, return_tensors="pt", truncation=True)
        summary_ids = self.model.generate(
            inputs["input_ids"], 
            max_length=max_summary_tokens, 
            num_beams=4, 
            early_stopping=True
        )
        return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [None]:
ctx = ContextManager(max_tokens=20)
summarizer = Summarizer()

# Add messages to fill memory
ctx.add_message("user", "How do I reset my password?")
ctx.add_message("assistant", "Click on 'Forgot password' to reset.")
ctx.add_message("user", "How to change email ID?")
ctx.add_message("assistant", "Go to settings...")
ctx.add_message("user", "What is the weather now?")
ctx.add_message("assistant", "It is sunny.")

query = "I can’t log into my account"
context = ctx.get_context(query=query, summarizer=summarizer)

for msg in context:
    print(f"{msg['role'].capitalize()}: {msg['content']}")