<a href="https://colab.research.google.com/github/bhuvighosh3/store_codes/blob/main/self_querying.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install lark-parser
!pip install chromadb
import lark



In [None]:
#Necessary import
import numpy as np
import pandas as pd
import datetime as dt
import os
import matplotlib.pyplot as plt
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate,  ChatPromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.vectorstores import FAISS, Chroma
import csv
from typing import Dict, List, Optional
from langchain.document_loaders.base import BaseLoader
from langchain.docstore.document import Document

In [None]:
!pip install langchain



# Dataset

In [None]:
 # Define 4-th order runge-kutta
def rk_4(y_0,t,f,args=()):
  n = len(t)
  y = np.zeros((n, len(y0)))
  y[0] = y0
  for i in range(n - 1):
      h = t[i+1] - t[i]
      k1 = f(y[i], t[i], *args)
      k2 = f(y[i] + k1 * h / 2., t[i] + h / 2., *args)
      k3 = f(y[i] + k2 * h / 2., t[i] + h / 2., *args)
      k4 = f(y[i] + k3 * h, t[i] + h, *args)
      y[i+1] = y[i] + (h / 6.) * (k1 + 2*k2 + 2*k3 + k4)
  return y

# Define parameters of SIR model
beta=0.7 # Infection rate
gamma=0.1  # Removal reate
T_max= 90 # total days in consideration
dt=1
T_num=int(T_max/dt)
t= np.linspace(start=0,stop=T_max,num=int(T_num))


# Define f_SIR
def f_SIR(y,t,N,beta, gamma):
  f_SIR=np.zeros(3)
  f_SIR[0]= - beta/N*y[0]*y[1]
  f_SIR[1]= beta/N*y[0]*y[1]-gamma*y[1]
  f_SIR[2]= gamma*y[1]
  return f_SIR

np.random.seed(0)
y=[]
y0_i=np.random.randint(low=50, high=2000, size=10, dtype=int)
N=np.random.randint(low=5000, high=2e4, size=10, dtype=int)
for i in range(10):
    y0=np.array([N[i]-y0_i[i],N[i]-y0_i[i], 0])
    # Compute the value of SIR
    y.append(rk_4(y0, t, f_SIR,args=(N[i], beta, gamma)).astype(int))


In [None]:
sir_table_list=[]
for i in range(10):
    sir_table=pd.DataFrame(
        {
    "time":pd.date_range(start='1/1/2018', periods=90),"susceptible":y[i][:,0],"infectious":y[i][:,1] ,"removed":y[i][:,2],
"city": f"city{i}"        })
    sir_table_list.append(sir_table)
sir_table=pd.concat(sir_table_list)

In [None]:
sir_table.head()

Unnamed: 0,time,susceptible,infectious,removed,city
0,2018-01-01,8639,8639,0,city0
1,2018-01-02,3857,12338,1081,city0
2,2018-01-03,1458,13414,2405,city0
3,2018-01-04,545,12983,3749,city0
4,2018-01-05,214,12046,5017,city0


In [None]:
sir_table['time']=sir_table['time'].astype('object')

In [None]:
sir_table.to_csv("sir.csv")

# Embedding

In [None]:
# Load data and set embeddings
loader = CSVLoader(file_path='/content/sir.csv')
data = loader.load()

In [None]:
!pip install openai



In [None]:
!pip install tiktoken
embeddings = OpenAIEmbeddings(openai_api_key='')
vectorstore = Chroma.from_documents(data, embeddings)



In [None]:
retriever=vectorstore.as_retriever(search_kwargs={"k": 6})

In [None]:
llm=ChatOpenAI(model_name="gpt-4",openai_api_key='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',temperature=0)

# Define the system message template
system_template = """The provided {context} is a tabular dataset containing Suspectible, infectious and removed population during 90 days in 10 cities.
The dataset includes the following columns:
'time': time the population was meseaured,
'city': city in which the popoluation was measured,
"susceptible": the susceptible population of the disease,
"infectious": the infectious population of the disease,
"removed": the removed popolation of the disease.
----------------
{context}"""

# Create the chat prompt templates
messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}")
]
qa_prompt = ChatPromptTemplate.from_messages(messages)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(llm=llm, retriever=vectorstore.as_retriever(), return_source_documents=False,combine_docs_chain_kwargs={"prompt": qa_prompt},memory=memory,verbose=True)

