In [None]:
%pip install google-adk
%pip install litellm

In [None]:
pip install pandas python-dotenv duckdb numpy plotly

In [1]:
import os
from dotenv import load_dotenv

load_dotenv()

api_key_openai=os.getenv("OPENAI_API_KEY")
api_key_google=os.getenv("GOOGLE_API_KEY")

if api_key_openai and api_key_google:
    print("Keys loaded")
else:
    print("Keys are not loaded")

Keys loaded


In [2]:
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "False"
# os.environ["GOOGLE_API_KEY"] = "YOUR_API_KEY"  # Replace with your actual key here OR add it to your .env 

In [3]:
from google.adk.agents import Agent
from google.genai.types import GenerationConfig
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai.types import Content, Part
from google.adk.tools import google_search
from google.adk.tools.tool_context import ToolContext
from google.adk.models.lite_llm import LiteLlm # For multi-model support
import asyncio

# from google.genai import types

In [4]:
import json
import pandas as pd
import numpy as np
import duckdb
from typing import TypedDict, Optional, List, Dict, Union, Any
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

In [5]:
def preprocess_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    """Robust preprocessing of DataFrame to handle empty values."""
    # Convert empty strings and whitespace to None
    df = df.replace(r'^\s*$', None, regex=True)
    # Convert NaN strings to None
    df = df.replace(['nan', 'NaN', 'null'], None)
    # Convert pandas NaN to None
    df = df.where(pd.notnull(df), None)
    return df


def preview_excel_structure(input_str: str, tool_context: ToolContext) -> str:
    """
    Use this first to examine the Excel file structure and data types. 
    The input should be a JSON string with format: {"file_name": "your_file.xlsx"}
    """
    try:
        data = json.loads(input_str)
        file_name = data.get("file_name")
        if not file_name:
            return json.dumps({"error": "File name must be provided"})

        df = pd.read_excel(file_name)
        df = preprocess_dataframe(df)  # Apply preprocessing
        df_sample = df.head(3).astype(str)

        print("✅ Preview successful")
        display(df_sample)

        result = {
            "columns": df.columns.tolist(),
            "dtypes": df.dtypes.astype(str).to_dict(),
            "sample_rows": df_sample.to_dict(orient="records")
        }

        # Persist result to state
        tool_context.state["preview_structure"] = result
        tool_context.state["file_name"] = file_name  # also store file name for later tools
        tool_context.state["full_data_rows"] = df.to_dict(orient="records")


        # print("🔍 Tool Context State:\n", tool_context.state.to_dict())

        return json.dumps({ "result": result })

    except Exception as e:
        return json.dumps({ "error": str(e) })


In [6]:
def complex_duckdb_query(input_data: dict, tool_context: ToolContext) -> str:
    """
    Use this tool for SQL operations (GROUP BY, aggregations, etc.).

    Args:
        input_data (dict): A dictionary with the following keys:
            - "file_name" (str): Name of the Excel file to query (optional if already stored in state).
            - "query" (str): SQL query to run against the file.

    Returns:
        str: JSON string with structure:
            {
                "status": "success",
                "message": "...",
                "result": {
                    "columns": [...],
                    "rows": [...]
                }
            }
            or an error message.
    """

    try:
        # Validate input
        if not isinstance(input_data, dict):
            return json.dumps({"error": "Input must be a dictionary."})

        query = input_data.get("query")
        
        file_name = input_data.get("file_name") or tool_context.state.get("file_name")

        if not file_name:
            return json.dumps({"error": "'file_name' must be provided or available in state."})

        if not isinstance(query, str) or not query.strip():
            return json.dumps({"error": "'query' must be a non-empty string."})

        print("\n🔍 Executing DuckDB query:")
        print(query)
        print(file_name)

        # df = pd.read_excel(file_name)
        df = pd.DataFrame(tool_context.state["full_data_rows"])

        df = preprocess_dataframe(df)

        with duckdb.connect() as con:
            con.register("data", df)
            query = query.replace(file_name, "data")
            result = con.execute(query).fetchdf()
            print("✅ Query successful")
            display(result)

            if isinstance(result, pd.DataFrame):
                df_processed = result.copy()
                df_processed = df_processed.replace([float('inf'), -float('inf')], None)
                df_processed = df_processed.where(pd.notna(df_processed), None)

                for column in df_processed.columns:
                    if df_processed[column].dtype == 'object':
                        df_processed[column] = df_processed[column].apply(
                            lambda x: str(x) if x is not None else None
                        )

                result_dict = {
                    "columns": df_processed.columns.tolist(),
                    "rows": df_processed.to_dict(orient="records")
                }

                # Store in state
                tool_context.state["last_query"] = query
                tool_context.state["query_result"] = result_dict

                # print("🔍 Tool Context State:\n", tool_context.state.to_dict())

                return json.dumps({
                    "status": "success",
                    "message": "Query executed successfully. You can now proceed to create the visualization.",
                    "result": result_dict
                })

            else:
                tool_context.state["last_query"] = query
                tool_context.state["query_result"] = str(result)
                return json.dumps({"status": "success", "result": str(result)})

    except Exception as e:
        return json.dumps({"error": str(e)})


