<a href="https://colab.research.google.com/github/jessiechd/RAG_Model/blob/main/0710_chathistory_summarizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# setup

In [None]:
!pip install supabase numpy psycopg2 --q

In [None]:
import os
import json
import torch
import uuid
import numpy as np
from supabase import create_client, Client
from transformers import AutoTokenizer, AutoModel

# Initialize Supabase
SUPABASE_URL = ""
SUPABASE_KEY = ""

supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)

# Load Embedding Model
tokenizer = AutoTokenizer.from_pretrained("Alibaba-NLP/gte-multilingual-base", trust_remote_code=True)
model = AutoModel.from_pretrained("Alibaba-NLP/gte-multilingual-base", trust_remote_code=True).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))


# get embedding

In [None]:
import numpy as np
import ast
import re
from scipy.spatial.distance import cosine
from collections import Counter
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import nltk
nltk.download('all')
nltk.download('punkt')
nltk.download('stopwords')

def get_embedding(text):
    """Generates an embedding vector from input text."""
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().tolist()

def extract_keywords_simple(text):
    """Extracts important words from a query using simple filtering."""
    stop_words = set(stopwords.words('english'))
    words = word_tokenize(text.lower())
    keywords = [word for word in words if word.isalnum() and word not in stop_words]
    return keywords

def query_requires_table(user_query):
    """Determines if the query is likely asking for table data."""
    table_keywords = {"table", "data", "values", "measurements", "limits", "thresholds", "parameters", "average", "sum", "percentage"}
    return any(word in user_query.lower() for word in table_keywords)

def get_most_similar_keywords(query_keywords, top_text_chunks):
    """Extracts most relevant words from top retrieved text chunks."""
    all_text_words = set()
    for chunk in top_text_chunks:
        chunk_words = set(word_tokenize(chunk[2].lower()))  # Extract words from chunk text
        all_text_words.update(chunk_words)
    common_words = [word for word in query_keywords if word in all_text_words]
    return common_words if common_words else query_keywords  # Fallback to original keywords if no match



# hybrid bm25 + vec

In [None]:
!pip install rank_bm25 --q

In [None]:
import numpy as np
import ast
import re
import json
import psycopg2
from scipy.spatial.distance import cosine
from rank_bm25 import BM25Okapi

# Assume: get_embedding(), extract_keywords_simple(), query_requires_table() are already defined

def hybrid_retrieve(user_query, all_chunks, top_k=10):
    documents = [chunk[2] for chunk in all_chunks]  # chunk = (id, type, content, sim?)
    tokenized_corpus = [doc.split() for doc in documents]
    bm25 = BM25Okapi(tokenized_corpus)
    bm25_scores = bm25.get_scores(user_query.split())

    for i, chunk in enumerate(all_chunks):
        dense_sim = chunk[3] if chunk[3] else 0
        sparse_score = bm25_scores[i] if bm25_scores[i] else 0
        combined = 0.7 * dense_sim + 0.3 * sparse_score
        all_chunks[i] = (*chunk, bm25_scores[i], combined)

    all_chunks.sort(key=lambda x: x[5], reverse=True)  # sort by combined score
    return all_chunks[:top_k]

def query_supabase(user_query, top_k=5):
    """Hybrid Retrieval (BM25 + Dense Embedding) without reranking."""
    query_embedding = np.array(get_embedding(user_query), dtype=np.float32).flatten()
    keywords = extract_keywords_simple(user_query)
    requires_table = query_requires_table(user_query)
    query_list = query_embedding.tolist()

    conn = psycopg2.connect(DB_CONNECTION)
    cur = conn.cursor()

    ##### TEXT CHUNKS #####
    cur.execute("""
        SELECT id, 1 - (vec <=> %s) AS similarity
        FROM vecs.vec_text
        ORDER BY vec <=> %s
        LIMIT 10
    """, (json.dumps(query_list), json.dumps(query_list)))
    text_chunk_ids = cur.fetchall()

    text_results = []
    if text_chunk_ids:
        chunk_id_list = tuple([str(row[0]) for row in text_chunk_ids])
        cur.execute(f"""
            SELECT chunk_id, content, metadata
            FROM public.documents
            WHERE chunk_id IN {chunk_id_list};
        """)
        text_chunks = {row[0]: row[1:] for row in cur.fetchall()}
        text_results = [(cid, "text", text_chunks[cid][0], sim) for cid, sim in text_chunk_ids if cid in text_chunks]

    ##### TABLE CHUNKS #####
    cur.execute("""
        SELECT id, 1 - (vec <=> %s) AS similarity
        FROM vecs.vec_table
        ORDER BY vec <=> %s
        LIMIT 10
    """, (json.dumps(query_list), json.dumps(query_list)))
    table_chunk_ids = cur.fetchall()

    table_results = []
    if table_chunk_ids:
        chunk_id_list = tuple([str(row[0]) for row in table_chunk_ids])
        cur.execute(f"""
            SELECT chunk_id, description, metadata
            FROM public.tables
            WHERE chunk_id IN {chunk_id_list};
        """)
        table_chunks = {row[0]: row[1:] for row in cur.fetchall()}
        table_results = [(cid, "table", table_chunks[cid][0], sim) for cid, sim in table_chunk_ids if cid in table_chunks]

    conn.close()

    #### Combine Results and Run Hybrid ####
    all_results = text_results + table_results
    top_hybrid = hybrid_retrieve(user_query, all_results, top_k=top_k)

    return top_hybrid


# LLM function

In [None]:
import openai

# OpenAI API Key
OPENAI_API_KEY = ""
openai.api_key = OPENAI_API_KEY