In [None]:
qa.run(
    {
        "question": "Which city has the most infectious people on 2018-02-03?"
    })



[1m> Entering new StuffDocumentsChain chain...[0m


[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mSystem: The provided : 31
time: 2018-02-01
susceptible: 0
infectious: 1729
removed: 35608
city: city3

: 31
time: 2018-02-01
susceptible: 0
infectious: 1729
removed: 35608
city: city3

: 31
time: 2018-02-01
susceptible: 0
infectious: 1729
removed: 35608
city: city3

: 31
time: 2018-02-01
susceptible: 0
infectious: 1729
removed: 35608
city: city3 is a tabular dataset containing Suspectible, infectious and removed population during 90 days in 10 cities.
The dataset includes the following columns:
'time': time the population was meseaured,
'city': city in which the popoluation was measured,
"susceptible": the susceptible population of the disease,
"infectious": the infectious population of the disease,
"removed": the removed popolation of the disease.
----------------
: 31
time: 2018-02-01
susceptible: 0
infectious: 1729
removed: 35608
city: city3

: 31
ti

'The provided data does not include information for the date 2018-02-03.'

In [None]:
retriever.get_relevant_documents("Which city has the most infectious people on 2018-02-03?")

[Document(page_content=': 31\ntime: 2018-02-01\nsusceptible: 0\ninfectious: 1729\nremoved: 35608\ncity: city3', metadata={'row': 301, 'source': '/content/sir.csv'}),
 Document(page_content=': 31\ntime: 2018-02-01\nsusceptible: 0\ninfectious: 1729\nremoved: 35608\ncity: city3', metadata={'row': 301, 'source': '/content/sir.csv'}),
 Document(page_content=': 31\ntime: 2018-02-01\nsusceptible: 0\ninfectious: 1729\nremoved: 35608\ncity: city3', metadata={'row': 301, 'source': '/content/sir.csv'}),
 Document(page_content=': 31\ntime: 2018-02-01\nsusceptible: 0\ninfectious: 1729\nremoved: 35608\ncity: city3', metadata={'city': 'city3', 'row': 301, 'source': 'sir.csv', 'time': '2018-02-01'}),
 Document(page_content=': 31\ntime: 2018-02-01\nsusceptible: 0\ninfectious: 1729\nremoved: 35608\ncity: city3', metadata={'row': 301, 'source': '/content/sir.csv'}),
 Document(page_content=': 31\ntime: 2018-02-01\nsusceptible: 0\ninfectious: 1729\nremoved: 35608\ncity: city3', metadata={'city': 'city3', '

## Customized CSV loader with metadata

In [None]:
class MetaDataCSVLoader(BaseLoader):
    """Loads a CSV file into a list of documents.

    Each document represents one row of the CSV file. Every row is converted into a
    key/value pair and outputted to a new line in the document's page_content.

    The source for each document loaded from csv is set to the value of the
    `file_path` argument for all doucments by default.
    You can override this by setting the `source_column` argument to the
    name of a column in the CSV file.
    The source of each document will then be set to the value of the column
    with the name specified in `source_column`.

    Output Example:
        .. code-block:: txt

            column1: value1
            column2: value2
            column3: value3
    """

    def __init__(
        self,
        file_path: str,
        source_column: Optional[str] = None,
        metadata_columns: Optional[List[str]] = None,
        content_columns: Optional[List[str]] =None ,
        csv_args: Optional[Dict] = None,
        encoding: Optional[str] = None,
    ):
        self.file_path = file_path
        self.source_column = source_column
        self.encoding = encoding
        self.csv_args = csv_args or {}
        self.content_columns= content_columns
        self.metadata_columns = metadata_columns        # < ADDED

    def load(self) -> List[Document]:
        """Load data into document objects."""

        docs = []
        with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
            csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
            for i, row in enumerate(csv_reader):
                if self.content_columns:
                    content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items() if k in self.content_columns)
                else:
                    content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
                try:
                    source = (
                        row[self.source_column]
                        if self.source_column is not None
                        else self.file_path
                    )
                except KeyError:
                    raise ValueError(
                        f"Source column '{self.source_column}' not found in CSV file."
                    )
                metadata = {"source": source, "row": i}
                # ADDED TO SAVE METADATA
                if self.metadata_columns:
                    for k, v in row.items():
                        if k in self.metadata_columns:
                            metadata[k] = v
                # END OF ADDED CODE
                doc = Document(page_content=content, metadata=metadata)
                docs.append(doc)

        return docs

In [None]:
# Load data and set embeddings
loader = MetaDataCSVLoader(file_path="sir.csv",metadata_columns=['time','city'])
data = loader.load()

In [None]:
embeddings = OpenAIEmbeddings(openai_api_key='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')
vectorstore = Chroma.from_documents(data, embeddings)

In [None]:
llm = ChatOpenAI(openai_api_key='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', model_name="gpt-4", temperature=0)
metadata_field_info = [
    AttributeInfo(
        name="time",
        description="time the population was measured",
        type="object",
    ),
    AttributeInfo(
        name="city",
        description="city in which the population was measured",
        type="string",
    ),
]
document_content_description = "Susceptible, infectious and removed population during 90 days in 10 cities"
retriever = SelfQueryRetriever.from_llm(
    llm, vectorstore, document_content_description, metadata_field_info, search_kwargs={"k": 5}, verbose=True
)

In [None]:
retriever.get_relevant_documents("Which city has the max infetious people?")



query='max infectious people' filter=None limit=None


[Document(page_content=': 45\ntime: 2018-02-15\nsusceptible: 0\ninfectious: 147\nremoved: 12884\ncity: city8', metadata={'row': 765, 'source': '/content/sir.csv'}),
 Document(page_content=': 45\ntime: 2018-02-15\nsusceptible: 0\ninfectious: 147\nremoved: 12884\ncity: city8', metadata={'row': 765, 'source': '/content/sir.csv'}),
 Document(page_content=': 45\ntime: 2018-02-15\nsusceptible: 0\ninfectious: 147\nremoved: 12884\ncity: city8', metadata={'row': 765, 'source': '/content/sir.csv'}),
 Document(page_content=': 45\ntime: 2018-02-15\nsusceptible: 0\ninfectious: 147\nremoved: 12884\ncity: city8', metadata={'city': 'city8', 'row': 765, 'source': 'sir.csv', 'time': '2018-02-15'}),
 Document(page_content=': 45\ntime: 2018-02-15\nsusceptible: 0\ninfectious: 147\nremoved: 12884\ncity: city8', metadata={'row': 765, 'source': '/content/sir.csv'})]

# Build the LLM

In [None]:
llm=ChatOpenAI(openai_api_key='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',model_name="gpt-4",temperature=0)

# Define the system message template
system_template = """The provided {context} is a tabular dataset containing Suspectible, infectious and removed population during 90 days in 10 cities.
The dataset includes the following columns:
'time': time the population was meseaured,
'city': city in which the popoluation was measured,
"susceptible": the susceptible population of the disease,
"infectious": the infectious population of the disease,
"removed": the removed popolation of the disease.
----------------
{context}"""

# Create the chat prompt templates
messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}")
]
qa_prompt = ChatPromptTemplate.from_messages(messages)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, return_source_documents=False,combine_docs_chain_kwargs={"prompt": qa_prompt},memory=memory,verbose=True)

In [None]:
def chat_bot(question):
    return qa.run(
    {
        "question": question
    })

In [None]:
chat_bot(
"Which city has the max infetious people")

query='max infectious people' filter=None limit=None


[1m> Entering new StuffDocumentsChain chain...[0m


[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mSystem: The provided : 45
time: 2018-02-15
susceptible: 0
infectious: 147
removed: 12884
city: city8

: 45
time: 2018-02-15
susceptible: 0
infectious: 147
removed: 12884
city: city8

: 45
time: 2018-02-15
susceptible: 0
infectious: 147
removed: 12884
city: city8

: 45
time: 2018-02-15
susceptible: 0
infectious: 147
removed: 12884
city: city8

: 45
time: 2018-02-15
susceptible: 0
infectious: 147
removed: 12884
city: city8 is a tabular dataset containing Suspectible, infectious and removed population during 90 days in 10 cities.
The dataset includes the following columns:
'time': time the population was meseaured,
'city': city in which the popoluation was measured,
"susceptible": the susceptible population of the disease,
"infectious": the infectious population of the disease,
"removed": the removed pop

'The provided data only includes information for city8. To determine which city has the most infectious people, data for other cities would be needed.'

In [None]:
retriever.get_relevant_documents("Which city has the most infetious people on  2018-02-03?")

query='most infectious people' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='time', value=datetime.date(2018, 2, 3)) limit=None


ValueError: ignored