In [7]:
def create_visualization(input_data: dict, tool_context: ToolContext) -> str:
    """
    Create and immediately display various types of visualizations using Plotly.
    
    Args:
        input_data (dict): A dictionary with the following keys:
            {
              "data": {
                "result": {
                  "rows": [...],
                  "columns": [...]
                }
              },
              "plot_type": "...",
              "x": "...",
              "y": "...",
              "title": "...",
              "color": "...",
              "source": "...",    # For Sankey
              "target": "...",    # For Sankey
              "value": "..."      # For Sankey
            }
        tool_context (ToolContext): Injected by ADK to access session state.

    Returns:
        str: A success message or an error.
    """
    try:
        print("📊 create_visualization tool was called.")
        
        # Fallback from context
        if input_data is None and tool_context:
            input_data = tool_context.state.get("query_result")
            print("📥 Used query_result from session state.")

        if not input_data:
            return "❌ Error: No input data provided for visualization."

        print("📊 Input received for visualization:", json.dumps(input_data, indent=2))
        
        params = input_data
        data = params.get("data")
        plot_type = params.get("plot_type", "line")
        x = params.get("x")
        y = params.get("y")
        title = params.get("title", "")
        color = params.get("color")
        orientation = params.get("orientation", "v")
        barmode = params.get("barmode", "group")
        size = params.get("size")
        nbins = params.get("nbins")
        source = params.get("source")
        target = params.get("target")
        value = params.get("value")

        # Validate and convert
        if not data or not x:
            return "❌ Error: Missing required parameters: data and x."
        
        if isinstance(data, dict) and "result" in data and "rows" in data["result"]:
            df = pd.DataFrame(data["result"]["rows"])
        else:
            return "❌ Error: Invalid data format."

        # Layout
        layout_settings = {
            'title': {'text': title, 'x': 0.5, 'xanchor': 'center', 'font': dict(size=16)},
            'plot_bgcolor': 'white',
            'paper_bgcolor': 'white',
            'font': dict(size=12),
            'margin': dict(l=50, r=50, t=50, b=100),
            'height': 800,
            'width': 1600,
            'template': 'plotly_white'
        }

        if plot_type == "line":
            fig = px.line(df, x=x, y=y, color=color, title=title)
            fig.update_traces(mode='lines+markers')
        elif plot_type == "bar":
            fig = px.bar(df, x=x, y=y, color=color, title=title, barmode=barmode, orientation=orientation)
        elif plot_type == "scatter":
            fig = px.scatter(df, x=x, y=y, color=color, size=size, title=title)
        elif plot_type == "box":
            fig = px.box(df, x=x, y=y, color=color, title=title)
        elif plot_type == "histogram":
            fig = px.histogram(df, x=x, color=color, nbins=nbins, title=title)
        elif plot_type == "pie":
            fig = px.pie(df, names=x, values=y if y else None, color=x, title=title)
        elif plot_type == "heatmap":
            if len(df.columns) < 3:
                pivot_df = df.pivot(index=y, columns=x, values=color if color else 'value')
                fig = px.imshow(pivot_df, title=title)
            else:
                fig = px.imshow(df, title=title)
        elif plot_type == "sankey":
            if not (source and target and value):
                return "❌ Error: Sankey diagrams require 'source', 'target', and 'value'."
            
            # Sankey: compute unique node labels and indices
            unique_nodes = list(dict.fromkeys(df[source].tolist() + df[target].tolist()))
            source_indices = df[source].apply(lambda x: unique_nodes.index(x))
            target_indices = df[target].apply(lambda x: unique_nodes.index(x))

            # Add total flow per node
            label_with_values = []
            for node in unique_nodes:
                total_val = df[df[source] == node][value].sum() + df[df[target] == node][value].sum()
                label_with_values.append(f"{node}\n({total_val:,.2f})")

            fig = go.Figure(data=[go.Sankey(
                arrangement="freeform",
                domain=dict(x=[0, 1], y=[0.1, 0.90]),
                node=dict(
                    pad=70, thickness=30,
                    line=dict(color="black", width=0.9),
                    label=label_with_values
                ),
                link=dict(
                    source=source_indices,
                    target=target_indices,
                    value=df[value]
                )
            )])
            fig.update_layout(
                title_text=title,
                font=dict(size=16, family="Arial, sans-serif", color="black"),
                plot_bgcolor='white',
                paper_bgcolor='white',
                margin=dict(l=50, r=50, t=50, b=50)
            )
        else:
            return f"❌ Error: Unsupported plot type: {plot_type}."

        fig.update_layout(layout_settings)

        if plot_type not in ['pie', 'heatmap']:
            fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', title_text=x.replace('_', ' ').title())
            if y:
                fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', title_text=y.replace('_', ' ').title())

        fig.show()

        if tool_context:
            tool_context.state["last_visualization"] = {
                "plot_type": plot_type,
                "x": x, "y": y, "title": title, "color": color
            }

        # print("🔍 Tool Context State:\n", tool_context.state.to_dict())

        return "✅ Visualization displayed successfully."

    except Exception as e:
        return f"❌ Visualization error: {str(e)}."