In [None]:
def call_openai_llm(user_query, retrieved_chunks, chat_history=[]):
    """Send the query along with retrieved context and chat history to OpenAI API."""

    # 🔹 Sanitize chat history (make sure all entries are dicts)
    safe_history = []
    for msg in chat_history:
        if isinstance(msg, dict) and "role" in msg and "content" in msg:
            safe_history.append(msg)
        else:
            print("⚠️ Skipping malformed chat history entry:", msg)

    # 🔹 Prepare context from retrieved chunks
    context_text = "\n\n".join([f"Chunk {i+1}: {chunk[2]}" for i, chunk in enumerate(retrieved_chunks)])

    # 🔹 Construct messages for OpenAI Chat API
    messages = [
        {"role": "system", "content": "You are an intelligent assistant. Use the following retrieved information to answer the user's query."},
        *safe_history,
        {"role": "user", "content": f"Context:\n{context_text}\n\nUser's Question: {user_query}"}
    ]

    # 🔹 Make API call
    client = openai.OpenAI(api_key=openai.api_key)  # New client-style API
    response = client.chat.completions.create(
        model="gpt-4-turbo",
        messages=messages,
        temperature=0.7
    )

    # 🔹 Extract and update chat history
    answer = response.choices[0].message.content
    chat_history.append({"role": "user", "content": user_query})
    chat_history.append({"role": "assistant", "content": answer})

    return answer, chat_history


# chat history and summarizer as context

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

SUMMARIZER_MODEL_ID = "t5-small"
summarizer_tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL_ID)
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL_ID)

from transformers import PreTrainedTokenizerFast
llm_tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")

def count_tokens(text):
    return len(llm_tokenizer.encode(text))

In [None]:
def summarize_text(text, max_input_tokens=512, max_output_tokens=150):
    inputs = summarizer_tokenizer.encode(
        "summarize: " + text,
        return_tensors="pt",
        max_length=max_input_tokens,
        truncation=True
    )
    summary_ids = summarizer_model.generate(
        inputs,
        max_length=max_output_tokens,
        min_length=30,
        length_penalty=2.0,
        num_beams=4,
        early_stopping=True
    )
    return summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [None]:
class ChatContextManager:
    def __init__(self, summarize_every_turns=3, summarize_every_tokens=1000):
        self.chat_history = []
        self.summary = ""
        self.turns_since_last_summary = 0
        self.summarize_every_turns = summarize_every_turns
        self.summarize_every_tokens = summarize_every_tokens

    def summarize_text(self, text, max_input_tokens=512, max_output_tokens=150):
        inputs = summarizer_tokenizer.encode(
            "summarize: " + text,
            return_tensors="pt",
            max_length=max_input_tokens,
            truncation=True
        )
        summary_ids = summarizer_model.generate(
            inputs,
            max_length=max_output_tokens,
            min_length=30,
            length_penalty=2.0,
            num_beams=4,
            early_stopping=True
        )
        return summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    def add_turn(self, user_msg, bot_msg):
        self.chat_history.append(("user", user_msg))
        self.chat_history.append(("assistant", bot_msg))
        self.turns_since_last_summary += 1

        if self.should_summarize():
            self.update_summary()

    def should_summarize(self):
        full_text = " ".join([msg for _, msg in self.chat_history])
        token_count = count_tokens(full_text)
        return (
            self.turns_since_last_summary >= self.summarize_every_turns
            or token_count >= self.summarize_every_tokens
        )

    def update_summary(self):
        full_text = "\n".join([f"{role}: {msg}" for role, msg in self.chat_history])
        new_summary = self.summarize_text(full_text)
        print("\n📝 Summary updated:\n", new_summary, "\n")
        self.summary = f"{self.summary}\n{new_summary}" if self.summary else new_summary
        self.chat_history = []
        self.turns_since_last_summary = 0

    def get_context_for_llm(self, recent_n=2):
        recent = self.chat_history[-recent_n * 2:]
        recent_text = "\n".join([f"{role}: {msg}" for role, msg in recent])
        return f"Summary:\n{self.summary}\n\nRecent Turns:\n{recent_text}"

# TEST - chat history & summarizer

In [None]:
# before summarizer + chat history

def test_llm(user_query):
  print("\n🔹🔹🔹🔹🔹🔹\n")
  retrieved_chunks = query_supabase(user_query)
  print("\n🔹 Input Query:\n", user_query)
  response, chat_history = call_openai_llm(user_query, retrieved_chunks, [])
  print("\n🔹 Chatbot Response:\n", response)


test_llm("what are the uses of AI in the ecosystem?")
test_llm("what happens in the US?")
test_llm("how is the AI development in the US?")
test_llm("give me a company that utilizes this")


In [None]:
chat_ctx = ChatContextManager()

print("Type '0' to exit, '1' to reset chat history. \n")

while True:
    print("\n🔹🔹🔹🔹🔹🔹\n")
    user_query = input("\n🔹 Input Query:\n").strip()

    if user_query == "0":
        print("\n exiting...")
        break

    if user_query == "1":
        chat_ctx = ChatContextManager()  # reset context manager
        print("\n chat history cleared. \n")
        continue

    retrieved_chunks = query_supabase(user_query)
    context = chat_ctx.get_context_for_llm()
    chat_history = [{"role": "user", "content": context}] if context.strip() else []

    try:
        response, _ = call_openai_llm(user_query, retrieved_chunks, chat_history)
        print("\n 🔹 Chatbot Response:\n", response)
        chat_ctx.add_turn(user_query, response)

    except Exception as e:
        print("ERROR:", e)
