In [1]:
import anthropic
import pandas as pd
import json
import plotly.express as px
import re
import plotly.graph_objects as go
from typing import Dict, Any, Optional

In [2]:
def fake_data():
    dic = {'name':['a','b','c','d','e'],
       'city':['ny','cf','ny','cf','ny']}
    df = pd.DataFrame(dic)
    df_result = df.groupby('city')['name'].count().reset_index()
    df_result = df_result.rename(columns={'name': 'count'})
    return df_result

In [3]:
def df_info(df_result,query,graph_type):
    dataframe_info = {
    "columns": df_result.columns.to_list(),
    "data_types": df_result.dtypes.to_dict(),
    "query": query,
    "graph_type":graph_type
}
    return dataframe_info

In [4]:
def call_llm(api_key,dataframe_info):
    client = anthropic.Anthropic(api_key=api_key)
    dataframe_str = json.dumps(dataframe_info, indent=2, default=str)
    msg = f"""
    I have result in the form of dataframe based on query:{dataframe_info['query']}, the :

    {dataframe_str}

    Based on the attached dataframe from above query, please recommend:
    x-axis and y-axis values  (important use column names same as in {dataframe_info['columns']}), and any grouping/coloring in a json format
    remove unnecessary text like "based on ....", "here is my recomendation..." etc 

    Output format should be a dictionary  as 'x_axis' :'', 'y_axis':'', 'color': '', 'title':'', 'x_label': '', 'y_label': ''
    and also provide plotly code in python
    """
    # print(msg)
    message = client.messages.create(
    model="claude-opus-4-20250514",
    max_tokens=1024,
    messages=[
        {"role": "user", "content": msg}])
    
    return message

In [5]:
def execute_llm_chart_code(llm_response: str, 
                          dataframe: Optional[pd.DataFrame] = None, 
                          data_values: Optional[Dict[str, int]] = None) -> go.Figure:
    
    # Extract JSON configuration (optional, for metadata)
    json_match = re.search(r'```json\s*(\{.*?\})\s*```', llm_response, re.DOTALL)
    config = {}
    if json_match:
        try:
            config = json.loads(json_match.group(1))
            print(f"Found configuration: {config}")
        except json.JSONDecodeError:
            print("Could not parse JSON configuration")
    
    # Extract Python code
    python_match = re.search(r'```python\s*(.*?)\s*```', llm_response, re.DOTALL)
    if not python_match:
        raise ValueError("No Python code found in LLM response")
    
    python_code = python_match.group(1)
    # print("Extracted Python code:")
    # print(python_code)

    df = dataframe
    # print(f"Using provided dataframe with shape: {df.shape}")
    # print(f"Columns: {list(df.columns)}")
    
    # Create execution environment with necessary imports and variables
    exec_globals = {
        'go': go,
        'pd': pd,
        'df': df,
        'fig': None  # Will be set by the executed code
    }
    
    # Execute the Python code
    try:
        exec(python_code, exec_globals)
        fig = exec_globals.get('fig')
        
        if fig is None:
            raise ValueError("Code execution did not create a 'fig' variable")
            
        return fig
        
    except Exception as e:
        raise RuntimeError(f"Error executing Python code: {str(e)}")



In [None]:
def create_graph(df,query,graph_type):
    dataframe_info = df_info(df,query,graph_type)
    message = call_llm("API key",dataframe_info)
    message = message.content[0].text
    fig = execute_llm_chart_code(message,df)
    fig.show()

In [7]:
df = fake_data()
query = "No of people from each city"
graph_type = "bar"
create_graph(df,query,graph_type)

Found configuration: {'x_axis': 'city', 'y_axis': 'count', 'color': '', 'title': 'Number of People by City', 'x_label': 'City', 'y_label': 'Number of People'}
