Skip to content

Commit

Permalink
Merge pull request #2 from nyanp/feature/function-call
Browse files Browse the repository at this point in the history
Support function call (update docs, refactoring)
  • Loading branch information
nyanp committed Jun 24, 2023
2 parents e0bee03 + 29d58e1 commit 4d1c940
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 42 deletions.
100 changes: 93 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,31 +65,117 @@ This design limits the visualization expression compared to Python code generati
- Interactive
- Declarative data can be modified by the user to improve plots through collaborative work between the user and LLM.

## Usage
By default, chat2plot uses [function calling API](https://openai.com/blog/function-calling-and-other-api-updates).

## Examples

### Custom language models
`gpt-3.5-turbo-0613` is used by default, but you can use other language models.

```python
import pandas as pd
from langchain.chat_models import AzureChatOpenAI
from chat2plot import chat2plot

plot = chat2plot(pd.DataFrame(), chat=AzureChatOpenAI())
ret = plot.query("<your query>")
```

### Vega-lite format

```python
import pandas as pd
from chat2plot import chat2plot

plot = chat2plot(pd.DataFrame(), schema_definition="vega")
ret = plot.query("<your query>")

assert isinstance(ret.config, dict) # vega-lite format
print(ret.config)
```

### Custom chart definition

```python
import pydantic
import pandas as pd
from chat2plot import chat2plot

class CustomChartConfig(pydantic.BaseModel):
chart_type: str
x_axis_name: str
y_axis_name: str
y_axis_aggregate: str

plot = chat2plot(pd.DataFrame(), schema_definition=CustomChartConfig)
ret = plot.query("<your query>")

# chat2plot treats the data type you pass as a chart setting
assert isinstance(ret.config, CustomChartConfig)
```

### Specifying output language
You can specify in which language the chart explanations should be output.
If not specified, it will return as much as possible in the same language as the user's question,
but this option is often useful if you always want output in a specific language.

```python
import pandas as pd
from chat2plot import chat2plot

plot = chat2plot(pd.DataFrame(), language="Chinese")
ret = plot.query("<your query>")

print(ret.explanation) # explanation
```

### Privacy preserving
When `description_strategy="dtypes"` is specified, chat2plot will not send the data
content (but just column names) to LLM.

```python
import pandas as pd
from langchain.chat_models import AzureChatOpenAI
from chat2plot import chat2plot

plot = chat2plot(pd.DataFrame(), description_strategy="dtypes")
ret = plot.query("<your query>")
```


## API

A `Chat2Plot` instance can be created using the `chat2plot` function.

```Python
def chat2plot(
df: pd.DataFrame,
model_type: str = "simple",
chat: langchain.chat_models.base.BaseChatModel | None = None,
schema_definition: Literal["simple", "vega"] | Type[pydantic.BaseModel] = "simple",
chat: BaseChatModel | None = None,
function_call: bool | Literal["auto"] = "auto",
language: str | None = None,
description_strategy: str = "head",
custom_deserializer: ModelDeserializer | None = None,
verbose: bool = False,
) -> Chat2PlotBase:
```

- **df** - Data source for visualization.
- **model_type** (optional) - Type of json format.
- **schema_definition** (optional) - Type of json format.
- `vega` - A vega-lite compliant format
- `simple` - A simpler format, parsed as `chat2plot.PlotConfig`
- `simple` - chat2plot's built-in format, parsed as `chat2plot.PlotConfig`

If you want chat2plot to generate chart definitions according to your own defined schema,
you can pass any type that extends pydantic.BaseModel instead of these two options.
- **chat** (optional) - The chat instance for interaction with LLMs.
If omitted, `ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")` will be used.
If omitted, `ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613")` will be used.
- **function_call** (optional) - Specifies whether to use the [function calling API](https://openai.com/blog/function-calling-and-other-api-updates).
If omitted, it is automatically determined based on the underlying model type.
- **language** (optional) - Language of explanations. If not specified, it will be automatically inferred from user prompts.
- **description_strategy** - Type of how the information in the dataset is embedded in the prompt.
- **description_strategy** (optional) - Type of how the information in the dataset is embedded in the prompt.
- `head` - send `df.head(5)` to LLMs.
- `dtypes` - send `df.dtypes` to LLMs. This can be used when you do not want to send contents of `df` to LLMs.
- **custom_deserializer** (optional) - Specifies a custom deserializer to convert json returned from the LLM into a chart configuration.
- **verbose** (optional) - If `True`, chat2plot will output logs.


Expand Down
23 changes: 14 additions & 9 deletions chat2plot/chat2plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,6 @@ def __init__(
def history(self) -> list[BaseMessage]:
return copy.deepcopy(self._conversation_history)

def set_chatmodel(self, chat: BaseChatModel) -> None:
self._chat = chat

def query_without_history(self, q: str) -> BaseMessage:
response = self._chat([HumanMessage(content=q)])
return response

def query(self, q: str, raw: bool = False) -> BaseMessage:
prompt = q if raw else self._user_prompt_template.format(text=q)
response = self._query(prompt)
Expand Down Expand Up @@ -117,6 +110,10 @@ class Chat2PlotBase:
def session(self) -> ChatSession:
raise NotImplementedError()

@property
def function_call(self) -> bool:
return False

def query(self, q: str, config_only: bool = False, show_plot: bool = False) -> Plot:
raise NotImplementedError()

Expand Down Expand Up @@ -172,6 +169,10 @@ def __init__(
def session(self) -> ChatSession:
return self._session

@property
def function_call(self) -> bool:
return self._function_call

def query(self, q: str, config_only: bool = False, show_plot: bool = False) -> Plot:
raw_response = self._session.query(q)

Expand Down Expand Up @@ -326,21 +327,25 @@ def chat2plot(
function_call: bool | Literal["auto"] = "auto",
language: str | None = None,
description_strategy: str = "head",
verbose: bool = False,
custom_deserializer: ModelDeserializer | None = None,
verbose: bool = False,
) -> Chat2PlotBase:
"""Create Chat2Plot instance.
Args:
df: Data source for visualization.
schema_definition: Type of json format. "vega" for a vega-lite compliant format, or "simple" or a simpler format.
schema_definition: Type of json format; "vega" for vega-lite compliant json, "simple" for chat2plot built-in
data structure. If you want a custom schema definition, pass a type inheriting from pydantic.BaseModel
as your own chart setting.
chat: The chat instance for interaction with LLMs.
If omitted, `ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613")` will be used.
function_call:
language: Language of explanations. If not specified, it will be automatically inferred from user prompts.
description_strategy: Type of how the information in the dataset is embedded in the prompt.
Defaults to "head" which embeds the contents of df.head(5) in the prompt.
"dtypes" sends only columns and types to LLMs and does not send the contents of the dataset,
which allows for privacy but may reduce accuracy.
custom_deserializer: A custom function to convert the json returned by the LLM into a object.
verbose: If `True`, chat2plot will output logs.
Returns:
Expand Down
5 changes: 4 additions & 1 deletion chat2plot/dataset_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ def description(


def description_by_head(df: pd.DataFrame, num_rows: int = 5) -> str:
head_part = str(df.sample(num_rows, random_state=0).to_markdown())
if len(df) < num_rows:
head_part = str(df.to_markdown())
else:
head_part = str(df.sample(num_rows, random_state=0).to_markdown())

return dedent(
f"""
Expand Down
15 changes: 0 additions & 15 deletions example/streamlit_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import pandas as pd
import streamlit as st
from langchain.chat_models import ChatOpenAI
from plotly.graph_objs import Figure
from pydantic import BaseModel
from streamlit_chat import message
Expand Down Expand Up @@ -52,18 +51,6 @@ def initialize_logger():
if "logger" not in st.session_state:
st.session_state["logger"] = initialize_logger()

with st.sidebar:
model_name = st.selectbox(
"Model type",
(
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0613",
"gpt-4-32k-0613",
),
index=0,
)

api_key = st.text_input("Step1: Input your OpenAI API-KEY", value="")
csv_file = st.file_uploader("Step2: Upload csv file", type={"csv"})

Expand Down Expand Up @@ -111,8 +98,6 @@ def reset_history():

c2p = st.session_state["chat"]

c2p.session.set_chatmodel(ChatOpenAI(temperature=0, model_name=model_name))

chat_container = st.container()
input_container = st.container()

Expand Down
140 changes: 131 additions & 9 deletions tests/test_chat2plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pandas as pd
import pydantic
import pytest
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.schema import FunctionMessage

from chat2plot import schema, chat2plot, PlotConfig
from chat2plot import PlotConfig, chat2plot, schema


@pytest.mark.parametrize(
Expand All @@ -10,15 +13,17 @@
"Average price per category",
"カテゴリごとの平均価格",
"avg price for each category",
"Show me average price per category in bar chart."
]
"Show me average price per category in bar chart.",
],
)
def test_plot_bar(prompt):
df = pd.DataFrame({
"category": ["A", "B", "C", "A", "B"],
"price": [100, 200, 100, 150, 250],
"x": [1, 2, 3, 4, 5]
})
def test_plot_bar(prompt: str):
df = pd.DataFrame(
{
"category": ["A", "B", "C", "A", "B"],
"price": [100, 200, 100, 150, 250],
"x": [1, 2, 3, 4, 5],
}
)

plot = chat2plot(df)
ret = plot.query(prompt, config_only=True)
Expand All @@ -29,3 +34,120 @@ def test_plot_bar(prompt):
assert config.y.column == "price"
assert config.y.aggregation == schema.AggregationType.AVG


def test_vega_json():
df = pd.DataFrame(
{
"date": [
"2021-01-01",
"2021-02-02",
"2021-02-03",
"2021-02-04",
"2021-02-05",
],
"price": [100, 200, 300, 400, 500],
"x": [1, 2, 3, 4, 5],
}
)
plot = chat2plot(df, schema_definition="vega")
ret = plot.query("Daily total sales in line chart", config_only=True)

assert isinstance(ret.config, dict)

# https://vega.github.io/vega-lite/docs/line.html#line-chart
expected = {
"mark": "line",
"encoding": {
"x": {"field": "date", "type": "temporal"},
"y": {"field": "price", "aggregate": "sum", "type": "quantitative"},
},
}
assert ret.config["mark"] == expected["mark"]
assert ret.config["encoding"]["x"] == expected["encoding"]["x"]
assert ret.config["encoding"]["y"] == expected["encoding"]["y"]


class CustomChartConfig(pydantic.BaseModel):
chart_type: str
x_axis_name: str
y_axis_name: str
y_axis_aggregate: str


def test_custom_schema():
df = pd.DataFrame(
{
"date": [
"2021-01-01",
"2021-02-02",
"2021-02-03",
"2021-02-04",
"2021-02-05",
],
"price": [100, 200, 300, 400, 500],
"x": [1, 2, 3, 4, 5],
}
)
plot = chat2plot(df, schema_definition=CustomChartConfig)
ret = plot.query("Daily total sales in line chart", config_only=True)

assert isinstance(ret.config, CustomChartConfig)
assert ret.config.chart_type == "line"
assert ret.config.x_axis_name == "date"
assert ret.config.y_axis_name == "price"
assert ret.config.y_axis_aggregate.lower() == "sum"


def test_function_call():
df = pd.DataFrame(
{
"date": [
"2021-01-01",
"2021-02-02",
"2021-02-03",
"2021-02-04",
"2021-02-05",
],
"price": [100, 200, 300, 400, 500],
"x": [1, 2, 3, 4, 5],
}
)

for function_call in [False, True, "auto"]:
plot = chat2plot(df, function_call=function_call)
if function_call == "auto":
assert plot.function_call
else:
assert plot.function_call == function_call
ret = plot.query("Daily total sales in line chart", config_only=True)
assert ret.config.chart_type == schema.ChartType.LINE
assert ret.config.x.column == "date"
assert ret.config.y.column == "price"
assert ret.config.y.aggregation == schema.AggregationType.SUM

if plot.function_call:
assert any(
isinstance(msg, FunctionMessage) for msg in ret.conversation_history
)
else:
assert not any(
isinstance(msg, FunctionMessage) for msg in ret.conversation_history
)


def test_function_call_auto():
chat = ChatOpenAI(model_name="gpt-3.5-turbo")
plot = chat2plot(pd.DataFrame(), chat=chat)
assert not plot.function_call

chat = ChatOpenAI(model_name="gpt-4")
plot = chat2plot(pd.DataFrame(), chat=chat)
assert not plot.function_call

chat = ChatOpenAI(model_name="gpt-3.5-turbo-0613")
plot = chat2plot(pd.DataFrame(), chat=chat)
assert plot.function_call

chat = AzureChatOpenAI(openai_api_base="azure", openai_api_version="dummy")
plot = chat2plot(pd.DataFrame(), chat=chat)
assert not plot.function_call
Loading

0 comments on commit 4d1c940

Please sign in to comment.