In [None]:
# Go to https://ncbi.nlm.nih.gov/gds
# Select 'GEO DataSets'
# Search for 'Heart Desease' or for "MBNL1 regulates programmed postnatal switching between regenerative and differentiated cardiac states":
# > https://ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE246743
# Go to the end of the document and inspects the bases used in this study.
# Download this CSV to the ./documents folder: 
# > GSE246743_raw_counts.csv.gz	    703.6 Kb    (ftp)(http) CSV
# > https://ncbi.nlm.nih.gov/geo/download/?acc=GSE246743&format=file&file=GSE246743%5Fraw%5Fcounts%2Ecsv%2Egz
# Gunzip the file: gunzip GSE246743_raw_counts.csv.gz

In [None]:
import os
from dotenv import load_dotenv
load_dotenv(".env")

In [2]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

In [5]:
# !pip install pandas

In [None]:
import pandas as pd

df = pd.read_csv("./documents/GSE246743_raw_counts.csv", index_col=0)
df = df.head(10)
df.head()

In [None]:
df = df.T
df.head()

In [5]:
# !pip install langchain-experimental

In [None]:
import pandas as pd
from langchain_experimental.tools import PythonAstREPLTool

tool = PythonAstREPLTool(locals={"df": df})
tool.invoke("df['ENSMUSG00000051951'].mean()")

# 0.11764705882352941

In [None]:
tool.name

In [None]:
llm_with_tools = llm.bind_tools([tool], tool_choice=tool.name)
llm_with_tools.invoke(
    "I have a dataframe 'df' and want to know the standard deviation of 'ENSMUSG00000051951' column"
)

# AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_eHA4gd8l4NnJCsifoKHz27JE', 'function': {'arguments': '{"query":"df[\'ENSMUSG00000051951\'].std()"}', 'name': 'python_repl_ast'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 127, 'total_tokens': 144, 'completion_tokens_details': {'audio_tokens': None, 'reasoning_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': None, 'cached_tokens': 0}}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-2b49cab3-6334-404d-a9c1-d609f1cd69cf-0', tool_calls=[{'name': 'python_repl_ast', 'args': {'query': "df['ENSMUSG00000051951'].std()"}, 'id': 'call_eHA4gd8l4NnJCsifoKHz27JE', 'type': 'tool_call'}], usage_metadata={'input_tokens': 127, 'output_tokens': 17, 'total_tokens': 144, 'input_token_details': {'cache_read': 0}, 'output_token_details': {'reasoning': 0}})

In [None]:
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser

parser = JsonOutputKeyToolsParser(key_name=tool.name, return_single=True)

(llm_with_tools | parser).invoke(
    "I have a dataframe 'df' and want to know the standard deviation of 'ENSMUSG00000051951' column"
)

# [{'query': "df['ENSMUSG00000051951'].std()"}]

In [10]:
resp = (llm_with_tools | parser).invoke(
    "I have a dataframe 'df' and want to know the standard deviation of 'ENSMUSG00000051951' column"
)

In [None]:
resp[0]['query']

# "df['ENSMUSG00000051951'].std()"

In [None]:
tool.invoke(resp[0]['query'])

# 0.3321055820775357

In [13]:
# !pip install tabulate

In [None]:
df.head()

In [15]:
system = f"""
You hava access to a pandas dataframe `df`.
Here is the output of `df.head().to_markdown()`:

```
{df.head().to_markdown()}
```

Given a user question, write the Python code to answer it.
Return ONLY the valid Python code and nothing else.
DONT'T assume you have access to any libraries other than built-in Python ones and pandas.
"""

from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{question}")])

In [None]:
prompt

# ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template="\nYou hava access to a pandas dataframe `df`.\nHere is the output of `df.head().to_markdown()`:\n\n```\n|        |   ENSMUSG00000102693 |   ENSMUSG00000064842 |   ENSMUSG00000051951 |   ENSMUSG00000102851 |   ENSMUSG00000103377 |   ENSMUSG00000104017 |   ENSMUSG00000103025 |   ENSMUSG00000089699 |   ENSMUSG00000103201 |   ENSMUSG00000103147 |\n|:-------|---------------------:|---------------------:|---------------------:|---------------------:|---------------------:|---------------------:|---------------------:|---------------------:|---------------------:|---------------------:|\n| KO_0.1 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |\n| KO_0.2 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |\n| KO_0.3 |                    0 |                    0 |                    1 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |\n| KO_4.1 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |\n| KO_4.2 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |                    0 |\n```\n\nGiven a user question, write the Python code to answer it.\nReturn ONLY the valid Python code and nothing else.\nDONT'T assume you have access to any libraries other than built-in Python ones and pandas.\n"), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])

In [None]:
code_chain = prompt | llm_with_tools | parser

code_chain.invoke({"question": "What's the standard deviation of ENSMUSG00000051951?"})

# [{'query': "df['ENSMUSG00000051951'].std()"}]

In [18]:
def analyze_df(question):
    chain = (prompt | llm_with_tools | parser)
    resp = chain.invoke({"question": question})
    print(resp[0]['query'])
    return tool.invoke(resp[0]['query'])

In [None]:
analyze_df("What's the standard deviation of ENSMUSG00000051951?")

# df['ENSMUSG00000051951'].std()
# 0.3321055820775357

In [None]:
analyze_df("What's the mean of ENSMUSG00000051951?")

# df['ENSMUSG00000051951'].mean()
# 0.11764705882352941

In [None]:
analyze_df("What are the columns of the df?")

# f.columns
# Index(['ENSMUSG00000102693', 'ENSMUSG00000064842', 'ENSMUSG00000051951',
#        'ENSMUSG00000102851', 'ENSMUSG00000103377', 'ENSMUSG00000104017',
#        'ENSMUSG00000103025', 'ENSMUSG00000089699', 'ENSMUSG00000103201',
#        'ENSMUSG00000103147'],
#       dtype='object')

In [None]:
analyze_df("Which column has more zeroes?")

# zero_counts = df.eq(0).sum()
# zero_counts.idxmax()
# 'ENSMUSG00000102693'

In [None]:
analyze_df("Which column has less zeroes?")

# df.eq(0).sum().idxmin()
# 'ENSMUSG00000051951'

In [None]:
analyze_df("What is the sum of each column?")