In [8]:
llm = LiteLlm(
    model="openai/gpt-4o",
    temperature=0.0
)

root_agent = Agent(
    name="data_agent",
    model=llm,
    # model="gemini-2.0-flash-exp", #gemini-2.0-flash-exp, gemini-2.5-pro-preview-03-25 gemini-2.5-pro-exp-03-25
    description=(
        "Agent to answer data questions and create visualizations."
    ),
    instruction=(
        """
You will be given a task to perform. You must follow these exact steps in order:

1. Always call the `preview_excel_structure` tool first to get the Excel columns. Match user input terms to Excel column names. You MUST print the matched values before continuing.
2. Use the matched column names to build SQL query using the `complex_duckdb_query` tool. 

**IMPORTANT:** 
   - The table is registered as `data`. 
   - DO NOT use the file name or sheet name. 
   - Always query using `FROM data`.

   
3. Use the results of the SQL query to generate a Plotly visualization using the `create_visualization` tool. 
   - If user asks for a specific chart type (e.g., Sankey), use that.
   - ALWAYS convert wide-format data (e.g., a single row of many metrics) into long-format JSON before calling create_visualization.

**Visualization Output Format (CRITICAL):**

{
  "data": { "result": { "columns": [...], "rows": [...] } },
  "plot_type": "...",
  "x": "column_name_for_x_axis",
  "y": "column_name_for_y_axis",
  "title": "...",
  "color": "...",
}


**Sankey example:**
 {
  "data": {
    "result": {
      "rows": [
        { "source": "Company", "target": "Contacted by recruiter", "value": 1 },
        { "source": "Contacted by recruiter", "target": "Recruiter interview", "value": 16 },
        { "source": "Recruiter interview", "target": "1st round", "value": 7 },
        { "source": "1st round", "target": "Challenge / Assignment", "value": 4 },
        { "source": "Challenge / Assignment", "target": "2nd round", "value": 4 },
        { "source": "2nd round", "target": "3rd round", "value": 1 },
        { "source": "3rd round", "target": "Opening closed / put on hold", "value": 1 },
        { "source": null, "target": "1st round", "value": 6 },
      ],
      "columns": ["source", "target", "value"]
    }
  },
  "plot_type": "sankey",
  "title": "Job Application Process Transitions",
  "x": "source",
  "source": "source",
  "target": "target",
  "value": "value"
}
   
4. Display the visualization.
5. Evaluate if the task is completed. If the visualization was shown successfully OR if task completed, stop execution!

        """
    ),
    tools=[preview_excel_structure, complex_duckdb_query, create_visualization],
    output_key="last_agent_response",
)


In [None]:
# Setup the session and runner
session_service = InMemorySessionService()
app_name = "viz_app"
user_id = "jeny"
session_id = "session_viz_001"
session = session_service.create_session(app_name=app_name, user_id=user_id, session_id=session_id)

# session = session_service.create_session(app_name=app_name, user_id=user_id)

runner = Runner(agent=root_agent, app_name=app_name, session_service=session_service)

# Create the user message
user_message = Content(role="user", parts=[Part(text="""
Use file: data_export.xlsx to
    1. Create and display a visualization which shows how the total Forcast flows through the unique groups in Channel and subchannel`
    2. You must follow this sankey flow: Source (Total forecast) -> target (Channel); Source (Channel) -> target (sub-channel)
    3. Visualization name: Total Forecast flow through Channel, Sub Channel

""")])

# Run and display the final response
for event in runner.run(user_id=user_id, session_id=session.id, new_message=user_message):
    if event.is_final_response():
        if event.content and event.content.parts:
            print(event.content.parts[0].text)
        else:
            print("⚠️ No final text response was returned by the agent.")
