In [None]:
import os
import subprocess

# Attempt to detect Colab environment
is_colab = 'google.colab' in str(get_ipython())

if is_colab:
    from google.colab import userdata
    print("Running on CoLab")
    tavily_key = userdata.get('TAVILY_API_KEY')
    oai_key = userdata.get('OPENAI_API_KEY')
else:
    print("Not running on CoLab, attempting to load keys from environment variables.")
    tavily_key = os.environ.get('TAVILY_API_KEY')
    oai_key = os.environ.get('OPENAI_API_KEY')

# Install necessary packages
!pip install -qU "langchain-community>=0.2.11" tavily-python openpyxl scipy scikit-learn xhtml2pdf joblib
!pip install -qU langchain langchain-core langchain-openai langchain_experimental langgraph chromadb pydantic python-dotenv tiktoken openpyxl scipy scikit-learn xhtml2pdf joblib

if tavily_key:
    os.environ["TAVILY_API_KEY"] = tavily_key
else:
    print("TAVILY_API_KEY not found.")

if not oai_key:
    print("OPENAI_API_KEY not found.")

def check_and_install(package_name):
    try:
        subprocess.check_output(['pip', 'show', package_name])
    except subprocess.CalledProcessError:
        print(f"{package_name} not found, installing...")
        subprocess.check_output(['pip', 'install', package_name])
    else:
        print(f"{package_name} already installed.")

# Ensure pydantic is upgraded if necessary
!pip install -U pydantic

In [None]:
import os
import subprocess
import pandas as pd
from pathlib import Path
from pprint import pprint
import uuid
from collections import OrderedDict
from tempfile import TemporaryDirectory
import io
import numpy as np
import json

from typing import Dict, Optional, List, Tuple, Union, Annotated, Literal
from pydantic import BaseModel, Field, validator, model_validator

from langchain_core.tools import tool, InjectedToolArg
from langchain.tools import Tool
from langchain_experimental.utilities import PythonREPL
from typing_extensions import TypedDict

from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt.chat_agent_executor import AgentState
from langgraph.prebuilt import create_react_agent
from langchain_core.runnables.config import RunnableConfig
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, trim_messages, ToolMessageChunk, ToolCall
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore
from langchain_core.stores import BaseStore
from langchain_core.prompts import MessagesPlaceholder
from langgraph.types import Command
from scipy import stats
import kagglehub
from IPython.display import Image, display
from tavily import TavilyClient # Added for search_web_for_context tool

In [None]:
_TEMP_DIRECTORY = TemporaryDirectory()
WORKING_DIRECTORY = Path(_TEMP_DIRECTORY.name)
print(f"Working directory set to: {WORKING_DIRECTORY}")

In [None]:
# Pydantic models, DataFrameRegistry, global_df_registry, State model
# Imports like pydantic, typing, uuid, OrderedDict, Path, AgentState, RunnableConfig, BaseChatModel, PromptTemplate are assumed from essential_imports_cell

class AnalysisConfig(BaseModel):
    """User-configurable settings for the data analysis workflow."""
    default_visualization_style: str = Field("seaborn-v0_8-whitegrid", description="Default style for matplotlib/seaborn visualizations.")
    report_author: Optional[str] = Field(None, description="Author name to include in generated reports.")
    datetime_format_preference: str = Field("%Y-%m-%d %H:%M:%S", description="Preferred format for datetime string representations.")
    large_dataframe_preview_rows: int = Field(5, description="Number of rows for previewing large dataframes.")
    # default_correlation_method: str = Field("pearson", description="Default method for correlation.")
    # automatic_outlier_removal: bool = Field(False, description="Whether to automatically remove outliers found.")

class CleaningMetadata(BaseModel):
    """Metadata about the data cleaning actions taken."""
    steps_taken: list[str] = Field(description="List of cleaning steps performed.")
    data_description_after_cleaning: str = Field(description="Brief description of the dataset after cleaning.")

class InitialDescription(BaseModel):
    """Initial description of the dataset."""
    dataset_description: str = Field(description="Brief description of the dataset.")
    data_sample: Optional[str] = Field(description="Sample of the data (first few rows).")

class AnalysisInsights(BaseModel):
    """Insights from the exploratory data analysis."""
    summary: str = Field(description="Overall summary of EDA findings.")
    correlation_insights: str = Field(description="Key correlation insights identified.")
    anomaly_insights: str = Field(description="Anomalies or interesting patterns detected.")
    recommended_visualizations: list[str] = Field(description="List of recommended visualizations to illustrate findings.")
    recommended_next_steps: Optional[List[str]] = Field(None, description="List of recommended next analysis steps or questions to investigate based on the findings.")

class VisualizationResults(BaseModel):
    """Results from the visualization generation."""
    visualizations: List[dict] = Field(description="List of visualizations generated. Each dictionary should have the plot type and the base64 encoded image")

class ReportResults(BaseModel):
    """Results from the report generation."""
    report_path: str = Field(description="Path to the generated report file.")

class DataQueryParams(BaseModel):
    """Parameters for querying the DataFrame."""
    columns: List[str] = Field(..., description="List of columns to include in the output")
    filter_column: Optional[str] = Field(None, description="Column to apply the filter on")
    filter_value: Optional[str] = Field(None, description="Value to filter the rows by")
    operation: str = Field("select", description="Operation to perform: 'select', 'sum', 'mean', 'count', 'max', 'min', 'median', etc.")

# schema_extra_cols and schema_extra_cells removed as they are no longer used.

class CellIdentifier(BaseModel):
    """Identifies a single cell by row index and column name."""
    row_index: int = Field(..., description="Row index of the cell.")
    column_name: str = Field(..., description="Column name of the cell.")

class GetDataParams(BaseModel):
    """Parameters for retrieving data from the DataFrame."""
    df_id: str = Field(..., description="DataFrame ID in the global registry.")
    index: Union[int, List[int], Tuple[int, int]] = Field(..., description="Specifies the rows to retrieve. Can be: 1) A single integer for one row. 2) A list of integers for multiple specific rows. 3) A 2-element tuple `(start, end)` for a range of rows (inclusive).")
    columns: Union[str, List[str]] = Field("all", description="A string (single column), a list of strings (multiple columns), or 'all' for all columns (default: 'all').")
    cells: Optional[List[CellIdentifier]] = Field(None, description="A list of cell identifier objects, each specifying a 'row_index' and 'column_name'.")

    @model_validator(mode='before')
    def validate_index(cls, values):
        index = values.get('index')
        if not isinstance(index, (int, list, tuple)):
            raise ValueError("Invalid 'index' type. Must be int, list, or tuple.")
        if isinstance(index, tuple) and len(index) != 2:
            raise ValueError("Invalid tuple length for 'index'. Must be a 2-tuple for range.")
        if isinstance(index, list) and not all(isinstance(i, int) for i in index):
            raise ValueError("Invalid list elements for 'index'. Must contain only integers.")
        return values

class DataFrameRegistry:
    def __init__(self, capacity=20):
        self.registry: Dict[str, dict] = {}
        self.df_id_to_raw_path: Dict[str, str] = {}
        self.cache = OrderedDict() 
        self.capacity = capacity

    def register_dataframe(self, df=None, df_id=None, raw_path=""):
        if df_id is None:
            df_id = str(uuid.uuid4())
        if raw_path == "":
            # WORKING_DIRECTORY is now guaranteed to be defined before this class is instantiated
            raw_path = WORKING_DIRECTORY / f"{df_id}.csv"
        self.registry[df_id] = {"df": df, "raw_path": str(raw_path)}
        self.df_id_to_raw_path[df_id] = str(raw_path)
        if df is not None:
            self.cache[df_id] = df
            if len(self.cache) > self.capacity:
                self.cache.popitem(last=False)
        return df_id

    def get_dataframe(self, df_id: str, load_if_not_exists=False):
        if df_id in self.cache:
            self.cache.move_to_end(df_id)
            return self.cache[df_id]
        
        if df_id in self.registry:
            df_data = self.registry[df_id]
            df = df_data.get("df")
            if df is not None:
                self.cache[df_id] = df
                if len(self.cache) > self.capacity:
                    self.cache.popitem(last=False)
                return df
            elif load_if_not_exists and df_data.get("raw_path"):
                try:
                    loaded_df = pd.read_csv(df_data["raw_path"])
                    self.registry[df_id]["df"] = loaded_df
                    self.cache[df_id] = loaded_df
                    if len(self.cache) > self.capacity:
                        self.cache.popitem(last=False)
                    return loaded_df
                except FileNotFoundError:
                    return None 
                except Exception as e:
                    print(f"Error loading DataFrame from {df_data['raw_path']}: {e}")
                    return None
        return None

    def remove_dataframe(self, df_id: str):
        if df_id in self.registry:
            del self.registry[df_id]
            if df_id in self.cache:
                del self.cache[df_id]
            del self.df_id_to_raw_path[df_id]
    def get_raw_path_from_id(self, df_id: str):
        return self.df_id_to_raw_path.get(df_id)

global_df_registry = DataFrameRegistry()

class State(AgentState):
  next: str
  user_prompt: str
  df_ids: List[str] = Field(default_factory=list)
  _config: Optional[RunnableConfig] = None
  initial_description: Optional[InitialDescription] = None
  cleaning_metadata: Optional[CleaningMetadata] = None
  analysis_insights: Optional[AnalysisInsights] = None
  initial_analysis_agent: Optional[BaseChatModel] = None
  data_cleaner_agent: Optional[BaseChatModel] = None
  analyst_agent: Optional[BaseChatModel] = None
  initial_analysis_complete: Optional[bool] = False
  data_cleaning_complete: Optional[bool] = False
  analyst_complete: Optional[bool] = False
  file_writer_complete: Optional[bool] = False
  _count_: int = 0
  _id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
  visualization_results: Optional[VisualizationResults] = None
  visualization_complete: Optional[bool] = False
  report_results: Optional[ReportResults] = None
  report_generator_complete: Optional[bool] = False
  analysis_config: Optional[AnalysisConfig] = Field(None, description="User-defined analysis configurations.")


In [None]:
data_cleaner_prompt_template = PromptTemplate(
    input_variables=['dataset_description', 'data_sample', 'tool_descriptions', 'output_format', 'available_df_ids'],
    template="""You are a Data Cleaner Agent equipped with tools to clean and preprocess a dataset.\n    The df_ids you can use to access the data are: {available_df_ids}\n\n    Here's a description of the dataset: {dataset_description}\n    Here's a sample of the data (first few rows):\n    {data_sample}\n\n    Your available tools are:\n    {tool_descriptions}\n\n    You have access to a tool called `query_dataframe` for querying the DataFrame.\n    Refer to the tool's docstring for detailed usage instructions and examples.\n    You can get the column names with `get_column_names`, and get any specified data from the dataframe using `get_data`.\n\n    Identify potential issues like missing values, outliers, incorrect data types, and inconsistencies.\n    Propose and execute a cleaning strategy, step-by-step, using the provided tools.\n    For each step, clearly state the tool you are using and the parameters.\n    Explain your reasoning for each cleaning step.\n\n    Example Plan & Execution:\n    Step 1: Check for missing values in each column using the 'CheckMissingValues' tool with input 'df_id'.\n    Step 2: If 'Age' column has missing values, fill them using the 'FillMissingMedian' tool with input 'df_id' and 'column_name' as 'Age'.\n    ... and so on.\n\n    After cleaning, summarize the actions taken and describe the current state of the dataset in a structured JSON format following this schema:\n    {output_format}\n\n    Let's begin! What is your data cleaning plan and execution using tools, and structured output?\n    """,
)

analyst_prompt_template_initial = PromptTemplate(
    input_variables=['user_prompt', 'tool_descriptions', 'output_format','available_df_ids'],
    template="""You are an Data Describer and Sampler. Your role is to perform exploratory data analysis (EDA) on a dataset.\n    Here's a text description of the dataset: {user_prompt}\n\n    First, we need a basic description of the dataset, along with a sample of the data, to pass to the Data Cleaner Agent.\n\n    The df_ids you can use to access the data are: {available_df_ids}\n\n    Your available tools are:\n    {tool_descriptions}\n\n    Describe the dataset in a structured JSON format following this schema:\n    {output_format}\n    FIRST, think of a step by step plan for how to proceed with collecting the data you need, when to stop collecting, and how and when to report back in the requested format.\n\n    Using your tools in a conservative manner, please 1. Write a text description of the dataset based on a few strategically used tool calls, and write it to dataset_description attribute.\n    Then, 2. use your tools to produce a sample of the data for the data_sample attribute.\n\n    Do not make unnecessary or repeated tool calls! Report straight to the Supervisor with the expected output format.\n\n    After performing any necessary tool calls, IMMEDIATELY output your final answer in the specified format. Do NOT make any further tool calls!\n    You do not need to perform the functions of other agents, do your job and submit the results.\n    """,
)

analyst_prompt_template_main = PromptTemplate(
    input_variables=['cleaned_dataset_description', 'cleaning_metadata', 'tool_descriptions', 'output_format','available_df_ids'],
    template="""You are an Analyst Agent. Your role is to perform exploratory data analysis (EDA) on a cleaned dataset.\n    Here's a description of the cleaned dataset: {cleaned_dataset_description}\n    Here's metadata about the data cleaning actions taken: {cleaning_metadata}\n\n    Your available tools are:\n    {tool_descriptions}\n\n    The df_ids you can use to access the data are: {available_df_ids}\n\n    You have access to a tool called `query_dataframe` for querying the DataFrame.\n    Refer to the tool's docstring for detailed usage instructions and examples.\n    You can get the column names with `get_column_names`, and get any specified data from the dataframe using `get_data`.\n\n    Perform EDA to understand the dataset. Include:\n    - Descriptive statistics (mean, median, mode, standard deviation, etc.) for relevant columns.\n    - Identify potential correlations between features.\n    - Highlight any anomalies or interesting patterns you find.\n    - Reason step-by-step about your analysis (Chain-of-Thought).\n    - Recommend visualizations that would best illustrate your findings.\n    - Suggest potential next steps for deeper analysis or further questions to investigate based on your findings.\n\n    Output should be a summary of your EDA findings, insights, recommended visualizations, and next steps, based on your tool usage, all written into the\n    {output_format} class.\n\n    Let's begin! What are your EDA insights and visualization recommendations using tools?\n    """,
)

file_writer_prompt_template = PromptTemplate(
    input_variables=['file_name', 'content', 'file_type','tool_descriptions','available_df_ids'],
    template= """You are an agent that specializes in writing data to a file in the format of {file_type}. You are one member of a data analysis team. You ONLY write content as requested in a analyst-friendly manner. Leaver other tasks to other agents on the team.\n    You have the following tools at your disposal:\n    {tool_descriptions}\n\n    The df_ids you can use to access the data are: {available_df_ids}\n\n    Write the following content to a file named {file_name}:\n    {content}\n    """,
)
visualization_prompt_template = PromptTemplate(
    input_variables=[
        "cleaned_dataset_description",
        "analysis_insights",
        "tool_descriptions",
        "output_format",
        "available_df_ids",
    ],
    template="""You are a Visualization Agent equipped with tools to create visualizations.\n    Here's a description of the cleaned dataset: {cleaned_dataset_description}\n    Here are the insights from the Analyst Agent and the list of visualizations to create: {analysis_insights}\n\n    Your available tools are:\n    {tool_descriptions}\n\n    The df_ids you can use to access the data are: {available_df_ids}\n\n    Create the visualizations step-by-step, using the provided tools.\n    For each step, clearly state the tool you are using and the parameters.\n    Explain your reasoning for each visualization.\n\n    After creating the visualizations, summarize the actions taken and describe the current state of the visualizations in a structured JSON format following this schema:\n    {output_format}\n\n    Let's begin!\n    What is your visualization plan and execution using tools, and structured output?\n    """,
)
report_generator_prompt_template = PromptTemplate(
    input_variables=[
        "cleaning_metadata",
        "analysis_insights",
        "visualization_results",
        "tool_descriptions",
        "output_format",
        "available_df_ids",
    ],
    template="""You are a Report Generator Agent equipped with tools to generate reports.\n    Here's the metadata about the data cleaning actions taken: {cleaning_metadata}\n    Here are the insights from the Analyst Agent: {analysis_insights}\n    Here are the visualization results: {visualization_results}\n\n    Your available tools are:\n    {tool_descriptions}\n\n    The df_ids you can use to access the data are: {available_df_ids}\n\n    Generate a structured report that combines textual explanations, statistics, and visualizations.\n    Explain your reasoning for the report structure.\n\n    After generating the report, summarize the actions taken and describe the current state of the report in a structured JSON format following this schema:\n    {output_format}\n\n    Let's begin!\n    What is your report generation plan and execution using tools, and structured output?\n    """,
)

In [None]:
# Added imports for create_histogram
import matplotlib.pyplot as plt
import seaborn as sns
import base64
from io import BytesIO

# Tools from original cell 4 (rzo9i8HtDrmO)
@tool(name_or_callable="GetDataframeSchema",response_format="content_and_artifact")
def get_dataframe_schema() -> tuple[str, dict]:
    """Return a summary of the DataFrame's schema and sample data."""
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            raw_path = global_df_registry.get_raw_path_from_id(df_id)
           if raw_path:
                df = pd.read_csv(raw_path)
                global_df_registry.register_dataframe(df, df_id, raw_path)
            else:
                return f'Error: DataFrame with ID '{df_id}' not found.', {}
        schema = {
            "columns": list(df.columns), # 'df' might be undefined here
            "dtypes": df.dtypes.astype(str).to_dict(),
            "sample": df.head(3).to_dict(orient="records")
        }
        return "", {"schema": schema}
    except Exception as e:
        return f"Error: {str(e)}", {}

@tool("GetColumnNames")
def get_column_names(df_id: str) -> str:
    """Useful to get the names of the columns in the current DataFrame."""
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            raw_path = global_df_registry.get_raw_path_from_id(df_id)
            df = pd.read_csv(raw_path)
            global_df_registry.register_dataframe(df, df_id, raw_path)
        if df is None:
            return f"Error: DataFrame with ID '{df_id}' not found."
        if df.empty:
            return f"Warning: DataFrame '{df_id}' is empty. No columns available."
        cols = df.columns.tolist()
        return ", ".join(cols)
    except FileNotFoundError as e:
        return f"Error loading DataFrame from path: {e}"
    except Exception as e:
        return f"Error getting column names for DataFrame '{df_id}': {e}"

@tool("CheckMissingValues")
def check_missing_values(df_id: str) -> str:
    """Checks for missing values in a pandas DataFrame and returns a summary."""
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            raw_path = global_df_registry.get_raw_path_from_id(df_id)
            df = pd.read_csv(raw_path)
            global_df_registry.register_dataframe(df, df_id, raw_path)
        if df is None:
            return f"Error: DataFrame with ID '{df_id}' not found."
        missing = df.isnull().sum()
        if missing.sum() == 0:
            return f"No missing values in DataFrame '{df_id}'."
        return missing.to_string()
    except FileNotFoundError as e:
        return f"Error loading DataFrame from path: {e}"
    except Exception as e:
        return f"Error checking missing values for DataFrame '{df_id}': {e}"

@tool("DropColumn")
def drop_column(df_id: str, column_name: str) -> str:
    """Drops a specified column from the DataFrame."""
    pprint(f"Dropping column {column_name} from {df_id}")
    df = global_df_registry.get_dataframe(df_id)
    try:
        if df is None:
          try:
            raw_path = global_df_registry.get_raw_path_from_id(df_id)
            df = pd.read_csv(raw_path)
            raw_path = global_df_registry.get_raw_path_from_id(df_id)
            df = pd.read_csv(raw_path)
            global_df_registry.register_dataframe(df, df_id, raw_path)
          except Exception as e:
            return f"Error loading DataFrame: {e}"
        if column_name not in df.columns:
            return f"Error: Column '{column_name}' not found in DataFrame '{df_id}'. Available columns: {list(df.columns)}"
        df.drop(columns=[column_name], inplace=True)
        # Re-register to ensure cache is updated
        global_df_registry.register_dataframe(df, df_id, global_df_registry.get_raw_path_from_id(df_id))
       return "Column dropped successfully. New columns: " + ", ".join(df.columns.tolist())
    except Exception as e:
        return f"Error dropping column: {e}"

@tool
def delete_rows(df_id: str, conditions: Union[str, List[str], Dict], inplace: bool = True) -> str:
    """Deletes rows from the DataFrame based on specified conditions."""
    try:
        df = global_df_registry.get_dataframe(df_id)
        if not isinstance(conditions, (str, list, dict)):
            return f"Error: 'conditions' must be a string, list of strings, or dict. Received type: {type(conditions).__name__}"
        if df is None:
          try:
            raw_path = global_df_registry.get_raw_path_from_id(df_id)
            df = pd.read_csv(raw_path)
            global_df_registry.register_dataframe(df, df_id, raw_path)
          except Exception as e:
            return f"Error loading DataFrame: {e}"
        
        # Original df is queried, but result not assigned back if inplace is False or for intermediate steps.
        # If inplace=True, df.drop(df.index, inplace=True) would drop from the original df object in registry if mutable.
        # This logic seems complex and might not behave as expected regarding persistence.
        queried_df = df # Start with the original df
        if isinstance(conditions, str):
            queried_df = df.query(conditions) # This creates a new df
        elif isinstance(conditions, list):
            temp_df = df
            for condition in conditions:
                temp_df = temp_df.query(condition)
            queried_df = temp_df
        elif isinstance(conditions, dict):
            temp_df = df
            for condition_type, condition_list in conditions.items():
                if not isinstance(condition_list, list):
                    return f"Error: condition list for '{condition_type}' must be a list."
                for condition in condition_list:
                    temp_df = temp_df.query(condition)
            queried_df = temp_df
        else:
            return f"Error: Invalid conditions format. Received type: {type(conditions).__name__}"

        if queried_df.empty:
            return f"No rows match the provided condition(s): {conditions}"

        if inplace:
            # This will modify the DataFrame object 'df' obtained from registry if it's the same object.
            # Pandas operations often return new DataFrames, so care is needed here for true inplace modification of registry's copy.
            df.drop(queried_df.index, inplace=True) # df is the original from registry, queried_df might be a view or copy.
            return "Rows deleted successfully."
        else:
            return queried_df.to_json() 
    except Exception as e:
        return f"Error deleting rows: {e}"

@tool("FillMissingMedian")
def fill_missing_median(df_id: str, column_name: str) -> str:
    """Fills missing values in a specified column with the median."""
    pprint(f"Filling missing values in column {column_name} from {df_id}")
    df = global_df_registry.get_dataframe(df_id)
    try:
      if df is None:
        try:
          raw_path = global_df_registry.get_raw_path_from_id(df_id)
          df = pd.read_csv(raw_path)
          global_df_registry.register_dataframe(df, df_id, raw_path)
        except Exception as e:
          return f"Error loading DataFrame: {e}"
      if column_name not in df.columns:
          return f"Error: Column '{column_name}' not found in DataFrame '{df_id}'."
      if not pd.api.types.is_numeric_dtype(df[column_name]):
          return f"Error: Column '{column_name}' in DataFrame '{df_id}' is not numeric and cannot compute median."
      median_value = df[column_name].median()
      df[column_name].fillna(median_value, inplace=True) # Modified to be inplace on the actual df from registry
      return f"Missing values in column '{column_name}' filled with median: {median_value}."
    except Exception as e:
        return f"Error filling missing values: {e}"

data_cleaning_tools = [
    Tool(name="GetDataFrameSchema", func=get_dataframe_schema, description="Useful to get a summary of the DataFrame's schema and sample data. Input should be 'df_id'."),
    Tool(name="GetColumnNames", func=get_column_names, description="Useful to get the names of the columns in the current DataFrame. Input should be 'df_id'."),
    Tool(name="CheckMissingValues", func=check_missing_values, description="Useful to check for missing values in the DataFrame. Input should be 'df_id'."),
    Tool(name="DropColumn", func=drop_column, description="Useful to drop a column from the DataFrame. Input should be 'df_id' and 'column_name'."),
    Tool(name="DeleteRows", func=delete_rows, description="Useful to delete rows from the DataFrame based on specified conditions. Input should be 'df_id', 'conditions', and 'inplace'."),
    Tool(name="FillMissingMedian", func=fill_missing_median, description="Useful to fill missing values in a specific column using the median. Input should be 'df_id' and 'column_name'."),
]

# Tools from original cell 5 (8Yb-OklIFuFw)
@tool(name_or_callable="QueryDataframe",response_format="content_and_artifact")
def query_dataframe(params: DataQueryParams, df_id: str) -> tuple[str, dict]:
    """Query the DataFrame based on specified columns, filter, and operation."""
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            try:
              raw_path = global_df_registry.get_raw_path_from_id(df_id)
              df = pd.read_csv(raw_path)
              # raw_path argument was duplicated in original: register_dataframe(df, df_id, raw_path, raw_path)
              global_df_registry.register_dataframe(df, df_id, raw_path)
            except Exception as e:
                return f"Error loading DataFrame: {e}", {}

        if params.filter_column and params.filter_column not in df.columns:
            return "Error: Filter column does not exist.", {}

        if params.filter_column:
            filtered_df = df[df[params.filter_column] == params.filter_value]
        else:
            filtered_df = df

        if params.operation == "select":
            result = filtered_df[params.columns].to_dict(orient="records")
        elif params.operation == "sum":
            result = filtered_df[params.columns].sum(numeric_only=True).to_dict()
        elif params.operation == "mean":
            result = filtered_df[params.columns].mean(numeric_only=True).to_dict()
        elif params.operation == "count":
            result = filtered_df[params.columns].count().to_dict()
        else:
            return f"Unsupported operation: {params.operation}", {}

        return "Query successful.", {"result": result}
    except Exception as e:
        return f"Error: {str(e)}", {}

@tool(name_or_callable="GetData", response_format="content_and_artifact")
def get_data(params: GetDataParams, df_id: str = "") -> str:
    """Retrieves data from a DataFrame by ID, for flexible row/column selection and retrieval specific cells."""
    if not df_id: df_id = params.df_id
    elif df_id.strip() != params.df_id.strip(): return "Error: df_id mismatch."

    df = global_df_registry.get_dataframe(df_id)
    if df is None:
        try:
          raw_path = global_df_registry.get_raw_path_from_id(df_id)
          df = pd.read_csv(raw_path)
          global_df_registry.register_dataframe(df, df_id, raw_path)
        except Exception as e:
            return f"Error loading DataFrame: {e}"
    
    index, columns, cells = params.index, params.columns, params.cells
    if cells is not None:
        output_str = ""
        for cell in cells:
            row_index = cell.row_index
            col_name = cell.column_name
            val = df.loc[row_index, col_name]
            output_str += f"Value at ({row_index}, {col_name}): {val}\n"
        return output_str

    if isinstance(index, int): rows = df.iloc[[index]]
    elif isinstance(index, list): rows = df.iloc[index]
    elif isinstance(index, tuple): rows = df.iloc[index[0]:index[1]] # Slicing with iloc
    else: return "Error: Invalid index format."

    if columns == "all": columns_to_include = df.columns
    elif isinstance(columns, str): columns_to_include = [columns]
    elif isinstance(columns, list): columns_to_include = columns
    else: return "Error: Invalid columns format."

    selected_data = rows[columns_to_include]
    output_str = ""
    for row_idx, row_data in selected_data.iterrows():
        output_str += f"Row {row_idx}:\n"
        for col, val in row_data.items():
            output_str += f"  {col}: {val}\n"
    return output_str

@tool("GetDescriptiveStatistics")
def get_descriptive_statistics(df_id: str, column_names: str = "all") -> str:
    """Calculates descriptive statistics for specified columns in the DataFrame."""
    pprint(f"Getting descriptive statistics for {column_names} from {df_id}")
    df = global_df_registry.get_dataframe(df_id)
    try:
      if df is None:
        try:
          raw_path = global_df_registry.get_raw_path_from_id(df_id)
          df = pd.read_csv(raw_path)
          global_df_registry.register_dataframe(df, df_id, raw_path)
        except Exception as e:
          return f"Error loading DataFrame: {e}"
      columns_to_describe = df.columns if column_names.lower() == 'all' or not column_names else column_names.split(',')
      # Ensure all columns exist
      missing_cols = [col for col in columns_to_describe if col not in df.columns]
      if missing_cols:
          return f"Error: Columns not found: {', '.join(missing_cols)}"
      desc_stats = df[columns_to_describe].describe()
      return desc_stats.to_string()
    except Exception as e:
        return f"Error calculating descriptive statistics: {e}"

@tool("CalculateCorrelation")
def calculate_correlation(df_id: str, column1_name: str, column2_name: str) -> str:
    """Calculates the Pearson correlation coefficient between two columns."""
    pprint(f"Calculating correlation between {column1_name} and {column2_name} from {df_id}")
    df = global_df_registry.get_dataframe(df_id)
    try:
      if df is None:
        try:
          raw_path = global_df_registry.get_raw_path_from_id(df_id)
          df = pd.read_csv(raw_path)
          global_df_registry.register_dataframe(df, df_id, raw_path)
        except Exception as e:
          return f"Error loading DataFrame: {e}"
      if column1_name not in df.columns or column2_name not in df.columns:
          return f"Error: One or both columns not found."
      correlation = df[column1_name].corr(df[column2_name])
      return f"Correlation between '{column1_name}' and '{column2_name}': {correlation}"
    except Exception as e:
      return f"Error calculating correlation: {e}"

@tool("PerformHypothesisTest")
def perform_hypothesis_test(df_id: str, column_name: str, value: float) -> str:
    """Performs a one-sample t-test."""
    pprint(f"Performing hypothesis test on {column_name} with value {value} from {df_id}")
    df = global_df_registry.get_dataframe(df_id)
    try:
      if df is None:
        try:
          raw_path = global_df_registry.get_raw_path_from_id(df_id)
          df = pd.read_csv(raw_path)
          global_df_registry.register_dataframe(df, df_id, raw_path)
        except Exception as e:
          return f"Error loading DataFrame: {e}"
      if column_name not in df.columns:
            return f"Error: Column {column_name} not found."
      column_data = df[column_name].dropna()
      if not pd.api.types.is_numeric_dtype(column_data):
            return "Error: Hypothesis test can only be performed on numeric columns."
      t_statistic, p_value = stats.ttest_1samp(a=column_data, popmean=value)
      alpha = 0.05
      result = f"Reject null hypothesis. Mean is significantly different from {value}." if p_value < alpha else f"Fail to reject null hypothesis. Mean is not significantly different from {value}."
      return result + f" T-statistic: {t_statistic}, P-value: {p_value}"
    except Exception as e:
        return f"Error performing hypothesis test: {e}"

analyst_tools = [get_dataframe_schema,get_descriptive_statistics, calculate_correlation, perform_hypothesis_test, get_column_names, get_data,query_dataframe]

# Tools from original cell 6 (cJ1tuCJZdkXk)
@tool
def create_sample(points: Annotated[List[str], "List of main points or sections."], file_name: Annotated[str, "File path to save the outline."]) -> Annotated[str, "Path of the saved file of sample data from the dataset."]:
    """Create and save an outline."""
    with (WORKING_DIRECTORY / file_name).open("w") as file:
        for i, point in enumerate(points):
            file.write(f"{i + 1}. {point}\n")
    return f"sample data saved to {file_name}"

@tool
def read_file(file_name: Annotated[str, "File path to read the file from."], start: Annotated[Optional[int], "The start line. Default is 0"] = None, end: Annotated[Optional[int], "The end line. Default is None"] = None) -> str:
    """Read the specified data file."""
    pprint(f"Reading file {file_name} from {WORKING_DIRECTORY} \n with start {start} and end {end}")
    with (WORKING_DIRECTORY / file_name).open("r") as file:
        lines = file.readlines()
    if start is None: start = 0
    if end is None: end = start + 10
    return "\n".join(lines[start:end])

@tool
def write_file(content: Annotated[str, "Text content to be written into the file."], file_name: Annotated[str, "File path to save the document."]) -> Annotated[str, "Path of the saved document file."]:
    """Create and save a data file of any format."""
    with (WORKING_DIRECTORY / file_name).open("w") as file:
        file.write(content)
    return f"Document saved to {file_name}"

@tool
def edit_file(file_name: Annotated[str, "Path of the file to be edited."], inserts: Annotated[Dict[int, str], "Dictionary where key is the line number (1-indexed) and value is the text to be inserted at that line."]) -> Annotated[str, "Path of the edited document file."]:
    """Edit a document by inserting text at specific line numbers."""
    with (WORKING_DIRECTORY / file_name).open("r") as file:
        lines = file.readlines()
    sorted_inserts = sorted(inserts.items())
    for line_number, text in sorted_inserts:
        if 1 <= line_number <= len(lines) + 1:
            lines.insert(line_number - 1, text + "\n")
        else:
            return f"Error: Line number {line_number} is out of range."
    with (WORKING_DIRECTORY / file_name).open("w") as file:
        file.writelines(lines)
    return f"Document edited and saved to {file_name}"

repl = PythonREPL()
@tool
def python_repl_tool(code: Annotated[str, "The python code to execute."], df_id: Annotated[str, "The ID of the DataFrame in the global registry."]) -> str:
    """Executes Python code within a REPL environment that has access to the global DataFrame registry."""
    def get_df_from_registry(df_id_local):
        return global_df_registry.get_dataframe(df_id_local)
    try:
        # Make get_df_from_registry available in the REPL's local scope
        repl.globals['get_df_from_registry'] = get_df_from_registry
        repl.globals['pd'] = pd # Make pandas available as pd
        result = repl.run(code)
    except BaseException as e:
        return f"Failed to execute. Error: {repr(e)}"
    return result

analyst_tools.append(python_repl_tool)
analyst_tools.append(create_sample)
data_cleaning_tools.append(Tool(name="WriteFile",func = write_file,description="Useful to create and save a data file of any format."))
data_cleaning_tools.append(Tool(name="PythonREPL",func = python_repl_tool,description="Useful to execute python code. If you want to see the output of a value, you should print it out with `print(...)`. This is visible to the user."))
data_cleaning_tools.append(Tool(name="EditFile",func = edit_file,description="Useful to edit a document by inserting text at specific line numbers."))
data_cleaning_tools.append(Tool(name="QueryDataframe", func = query_dataframe,description="Useful to query a dataframe."))
data_cleaning_tools.append(Tool(name="GetData",func = get_data,description="Useful to retrieve data from the DataFrame with the specified ID, supporting flexible row and column selection, and specific cell retrieval."))

file_writer_tools = [get_dataframe_schema,write_file, edit_file, read_file, python_repl_tool]
visualization_tools = [python_repl_tool,get_dataframe_schema,get_data,get_column_names]
report_generator_tools = [python_repl_tool, write_file, edit_file, read_file]

@tool
def create_histogram(df_id: str, column_name: str) -> dict:
    """Generates a histogram for a specified numeric column in a DataFrame.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_name: The name of the numeric column for which to generate the histogram.

    Returns:
        A dictionary containing the plot type ('histogram'), column name,
        and a base64 encoded PNG image of the histogram,
        or an error message string if generation fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return {"error": f"DataFrame with ID '{df_id}' not found."}

        if column_name not in df.columns:
            return {"error": f"Column '{column_name}' not found in DataFrame '{df_id}'."}

        if not pd.api.types.is_numeric_dtype(df[column_name]):
            return {"error": f"Column '{column_name}' is not numeric and a histogram cannot be generated."}

        plt.figure(figsize=(10, 6))
        sns.histplot(df[column_name], kde=True)
        plt.title(f'Histogram of {column_name}')
        plt.xlabel(column_name)
        plt.ylabel('Frequency')
        
        buf = BytesIO()
        plt.savefig(buf, format="png")
        plt.close() # Close the plot to free memory
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        buf.close()

        return {"plot_type": "histogram", "column_name": column_name, "image_base64": image_base64}

    except Exception as e:
        plt.close() # Ensure plot is closed in case of error during generation
        return {"error": f"Failed to generate histogram: {str(e)}"}

@tool
def create_scatter_plot(df_id: str, x_column_name: str, y_column_name: str) -> dict:
    """Generates a scatter plot for two specified numeric columns in a DataFrame.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        x_column_name: The name of the numeric column for the x-axis.
        y_column_name: The name of the numeric column for the y-axis.

    Returns:
        A dictionary containing the plot type ('scatter'), x-column name,
        y-column name, and a base64 encoded PNG image of the scatter plot,
        or an error message dictionary if generation fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return {"error": f"DataFrame with ID '{df_id}' not found."}

        if x_column_name not in df.columns:
            return {"error": f"X-axis column '{x_column_name}' not found in DataFrame '{df_id}'."}
        if y_column_name not in df.columns:
            return {"error": f"Y-axis column '{y_column_name}' not found in DataFrame '{df_id}'."}

        if not pd.api.types.is_numeric_dtype(df[x_column_name]):
            return {"error": f"X-axis column '{x_column_name}' is not numeric."}
        if not pd.api.types.is_numeric_dtype(df[y_column_name]):
            return {"error": f"Y-axis column '{y_column_name}' is not numeric."}

        plt.figure(figsize=(10, 6))
        sns.scatterplot(x=df[x_column_name], y=df[y_column_name])
        plt.title(f'Scatter Plot of {y_column_name} vs {x_column_name}')
        plt.xlabel(x_column_name)
        plt.ylabel(y_column_name)
        
        buf = BytesIO()
        plt.savefig(buf, format="png")
        plt.close() # Close the plot to free memory
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        buf.close()

        return {
            "plot_type": "scatter", 
            "x_column": x_column_name, 
            "y_column": y_column_name, 
            "image_base64": image_base64
        }

    except Exception as e:
        plt.close() # Ensure plot is closed in case of error during generation
        return {"error": f"Failed to generate scatter plot: {str(e)}"}

visualization_tools.append(create_histogram)
visualization_tools.append(create_scatter_plot)

@tool
def create_correlation_heatmap(df_id: str, column_names: Optional[List[str]] = None) -> dict:
    """Generates a correlation heatmap for numeric columns in a DataFrame.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_names: Optional list of column names to include.
                      If None or empty, all numeric columns are used.

    Returns:
        A dictionary containing the plot type ('correlation_heatmap')
        and a base64 encoded PNG image of the heatmap,
        or an error message dictionary if generation fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return {"error": f"DataFrame with ID '{df_id}' not found."}

        df_to_correlate = df
        if column_names:
            missing_cols = [col for col in column_names if col not in df.columns]
            if missing_cols:
                return {"error": f"Columns not found: {', '.join(missing_cols)}."}
            df_to_correlate = df[column_names]

        df_numeric = df_to_correlate.select_dtypes(include=np.number)

        if df_numeric.empty:
            return {"error": "No numeric columns found to generate a correlation heatmap."}
        if len(df_numeric.columns) < 2:
             return {"error": "At least two numeric columns are required for a correlation heatmap."}

        corr_matrix = df_numeric.corr()

        plt.figure(figsize=(12, 10))
        sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f")
        plt.title('Correlation Heatmap')
        
        buf = BytesIO()
        plt.savefig(buf, format="png")
        plt.close() # Close the plot to free memory
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        buf.close()

        return {"plot_type": "correlation_heatmap", "image_base64": image_base64}

    except Exception as e:
        plt.close() # Ensure plot is closed in case of error during generation
        return {"error": f"Failed to generate correlation heatmap: {str(e)}"}

visualization_tools.append(create_correlation_heatmap)

@tool
def create_box_plot(df_id: str, column_name: str, group_by_column: Optional[str] = None) -> dict:
    """Generates a box plot for a specified numeric column, optionally grouped by another column.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_name: The name of the numeric column for the box plot values.
        group_by_column: Optional. The name of the column to group by.

    Returns:
        A dictionary containing the plot type ('box_plot'), value column,
        group by column (if any), and a base64 encoded PNG image,
        or an error message dictionary if generation fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return {"error": f"DataFrame with ID '{df_id}' not found."}

        if column_name not in df.columns:
            return {"error": f"Value column '{column_name}' not found in DataFrame '{df_id}'."}

        if not pd.api.types.is_numeric_dtype(df[column_name]):
            return {"error": f"Value column '{column_name}' is not numeric."}

        plt.figure(figsize=(12, 8)) # Adjusted figure size for potential groups
        plot_title = f'Box Plot of {column_name}'

        if group_by_column:
            if group_by_column not in df.columns:
                return {"error": f"Group by column '{group_by_column}' not found in DataFrame '{df_id}'."}
            sns.boxplot(x=df[group_by_column], y=df[column_name])
            plot_title += f' grouped by {group_by_column}'
        else:
            sns.boxplot(y=df[column_name])
        
        plt.title(plot_title)
        plt.xlabel(group_by_column if group_by_column else column_name)
        plt.ylabel(column_name)
        plt.xticks(rotation=45, ha='right') # Rotate x-axis labels if grouped by categorical data
        plt.tight_layout() # Adjust layout to prevent labels from overlapping

        buf = BytesIO()
        plt.savefig(buf, format="png")
        plt.close() # Close the plot to free memory
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        buf.close()

        return {
            "plot_type": "box_plot", 
            "value_column": column_name, 
            "group_by_column": group_by_column, 
            "image_base64": image_base64
        }

    except Exception as e:
        plt.close() # Ensure plot is closed in case of error during generation
        return {"error": f"Failed to generate box plot: {str(e)}"}

visualization_tools.append(create_box_plot)

@tool
def export_dataframe(df_id: str, file_name: str, file_format: str) -> str:
    """Exports a DataFrame to a file (CSV, Excel, or JSON) in the working directory.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        file_name: The desired name for the output file (e.g., 'my_data', 'report').
                   The correct extension (.csv, .xlsx, .json) will be appended based on format.
        file_format: The format to export to. Supported: 'csv', 'excel', 'json'.

    Returns:
        A success message with the path to the saved file,
        or an error message if export fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return f"Error: DataFrame with ID '{df_id}' not found."

        # Ensure file_name does not have path components and prepare the full path
        base_name = Path(file_name).name # Sanitize to prevent path traversal
        
        if file_format == 'csv':
            actual_file_name = f"{Path(base_name).stem}.csv"
            full_path = WORKING_DIRECTORY / actual_file_name
            df.to_csv(full_path, index=False)
        elif file_format == 'excel':
            actual_file_name = f"{Path(base_name).stem}.xlsx"
            full_path = WORKING_DIRECTORY / actual_file_name
            df.to_excel(full_path, index=False)
        elif file_format == 'json':
            actual_file_name = f"{Path(base_name).stem}.json"
            full_path = WORKING_DIRECTORY / actual_file_name
            df.to_json(full_path, orient='records', indent=4)
        else:
            return f"Error: Unsupported file format '{file_format}'. Supported formats are 'csv', 'excel', 'json'."

        return f"DataFrame '{df_id}' successfully exported to '{str(full_path)}' as {file_format}."

    except Exception as e:
        return f"Failed to export DataFrame '{df_id}' to {file_format}: {str(e)}"

analyst_tools.append(export_dataframe)
file_writer_tools.append(export_dataframe)
data_cleaning_tools.append(export_dataframe)

@tool
def detect_and_remove_duplicates(df_id: str) -> str:
    """Detects and removes duplicate rows from a DataFrame.

    Args:
        df_id: The ID of the DataFrame in the global registry.

    Returns:
        A message summarizing the number of duplicates found and removed,
        or an error message if the operation fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return f"Error: DataFrame with ID '{df_id}' not found."

        num_duplicates = df.duplicated().sum()

        if num_duplicates == 0:
            return f"No duplicate rows found in DataFrame '{df_id}'."

        original_row_count = len(df)
        df.drop_duplicates(keep='first', inplace=True)
        rows_removed = original_row_count - len(df)

        # Update the DataFrame in the registry
        # The raw_path is kept from the original registration, implying the in-memory df is the target of de-duplication.
        # If the de-duplicated df needs to be saved back to this raw_path, an explicit save (e.g., df_no_duplicates.to_csv(...)) would be needed.
        # For now, we are updating the in-memory representation.
        original_raw_path = global_df_registry.get_raw_path_from_id(df_id)
        global_df_registry.register_dataframe(df_no_duplicates, df_id=df_id, raw_path=original_raw_path)

        return f"Found {num_duplicates} duplicate rows. Removed {rows_removed} rows. DataFrame '{df_id}' updated in memory."

    except Exception as e:
        return f"Error during duplicate detection/removal for DataFrame '{df_id}': {str(e)}"

data_cleaning_tools.append(detect_and_remove_duplicates)

@tool
def convert_data_types(df_id: str, column_types: dict) -> str:
    """Converts specified columns in a DataFrame to new data types.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_types: A dictionary where keys are column names and values are
                      the target data type strings (e.g., 'int', 'float', 'datetime64[ns]').

    Returns:
        A message summarizing the successful and failed data type conversions,
        or an error message if the DataFrame is not found.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return f"Error: DataFrame with ID '{df_id}' not found."

        df_copy = df.copy()
        successful_conversions = []
        failed_conversions = []
        any_conversion_made = False

        for column_name, new_type_str in column_types.items():
            if column_name not in df_copy.columns:
                failed_conversions.append(f"{column_name} (column not found)")
                continue
            
            try:
                # Special handling for 'boolean' as pandas uses 'bool'
                if new_type_str.lower() == 'boolean':
                    new_type_str = 'bool'
                
                # For numeric types, attempt to clean non-convertible values to NaN first
                if new_type_str in ['int', 'float', 'int64', 'float64']:
                    df_copy[column_name] = pd.to_numeric(df_copy[column_name], errors='coerce')
                
                df_copy[column_name] = df_copy[column_name].astype(new_type_str)
                successful_conversions.append(f"{column_name} to {new_type_str}")
                any_conversion_made = True
            except (ValueError, TypeError) as e:
                failed_conversions.append(f"{column_name} to {new_type_str} (Error: {e})")
            except Exception as e: # Catch any other unexpected errors during conversion
                failed_conversions.append(f"{column_name} to {new_type_str} (Unexpected error: {e})")

        if any_conversion_made:
            original_raw_path = global_df_registry.get_raw_path_from_id(df_id)
            global_df_registry.register_dataframe(df_copy, df_id=df_id, raw_path=original_raw_path)
            summary_message = f"Data type conversions applied for DataFrame '{df_id}'."
        else:
            summary_message = f"No data type conversions made for DataFrame '{df_id}'."
        
        if successful_conversions:
            summary_message += " Success: [" + ", ".join(successful_conversions) + "]."
        if failed_conversions:
            summary_message += " Failed: [" + ", ".join(failed_conversions) + "]."
            
        return summary_message

    except Exception as e:
        return f"Error during data type conversion process for DataFrame '{df_id}': {str(e)}"

data_cleaning_tools.append(convert_data_types)

@tool
def generate_html_report(report_title: str, text_sections: Dict[str, str], image_sections: Dict[str, str]) -> str:
    """Generates an HTML report from text and image sections and saves it to a file.

    Args:
        report_title: The main title for the report.
        text_sections: A dictionary where keys are section titles (e.g., "Data Description") 
                       and values are the corresponding text content (can be multiline).
        image_sections: A dictionary where keys are section titles (e.g., "Histogram of Age") 
                        and values are base64 encoded PNG image strings.

    Returns:
        A success message with the path to the saved HTML report file,
        or an error message if generation fails.
    """
    try:
        html_content = f"""<html>
<head><title>{report_title}</title></head>
<body>
<h1>{report_title}</h1>
"""

        for title, text in text_sections.items():
            html_content += f"<h2>{title}</h2>\n<p>{text.replace('\n', '<br>')}</p>\n"

        for title, base64_image_string in image_sections.items():
            html_content += f"<h2>{title}</h2>\n"
            html_content += f'<img src="data:image/png;base64,{base64_image_string}" alt="{title}" style="max-width:100%;height:auto;">\n'

        html_content += "</body>\n</html>"

        safe_title = "".join(c if c.isalnum() else "_" for c in report_title)
        file_name = f"{safe_title}_report.html"
        full_path = WORKING_DIRECTORY / file_name

        with open(full_path, 'w', encoding='utf-8') as f:
            f.write(html_content)

        return f"HTML report generated: '{str(full_path)}'"

    except Exception as e:
        return f"Failed to generate HTML report: {str(e)}"

report_generator_tools.append(generate_html_report)

@tool
def calculate_correlation_matrix(df_id: str, column_names: Optional[List[str]] = None) -> str:
    """Calculates the correlation matrix for numeric columns in a DataFrame.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_names: Optional. A list of column names to include in the calculation.
                      If None or empty, all numeric columns will be used.

    Returns:
        A JSON string representing the correlation matrix,
        or an error message string (as JSON) if calculation fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        df_copy = df.copy()

        if column_names:
            # Validate provided column names
            missing_cols = [col for col in column_names if col not in df_copy.columns]
            if missing_cols:
                return json.dumps({"error": f"Columns not found in DataFrame: {', '.join(missing_cols)}."})
            df_to_correlate = df_copy[column_names]
        else:
            df_to_correlate = df_copy

        df_numeric = df_to_correlate.select_dtypes(include=np.number)

        if df_numeric.empty:
            return json.dumps({"error": "No numeric columns found to calculate correlation matrix."})
        if len(df_numeric.columns) < 2:
            return json.dumps({"error": "At least two numeric columns are required to calculate a correlation matrix."})

        corr_matrix = df_numeric.corr()
        return corr_matrix.to_json(orient='index')

    except Exception as e:
        return json.dumps({"error": f"Failed to calculate correlation matrix: {str(e)}"})

analyst_tools.append(calculate_correlation_matrix)

@tool
def detect_outliers(df_id: str, column_name: str) -> str:
    """Detects outliers in a numeric column of a DataFrame using the IQR method.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_name: The name of the numeric column to check for outliers.

    Returns:
        A JSON string summarizing the outlier detection findings (IQR, bounds,
        number of outliers, sample of outliers) or a message if no outliers are found.
        Returns a JSON string with an error message if the operation fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        if column_name not in df.columns:
            return json.dumps({"error": f"Column '{column_name}' not found in DataFrame '{df_id}'."})

        s = df[column_name]
        if not pd.api.types.is_numeric_dtype(s):
            return json.dumps({"error": f"Column '{column_name}' must be numeric to detect outliers using IQR."})

        Q1 = s.quantile(0.25)
        Q3 = s.quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        outliers = s[(s < lower_bound) | (s > upper_bound)]

        if not outliers.empty:
            return json.dumps({
                "column": column_name,
                "iqr": IQR,
                "lower_bound": lower_bound,
                "upper_bound": upper_bound,
                "num_outliers": len(outliers),
                "outliers_sample": outliers.head().tolist() # Convert sample to list for JSON serialization
            })
        else:
            return json.dumps({
                "column": column_name, 
                "message": "No outliers detected using IQR method.",
                "iqr": IQR,
                "lower_bound": lower_bound,
                "upper_bound": upper_bound
            })

    except Exception as e:
        return json.dumps({"error": f"Failed to detect outliers: {str(e)}"})

analyst_tools.append(calculate_correlation_matrix)

@tool
def perform_normality_test(df_id: str, column_name: str) -> str:
    """Performs a Shapiro-Wilk normality test on a numeric column.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_name: The name of the numeric column to test.

    Returns:
        A JSON string with the test statistic, p-value, and interpretation,
        or an error message string (as JSON) if the test cannot be performed.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        if column_name not in df.columns:
            return json.dumps({"error": f"Column '{column_name}' not found in DataFrame '{df_id}'."})

        s = df[column_name].dropna()
        if not pd.api.types.is_numeric_dtype(s):
            return json.dumps({"error": f"Column '{column_name}' must be numeric for normality testing."})

        if not (3 <= len(s) < 5000):
            return json.dumps({"error": f"Column '{column_name}' must contain between 3 and 4999 non-null samples for Shapiro-Wilk test. Found {len(s)}."})

        stat, p_value = stats.shapiro(s)
        alpha = 0.05
        is_normal = p_value > alpha
        interpretation = "Data looks Gaussian (fail to reject H0)" if is_normal else "Data does not look Gaussian (reject H0)"

        return json.dumps({
            "column": column_name,
            "test_type": "Shapiro-Wilk",
            "statistic": stat,
            "p_value": p_value,
            "is_normal": is_normal,
            "interpretation": interpretation
        })

    except Exception as e:
        return json.dumps({"error": f"Failed to perform normality test: {str(e)}"})

analyst_tools.append(detect_outliers)

@tool
def perform_normality_test(df_id: str, column_name: str) -> str:
    """Performs a Shapiro-Wilk normality test on a numeric column.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_name: The name of the numeric column to test.

    Returns:
        A JSON string with the test statistic, p-value, and interpretation,
        or an error message string (as JSON) if the test cannot be performed.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        if column_name not in df.columns:
            return json.dumps({"error": f"Column '{column_name}' not found in DataFrame '{df_id}'."})

        s = df[column_name].dropna()
        if not pd.api.types.is_numeric_dtype(s):
            return json.dumps({"error": f"Column '{column_name}' must be numeric for normality testing."})

        # Shapiro-Wilk test is typically suitable for sample sizes between 3 and 5000.
        if not (3 <= len(s) < 5000):
            return json.dumps({"error": f"Column '{column_name}' must contain between 3 and 4999 non-null samples for Shapiro-Wilk test. Found {len(s)}."})

        stat, p_value = stats.shapiro(s)
        alpha = 0.05
        is_normal = p_value > alpha
        interpretation = "Data looks Gaussian (fail to reject H0)" if is_normal else "Data does not look Gaussian (reject H0)"

        return json.dumps({
            "column": column_name,
            "test_type": "Shapiro-Wilk",
            "statistic": stat,
            "p_value": p_value,
            "is_normal": is_normal,
            "interpretation": interpretation
        })

    except Exception as e:
        return json.dumps({"error": f"Failed to perform normality test: {str(e)}"})

analyst_tools.append(perform_normality_test)

@tool
def assess_data_quality(df_id: str) -> str:
    """Provides a comprehensive data quality assessment for a DataFrame.

    Checks for shape, missing values, data types, duplicate rows, and memory usage.

    Args:
        df_id: The ID of the DataFrame in the global registry.

    Returns:
        A JSON string summarizing the data quality assessment,
        or an error message string (as JSON) if the assessment fails.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        quality_report = {}

        # Basic Info
        quality_report["shape"] = {"rows": int(df.shape[0]), "columns": int(df.shape[1])}

        # Missing Values
        missing_info = df.isnull().sum()
        quality_report["missing_values_summary"] = missing_info[missing_info > 0].astype(int).to_dict()
        quality_report["total_missing_values"] = int(missing_info.sum())
        total_cells = df.shape[0] * df.shape[1]
        quality_report["percentage_missing"] = (quality_report["total_missing_values"] / total_cells) * 100 if total_cells > 0 else 0

        # Data Types
        quality_report["data_types"] = df.dtypes.astype(str).to_dict()

        # Duplicate Rows
        num_duplicates = df.duplicated().sum()
        quality_report["duplicate_rows"] = {
            "count": int(num_duplicates),
            "percentage": (int(num_duplicates) / df.shape[0]) * 100 if df.shape[0] > 0 else 0
        }

        # Memory Usage
        memory_usage_bytes = df.memory_usage(deep=True).sum()
        quality_report["memory_usage"] = f"{memory_usage_bytes / (1024**2):.2f} MB"

        return json.dumps(quality_report, indent=4, default=str) # Use default=str for numpy types

    except Exception as e:
        return json.dumps({"error": f"Failed to assess data quality: {str(e)}"})

analyst_tools.append(assess_data_quality)
data_cleaning_tools.append(assess_data_quality) # Also useful for data cleaning stage

@tool
def perform_normality_test(df_id: str, column_name: str) -> str:
    """Performs a Shapiro-Wilk normality test on a numeric column.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_name: The name of the numeric column to test.

    Returns:
        A JSON string with the test statistic, p-value, and interpretation,
        or an error message string (as JSON) if the test cannot be performed.
    """
    try:
        df = global_df_registry.get_dataframe(df_id)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        if column_name not in df.columns:
            return json.dumps({"error": f"Column '{column_name}' not found in DataFrame '{df_id}'."})

        s = df[column_name].dropna()
        if not pd.api.types.is_numeric_dtype(s):
            return json.dumps({"error": f"Column '{column_name}' must be numeric for normality testing."})

        # Shapiro-Wilk test is typically suitable for sample sizes between 3 and 4999.
        if not (3 <= len(s) < 5000):
            return json.dumps({"error": f"Column '{column_name}' must contain between 3 and 4999 non-null samples for Shapiro-Wilk test. Found {len(s)}."})

        stat, p_value = stats.shapiro(s)
        alpha = 0.05
        is_normal = p_value > alpha
        interpretation = "Data looks Gaussian (fail to reject H0)" if is_normal else "Data does not look Gaussian (reject H0)"

        return json.dumps({
            "column": column_name,
            "test_type": "Shapiro-Wilk",
            "statistic": stat,
            "p_value": p_value,
            "is_normal": is_normal,
            "interpretation": interpretation
        })

    except Exception as e:
        return json.dumps({"error": f"Failed to perform normality test: {str(e)}"})

analyst_tools.append(perform_normality_test)

@tool
def search_web_for_context(query: str, max_results: int = 3) -> str:
    """Performs a web search using Tavily API to find external context or insights.

    Args:
        query: The search query string.
        max_results: The maximum number of search results to return (default is 3).

    Returns:
        A JSON string containing a list of search results (each with title, url, content),
        or a JSON string with an error message if the search fails or API key is missing.
    """
    try:
        tavily_api_key = os.environ.get('TAVILY_API_KEY')
        if not tavily_api_key:
            return json.dumps({"error": "TAVILY_API_KEY not found in environment variables."})
        
        client = TavilyClient(api_key=tavily_api_key)
        # Use search_depth="advanced" for more comprehensive results if needed, basic is faster.
        response = client.search(query=query, search_depth="basic", max_results=max_results)
        
        # Extract relevant parts of the results
        formatted_results = []
        if "results" in response:
            for res in response["results"]:
                formatted_results.append({
                    "title": res.get("title"),
                    "url": res.get("url"),
                    "content": res.get("content")
                })
        return json.dumps(formatted_results, indent=4)

    except Exception as e:
        return json.dumps({"error": f"Failed to perform web search: {str(e)}"})

analyst_tools.append(perform_normality_test)

@tool
def load_multiple_files(file_paths: List[str], file_type: str) -> str:
    """Loads multiple data files (e.g., CSVs, JSONs) into DataFrames.

    Each successfully loaded DataFrame is registered with a new unique ID.
    Assumes file paths are accessible by the system.

    Args:
        file_paths: A list of strings, where each string is the full path to a data file.
        file_type: The type of the files to load. Supported: 'csv', 'json'.

    Returns:
        A JSON string summarizing the loading operation for each file, including
        original path, new df_id (if successful), row/column counts, and status.
    """
    results_summary = []
    for i, file_path_str in enumerate(file_paths):
        file_path = Path(file_path_str)
        try:
            if not file_path.exists() or not file_path.is_file():
                results_summary.append({
                    "original_path": file_path_str, 
                    "status": "error", 
                    "message": "File not found or is not a file."
                })
                continue

            df = None
            if file_type.lower() == 'csv':
                df = pd.read_csv(file_path)
            elif file_type.lower() == 'json':
                df = pd.read_json(file_path) # Add orient='records', lines=True if needed for specific JSON structures
            else:
                results_summary.append({
                    "original_path": file_path_str, 
                    "status": "error", 
                    "message": f"Unsupported file type: {file_type}. Supported: 'csv', 'json'."
                })
                continue

            # Generate a unique df_id, e.g., using file stem and an index or UUID
            new_df_id = f"loaded_df_{file_path.stem}_{str(uuid.uuid4())[:8]}"
            global_df_registry.register_dataframe(df, df_id=new_df_id, raw_path=file_path_str)
            
            results_summary.append({
                "original_path": file_path_str,
                "df_id": new_df_id,
                "rows": len(df),
                "columns": len(df.columns),
                "status": "success"
            })

        except Exception as e:
            results_summary.append({
                "original_path": file_path_str,
                "status": "error",
                "message": str(e)
            })
            
    return json.dumps(results_summary, indent=4)

analyst_tools.append(load_multiple_files)
data_cleaning_tools.append(load_multiple_files)@tool
def merge_dataframes(
    left_df_id: str,
    right_df_id: str,
    how: str,
    on: Optional[Union[str, List[str]]] = None,
    left_on: Optional[Union[str, List[str]]] = None,
    right_on: Optional[Union[str, List[str]]] = None
) -> str:
    """Merges two DataFrames based on specified keys and join type.

    Args:
        left_df_id: The ID of the left DataFrame.
        right_df_id: The ID of the right DataFrame.
        how: Type of merge to be performed. One of 'left', 'right', 'outer', 'inner', 'cross'.
        on: Column or index level names to join on. Must be found in both DataFrames.
            If None and not merging on indexes, this defaults to the intersection of the columns in both DataFrames.
        left_on: Column or index level names to join on in the left DataFrame.
        right_on: Column or index level names to join on in the right DataFrame.

    Returns:
        A JSON string with the new DataFrame ID for the merged DataFrame and its dimensions,
        or a JSON string with an error message if merging fails.
    """
    allowed_how_types = ['left', 'right', 'outer', 'inner', 'cross']
    if how not in allowed_how_types:
        return json.dumps({
            "error": f"Invalid merge type '{how}'. Allowed types are: {allowed_how_types}"
        })

    try:
        left_df = global_df_registry.get_dataframe(left_df_id, load_if_not_exists=True)
        if left_df is None:
            return json.dumps({"error": f"Left DataFrame with ID '{left_df_id}' not found."})

        right_df = global_df_registry.get_dataframe(right_df_id, load_if_not_exists=True)
        if right_df is None:
            return json.dumps({"error": f"Right DataFrame with ID '{right_df_id}' not found."})

        merged_df = pd.merge(
            left_df,
            right_df,
            how=how,
            on=on,
            left_on=left_on,
            right_on=right_on
        )

        new_merged_df_id = f"merged_df_{left_df_id}_{right_df_id}_{str(uuid.uuid4())[:4]}"
        # For derived dataframes, raw_path can be an empty string or indicate its derived nature
        raw_path_info = f"derived_from_merge_{left_df_id}_{right_df_id}"
        global_df_registry.register_dataframe(merged_df, df_id=new_merged_df_id, raw_path=raw_path_info)
        
        return json.dumps({
            "new_df_id": new_merged_df_id,
            "rows": len(merged_df),
            "columns": len(merged_df.columns),
            "message": f"Merge successful. New DataFrame '{new_merged_df_id}' created."
        })

    except KeyError as e:
        return json.dumps({"error": f"KeyError during merge: {str(e)}. Check if 'on', 'left_on', or 'right_on' keys exist in respective DataFrames."})
    except pd.errors.MergeError as e:
        return json.dumps({"error": f"MergeError: {str(e)}"})
    except Exception as e:
        return json.dumps({"error": f"An unexpected error occurred during merge: {str(e)}"})

analyst_tools.append(merge_dataframes)
data_cleaning_tools.append(merge_dataframes)

import re # Import re locally for the standardize_column_names tool
from sklearn.preprocessing import LabelEncoder, OneHotEncoder # For handle_categorical_encoding

@tool
def standardize_column_names(df_id: str, rule: str) -> str:
    """Standardizes column names of a DataFrame.

    Supported rules: 'snake_case', 'lower_case', 'upper_case'.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        rule: The standardization rule to apply.

    Returns:
        A JSON string summarizing the changes, or an error message.
    """
    
    def to_snake_case(name):
        # This specific snake_case function is chosen for its common usage pattern.
        s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
        s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
        s3 = re.sub(r'[^a-z0-9_]+' , '_', s2)
        s4 = re.sub(r'_+', '_', s3).strip('_')
        return s4

    try:
        df = global_df_registry.get_dataframe(df_id, load_if_not_exists=True)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        df_copy = df.copy()
        original_columns = df_copy.columns.tolist()
        new_columns = []

        if rule == 'lower_case':
            new_columns = [col.lower() for col in original_columns]
        elif rule == 'upper_case':
            new_columns = [col.upper() for col in original_columns]
        elif rule == 'snake_case':
            new_columns = [to_snake_case(col) for col in original_columns]
        else:
            return json.dumps({
                "error": f"Unsupported rule: '{rule}'. Supported rules are 'snake_case', 'lower_case', 'upper_case'."
            })
        
        df_copy.columns = new_columns
        
        original_raw_path = global_df_registry.get_raw_path_from_id(df_id)
        if original_raw_path is None or original_raw_path.startswith("derived_from_"):
             original_raw_path = f"df_{df_id}_cols_std" 

        global_df_registry.register_dataframe(df_copy, df_id=df_id, raw_path=original_raw_path)
        
        return json.dumps({
            "df_id": df_id,
            "rule_applied": rule,
            "original_columns": original_columns,
            "new_columns": new_columns,
            "message": "Column names standardized successfully."
        })

    except Exception as e:
        return json.dumps({"error": f"An unexpected error occurred during column name standardization: {str(e)}"})

data_cleaning_tools.append(standardize_column_names)


@tool
def format_markdown_report(report_title: str, text_sections: Dict[str, str], image_sections: Dict[str, str]) -> str:
    """Formats a report from text and image sections into a Markdown file.

    Args:
        report_title: The main title for the report.
        text_sections: A dictionary where keys are section titles and values are text content.
        image_sections: A dictionary where keys are section titles and values are
                        either base64 encoded PNG image strings or paths to image files.

    Returns:
        A JSON string with a success message and the path to the saved Markdown report,
        or a JSON string with an error message if generation fails.
    """
    try:
        md_content = f"# {report_title}\n\n"

        for title, text_content in text_sections.items():
            md_content += f"## {title}\n{text_content}\n\n"

        for title, image_value in image_sections.items():
            md_content += f"## {title}\n"
            # Heuristic to distinguish base64 from path:
            # - Check for common path characters or extensions.
            # - Check if it's a very long string (base64 tends to be long).
            # This is a basic heuristic and might need refinement for more robust detection.
            is_likely_path = any(char in image_value for char in ['/', '\\', '.']) and len(image_value) < 2000
            is_very_long = len(image_value) > 2000 # Common for base64

            if not is_likely_path and is_very_long: # Likely base64
                 # Basic check if it could be base64 (alphanumeric, +, /, =)
                if all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=" for c in image_value[-100:]): # Check last 100 chars
                    md_content += f"![{title}](data:image/png;base64,{image_value})\n\n"
                else: # Fallback to path if it doesn't look like base64
                    md_content += f"![{title}]({image_value})\n\n"
            elif is_likely_path and not is_very_long: # Likely a path
                 md_content += f"![{title}]({image_value})\n\n"
            elif is_very_long : # If very long, assume base64
                 md_content += f"![{title}](data:image/png;base64,{image_value})\n\n"
            else: # Default to path if unsure or short
                 md_content += f"![{title}]({image_value})\n\n"


        # Sanitize title for file name
        safe_title = "".join(c if c.isalnum() or c in [' ', '_'] else '_' for c in report_title).replace(' ', '_')
        file_name = f"{safe_title}_report.md"
        
        full_path = WORKING_DIRECTORY / file_name

        with open(full_path, 'w', encoding='utf-8') as f:
            f.write(md_content)

        return json.dumps({
            "report_path": str(full_path),
            "message": "Markdown report generated successfully."
        })

    except Exception as e:
        return json.dumps({"error": f"Failed to generate Markdown report: {str(e)}"})

report_generator_tools.append(format_markdown_report)

from xhtml2pdf import pisa # For create_pdf_report tool

@tool
def create_pdf_report(html_file_path_str: str) -> str:
    """Converts a given HTML file (in the working directory) to a PDF report.

    Args:
        html_file_path_str: The path to the source HTML file, relative to the working directory,
                            or an absolute path if it's within the working directory sandbox.

    Returns:
        A JSON string with the path to the generated PDF report and a success message,
        or a JSON string with an error message if conversion fails.
    """
    try:
        # Ensure html_file_path_str is treated as relative to WORKING_DIRECTORY
        # if it's not already an absolute path starting with WORKING_DIRECTORY string.
        # Path().is_absolute() might not work as expected if html_file_path_str is just a name like "report.html"
        if not os.path.isabs(html_file_path_str) or not html_file_path_str.startswith(str(WORKING_DIRECTORY)):
            source_html_path = WORKING_DIRECTORY / Path(html_file_path_str).name # Sanitize to prevent escaping working dir
        else:
            # If it's an absolute path, ensure it's within WORKING_DIRECTORY for security
            prospective_path = Path(html_file_path_str)
            if WORKING_DIRECTORY not in prospective_path.parents and prospective_path.parent != WORKING_DIRECTORY:
                 return json.dumps({"error": "Absolute HTML file path is outside the working directory."})
            source_html_path = prospective_path

        if not source_html_path.exists() or not source_html_path.is_file():
            return json.dumps({"error": f"Source HTML file not found at: {str(source_html_path)}"})

        pdf_file_path = source_html_path.with_suffix('.pdf')

        with open(source_html_path, "r", encoding="utf-8") as html_file:
            html_content = html_file.read()
        
        with open(pdf_file_path, "wb") as pdf_file:
            pisa_status = pisa.CreatePDF(
                html_content,  # the HTML to convert
                dest=pdf_file  # file handle to receive result
            )

        if pisa_status.err:
            return json.dumps({"error": f"PDF generation failed: {pisa_status.err}"})
        else:
            return json.dumps({
                "pdf_report_path": str(pdf_file_path),
                "message": "PDF report generated successfully."
            })

    except FileNotFoundError:
        return json.dumps({"error": f"Source HTML file not found (FileNotFoundError): {html_file_path_str}"})
    except Exception as e:
        return json.dumps({"error": f"An unexpected error occurred during PDF generation: {str(e)}"})

report_generator_tools.append(create_pdf_report)

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import accuracy_score, mean_squared_error
import joblib
import numpy as np # For np.sqrt, though already in essential_imports, good to have locally for clarity

@tool
def train_ml_model(df_id: str, feature_columns: List[str], target_column: str, model_type: str, test_size: float = 0.2, random_state: Optional[int] = 42, save_model: bool = False) -> str:
    """Trains a specified ML model on the DataFrame.

    Supported model_types: 'logistic_regression', 'linear_regression'.
    Drops rows with NaNs in features/target.

    Args:
        df_id: ID of the DataFrame.
        feature_columns: List of column names to use as features.
        target_column: Name of the column to use as the target variable.
        model_type: Type of model to train.
        test_size: Proportion of dataset for the test split.
        random_state: Seed for reproducibility.
        save_model: If True, saves the trained model to a file in the working directory.

    Returns:
        JSON string with training results (model type, metrics, model path if saved),
        or a JSON error string.
    """
    try:
        df = global_df_registry.get_dataframe(df_id, load_if_not_exists=True)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        if target_column not in df.columns:
            return json.dumps({"error": f"Target column '{target_column}' not found in DataFrame."})
        
        missing_features = [col for col in feature_columns if col not in df.columns]
        if missing_features:
            return json.dumps({"error": f"Feature column(s) '{', '.join(missing_features)}' not found in DataFrame."})
            
        if target_column in feature_columns:
            return json.dumps({"error": f"Target column '{target_column}' cannot also be in feature_columns."})

        relevant_columns = feature_columns + [target_column]
        df_cleaned = df[relevant_columns].dropna().copy() # Use .copy() to avoid SettingWithCopyWarning later

        if df_cleaned.empty:
            return json.dumps({"error": "DataFrame is empty after dropping NaNs from feature and target columns."})

        X = df_cleaned[feature_columns]
        y = df_cleaned[target_column]

        if X.empty or y.empty:
             return json.dumps({"error": "Features (X) or target (y) are empty after processing."})


        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

        model_path = ""
        metric_name = ""
        metric_value = None
        
        if model_type == 'logistic_regression':
            model = LogisticRegression(random_state=random_state, max_iter=1000)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            metric_name = "accuracy"
            metric_value = accuracy_score(y_test, y_pred)
        
        elif model_type == 'linear_regression':
            if not pd.api.types.is_numeric_dtype(y_train): # Check y_train, not just y
                 return json.dumps({"error": f"Target column '{target_column}' must be numeric for linear regression."})
            model = LinearRegression()
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            metric_name = "rmse"
            metric_value = np.sqrt(mean_squared_error(y_test, y_pred)) # np is needed here
        
        else:
            return json.dumps({"error": f"Unsupported model_type: '{model_type}'. Supported: 'logistic_regression', 'linear_regression'."})

        if save_model:
            model_filename = f"{model_type}_{df_id.replace('-', '_')}_{target_column.replace(' ', '_').replace('/', '_')}.joblib"
            model_full_path = WORKING_DIRECTORY / model_filename
            try:
                joblib.dump(model, model_full_path)
                model_path = str(model_full_path)
            except Exception as e:
                return json.dumps({"error": f"Failed to save model: {str(e)}"})

        results = {
            "model_type": model_type,
            "target_column": target_column,
            "features_used_count": len(feature_columns),
            "training_set_size": len(X_train),
            "test_set_size": len(X_test),
            metric_name: metric_value,
            "model_saved_path": model_path if save_model else "Not saved",
            "message": "Model training complete."
        }
        return json.dumps(results, cls=NpEncoder) # Use NpEncoder for numpy types

    except Exception as e:
        return json.dumps({"error": f"An unexpected error occurred during model training: {str(e)}"})

# Helper NpEncoder class for json.dumps if numpy types are present in results
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

analyst_tools.append(train_ml_model)


@tool
def handle_categorical_encoding(df_id: str, column_name: str, strategy: str) -> str:
    """Applies categorical encoding to a specified column.

    Supported strategies: 'label_encoding', 'one_hot_encoding'.
    For 'label_encoding', a new column '{column_name}_label_encoded' is created.
    For 'one_hot_encoding', the original column is replaced by new one-hot encoded columns.

    Args:
        df_id: The ID of the DataFrame in the global registry.
        column_name: The name of the categorical column to encode.
        strategy: The encoding strategy to apply.

    Returns:
        A JSON string summarizing the encoding operation, or an error message.
    """
    try:
        df = global_df_registry.get_dataframe(df_id, load_if_not_exists=True)
        if df is None:
            return json.dumps({"error": f"DataFrame with ID '{df_id}' not found."})

        if column_name not in df.columns:
            return json.dumps({"error": f"Column '{column_name}' not found in DataFrame '{df_id}'."})

        df_copy = df.copy()
        original_raw_path = global_df_registry.get_raw_path_from_id(df_id)
        if original_raw_path is None or original_raw_path.startswith("derived_from_") or original_raw_path.endswith("_cols_std"): # if it's already a modified df
             original_raw_path = f"df_{df_id}_encoded"


        if strategy == 'label_encoding':
            encoder = LabelEncoder()
            # Create a new column for the encoded data to avoid overwriting original or if original is non-numeric
            new_col_name = f"{column_name}_label_encoded"
            df_copy[new_col_name] = encoder.fit_transform(df_copy[column_name])
            message = f"Label encoding applied to '{column_name}', new column: '{new_col_name}'."
            columns_added = [new_col_name]
            columns_removed = []
        
        elif strategy == 'one_hot_encoding':
            # Ensure the column is treated as categorical, even if it's numeric (e.g., 0, 1 representing categories)
            df_copy[column_name] = df_copy[column_name].astype('category')
            
            encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
            encoded_data = encoder.fit_transform(df_copy[[column_name]])
            
            # Create new column names for the one-hot encoded data
            # Using encoder.get_feature_names_out() is more robust if available (sklearn 0.24+)
            # For older versions, categories_ might be used.
            try:
                new_cols = encoder.get_feature_names_out([column_name])
            except AttributeError: # Fallback for older sklearn versions
                new_cols = [f"{column_name}_{cat}" for cat in encoder.categories_[0]]

            encoded_df = pd.DataFrame(encoded_data, columns=new_cols, index=df_copy.index)
            
            # Drop the original column and concatenate the new encoded columns
            df_copy = df_copy.drop(column_name, axis=1)
            df_copy = pd.concat([df_copy, encoded_df], axis=1)
            message = f"One-hot encoding applied to '{column_name}'. Original column dropped. New columns: {', '.join(new_cols)}."
            columns_added = list(new_cols)
            columns_removed = [column_name]

        else:
            return json.dumps({
                "error": f"Unsupported encoding strategy: '{strategy}'. Supported: 'label_encoding', 'one_hot_encoding'."
            })

        global_df_registry.register_dataframe(df_copy, df_id=df_id, raw_path=original_raw_path)
        
        return json.dumps({
            "df_id": df_id,
            "strategy_applied": strategy,
            "column_processed": column_name,
            "columns_added": columns_added,
            "columns_removed": columns_removed,
            "message": message
        })

    except Exception as e:
        return json.dumps({"error": f"An unexpected error occurred during encoding: {str(e)}"})

data_cleaning_tools.append(handle_categorical_encoding)


In [None]:
in_memory_store = InMemoryStore()
llm = ChatOpenAI(model="gpt-4o-mini", api_key=oai_key) # oai_key from cell 1

def create_data_cleaner_agent(initial_description:str, df_ids:List[str] = []) -> BaseChatModel:
  checkpointer = MemorySaver()
  tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in data_cleaning_tools])
  prompt = ChatPromptTemplate.from_messages([
      SystemMessage(content=data_cleaner_prompt_template.format(tool_descriptions=tool_descriptions, output_format=CleaningMetadata.model_json_schema(), dataset_description=initial_description, data_sample=None, available_df_ids=df_ids)),
      MessagesPlaceholder(variable_name="messages"),
  ])
  return create_react_agent(llm, tools=[*data_cleaning_tools], state_schema=State, checkpointer=checkpointer,store = in_memory_store,response_format= CleaningMetadata,prompt=prompt, name= "data_cleaner", version="v2")

def create_initial_analysis_agent(user_prompt:str, df_ids:List[str] = []) -> BaseChatModel:
  tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in analyst_tools])
  prompt = ChatPromptTemplate.from_messages([
      SystemMessage(content=analyst_prompt_template_initial.format(tool_descriptions=tool_descriptions, output_format=InitialDescription.model_json_schema(), user_prompt=user_prompt, available_df_ids=df_ids)),
      MessagesPlaceholder(variable_name="messages"),
  ])
  checkpointer = MemorySaver()
  return create_react_agent(llm, tools=analyst_tools, state_schema=State,checkpointer=checkpointer, store = in_memory_store,response_format= InitialDescription,prompt=prompt, name= "initial_analysis", version="v2")

def create_analyst_agent(initial_description:str, df_ids:List[str] = []) -> BaseChatModel:
  checkpointer = MemorySaver()
  tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in analyst_tools])
  prompt = ChatPromptTemplate.from_messages([
      SystemMessage(content=analyst_prompt_template_main.format(tool_descriptions=tool_descriptions, output_format=AnalysisInsights.model_json_schema(), cleaned_dataset_description=initial_description, cleaning_metadata=None, available_df_ids=df_ids)),
      MessagesPlaceholder(variable_name="messages"),
  ])
  return create_react_agent(llm, tools=analyst_tools, state_schema=State, response_format= AnalysisInsights,checkpointer=checkpointer,store = in_memory_store,prompt=prompt, name= "analyst", version="v2")

def create_file_writer_agent() -> BaseChatModel:
  # tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in file_writer_tools]) # Not used in original
  return create_react_agent(llm, tools=file_writer_tools, state_schema=State, checkpointer=MemorySaver(),store = in_memory_store)

def create_visualization_agent(df_ids:List[str] = []) -> BaseChatModel:
    tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in visualization_tools])
    prompt = ChatPromptTemplate.from_messages([
            SystemMessage(content=visualization_prompt_template.format(tool_descriptions=tool_descriptions, output_format=VisualizationResults.model_json_schema(), cleaned_dataset_description="", analysis_insights="", available_df_ids=df_ids)),
            MessagesPlaceholder(variable_name="messages"),
        ])
    checkpointer = MemorySaver()
    return create_react_agent(llm, tools=[*visualization_tools], state_schema=State, checkpointer=checkpointer, store=in_memory_store, response_format=VisualizationResults, prompt=prompt, name="visualization", version="v2")

def create_report_generator_agent(df_ids:List[str] = []) -> BaseChatModel:
    tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in report_generator_tools])
    prompt = ChatPromptTemplate.from_messages([
            SystemMessage(content=report_generator_prompt_template.format(tool_descriptions=tool_descriptions, output_format=ReportResults.model_json_schema(), cleaning_metadata="", analysis_insights="", visualization_results="", available_df_ids=df_ids)),
            MessagesPlaceholder(variable_name="messages"),
        ])
    checkpointer = MemorySaver()
    return create_react_agent(llm, tools=[*report_generator_tools], state_schema=State, checkpointer=checkpointer, store=in_memory_store, response_format=ReportResults, prompt=prompt, name="report_generator", version="v2")

def update_memory(state: MessagesState, config: RunnableConfig, *, store: BaseStore):
    user_id = config["configurable"]["user_id"]
    namespace = (user_id, "memories")
    memory_id = str(uuid.uuid4())
    store.put(namespace, memory_id, {"memory": state["messages"][-1].content})

def make_supervisor_node(llm: BaseChatModel, members: list[str]) -> str:
    options = ["FINISH"] + members
    system_prompt = (
        "You are a supervisor tasked with managing a conversation between the"
        f" following workers: {members}. Given the following user request,"
        " respond with the worker to act next. Each worker will perform a"
        " task and respond with their results and status. When finished,"
        " respond with FINISH."
    )
    class Router(TypedDict):
        next: Literal[*options]

    def supervisor_node(state: State) -> Command[Literal[*members, "__end__"]]:
        _count_ = state.get("_count_", 0) + 1
        state["_count_"] = _count_
        completed_str_parts = []
        agent_bool_map = {
            "initial_analysis": state.get("initial_analysis_complete", False),
            "data_cleaner": state.get("data_cleaning_complete", False),
            "analyst": state.get("analyst_complete", False),
            "file_writer": state.get("file_writer_complete", False),
            "visualization": state.get("visualization_complete", False),
            "report_generator": state.get("report_generator_complete", False),
        }
        for agent_name, is_complete in agent_bool_map.items():
            if is_complete:
                completed_str_parts.append(f"{agent_name} is complete, so dont pass to {agent_name} again.")
        
        current_system_prompt = system_prompt
        if completed_str_parts:
            current_system_prompt += "\n" + "\n".join(completed_str_parts)
            
        messages = [SystemMessage(content=current_system_prompt)] + state["messages"]
        response = llm.with_structured_output(Router).invoke(messages, state["_config"])
        goto = response["next"]
        if goto == "FINISH":
            goto = END

        print(f"Coordinator node: current state keys: {list(state.keys())}. Current count: {_count_}")
        # No need to manually create update_dict, LangGraph handles state updates from Command
        print(f"\nCoordinator node: routing to: {goto} \n")
        update_memory(state, state["_config"], store=in_memory_store)
        return Command(goto=goto)
    return supervisor_node

In [None]:
# Download sample dataset from Kagglehub
# Imports: pprint, os, pandas as pd, kagglehub are in essential_imports_cell

# !pip install kagglehub # Already in first cell

path = kagglehub.dataset_download("datafiniti/consumer-reviews-of-amazon-products")
print("Path to dataset files:", path)

raw_path_str = os.path.join(path, "Datafiniti_Amazon_Consumer_Reviews_of_Amazon_Products.csv")
pprint(raw_path_str)

df = None
ext = os.path.splitext(raw_path_str)[-1].lower()
try:
    if ext == ".csv":
        df = pd.read_csv(raw_path_str)
    elif ext == ".json":
        df = pd.read_json(raw_path_str)
    else:
        raise ValueError("Unsupported file format. Please use CSV or JSON.")
except Exception as e:
    print(f"Error reading file: {e}")
    raise

df_name = "Datafiniti_Amazon_Consumer_Reviews_of_Amazon_Products"
df_id = global_df_registry.register_dataframe(df, df_name, raw_path_str)

sample_prompt_text = f"Please analyze the dataset named {df_name}. You have tools available to you for accessing the data using the following str as the df_id parameter: `{df_id}`."
sample_prompt_tuple = ("user", sample_prompt_text)
pprint(sample_prompt_tuple)

# Agent Instantiations
data_cleaner_agent = create_data_cleaner_agent(initial_description=sample_prompt_text, df_ids=[df_id])
initial_analysis_agent = create_initial_analysis_agent(user_prompt=sample_prompt_text, df_ids=[df_id])
analyst_agent = create_analyst_agent(initial_description=sample_prompt_text, df_ids=[df_id])
file_writer_agent = create_file_writer_agent()
visualization_agent = create_visualization_agent(df_ids=[df_id])
report_generator_agent = create_report_generator_agent(df_ids=[df_id])

In [None]:
# Node Functions
def initial_analysis_node(state: State) -> Command[Literal["supervisor"]]:
    tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in analyst_tools])
    output_format = InitialDescription.model_json_schema()
    
    msgs = trim_messages(state["messages"],max_tokens=1000,token_counter=len) # Increased token limit
    system_message_content = analyst_prompt_template_initial.format(tool_descriptions=tool_descriptions, output_format=output_format, user_prompt=state["user_prompt"], available_df_ids=state["df_ids"])
    
    # Replace or add system message carefully
    final_msgs = [msg for msg in msgs if not isinstance(msg, SystemMessage)]
    final_msgs.insert(0, SystemMessage(content=system_message_content))

    # Memory retrieval (optional, if needed and configured)
    # namespace = (state["_config"]["configurable"]["user_id"], "memories")
    # memories = in_memory_store.search(namespace, query=state["messages"][-1].content, limit=3)
    # info = "\n".join([d.value["memory"] for d in memories])

    result = initial_analysis_agent.invoke({"messages":final_msgs}, config=state["_config"]) # Pass config
    print(f"Initial analysis result: {result}")
    update = {"messages": [HumanMessage(content=result["messages"][-1].content, name="initial_analysis")],"initial_description": result["structured_response"], "initial_analysis_complete": True}
    return Command(update=update, goto="supervisor")

def data_cleaner_node(state: State) -> Command[Literal["supervisor"]]:
  tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in data_cleaning_tools])
  output_format = CleaningMetadata.model_json_schema()
  msgs = trim_messages(state["messages"],max_tokens=1000,token_counter=len)
  system_message_content = data_cleaner_prompt_template.format(tool_descriptions=tool_descriptions, output_format=output_format, dataset_description=state["initial_description"].dataset_description, data_sample=state["initial_description"].data_sample, available_df_ids=state["df_ids"])
  final_msgs = [msg for msg in msgs if not isinstance(msg, SystemMessage)]
  final_msgs.insert(0, SystemMessage(content=system_message_content))
  
  result = data_cleaner_agent.invoke({"messages":final_msgs}, config=state["_config"])
  pprint(result)
  update = {"messages": [HumanMessage(content=result["messages"][-1].content, name="data_cleaner")],"cleaning_metadata": result["structured_response"], "data_cleaning_complete": True}
  return Command(update=update, goto="supervisor")

def analyst_node(state: State) -> Command[Literal["supervisor"]]:
    tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in analyst_tools])
    output_format = AnalysisInsights.model_json_schema()
    msgs = trim_messages(state["messages"],max_tokens=1000,token_counter=len)
    system_message_content = analyst_prompt_template_main.format(tool_descriptions=tool_descriptions, output_format=output_format, cleaned_dataset_description=state["cleaning_metadata"].data_description_after_cleaning, cleaning_metadata=state["cleaning_metadata"], available_df_ids=state["df_ids"])
    final_msgs = [msg for msg in msgs if not isinstance(msg, SystemMessage)]
    final_msgs.insert(0, SystemMessage(content=system_message_content))

    result = analyst_agent.invoke({"messages":final_msgs}, config=state["_config"])
    pprint(f"Analyst result: {result}")
    update = {"messages": [HumanMessage(content=result["messages"][-1].content, name="analyst")],"analysis_insights": result["structured_response"], "analyst_complete": True}
    return Command(update=update, goto="supervisor")

def file_writer_node(state: State) -> Command[Literal["supervisor"]]:
    # This node's prompt logic might need to be more dynamic based on what needs writing
    # For now, assuming state['user_prompt'] or a specific field in state contains file_name, content, file_type
    # tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in file_writer_tools])
    # This node was not using a specific prompt template in the agent creation, so direct invocation
    result = file_writer_agent.invoke(state["messages"], config=state["_config"]) # Pass messages directly or format based on expected input
    pprint(f"File writer result: {result}")
    # Assuming result structure matches AgentState and has a 'messages' and potentially a structured_response for file_writer_complete
    update = {"messages": [HumanMessage(content=result["messages"][-1].content, name="file_writer")], "file_writer_complete": True} # Mark as complete
    return Command(update=update, goto="supervisor")

def visualization_node(state: State) -> Command[Literal["supervisor"]]:
    tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in visualization_tools])
    output_format = VisualizationResults.model_json_schema()
    msgs = trim_messages(state["messages"], max_tokens=1000, token_counter=len)
    system_message_content = visualization_prompt_template.format(tool_descriptions=tool_descriptions, output_format=output_format, cleaned_dataset_description=state["cleaning_metadata"].data_description_after_cleaning, analysis_insights=state["analysis_insights"], available_df_ids=state["df_ids"])
    final_msgs = [msg for msg in msgs if not isinstance(msg, SystemMessage)]
    final_msgs.insert(0, SystemMessage(content=system_message_content))

    result = visualization_agent.invoke({"messages": final_msgs}, config=state["_config"])
    pprint(f"Visualization result: {result}")
    update = {"messages": [HumanMessage(content=result["messages"][-1].content, name="visualization")], "visualization_results": result["structured_response"], "visualization_complete": True}
    return Command(update=update, goto="supervisor")

def report_generator_node(state: State) -> Command[Literal["supervisor"]]:
    tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in report_generator_tools])
    output_format = ReportResults.model_json_schema()
    msgs = trim_messages(state["messages"], max_tokens=1000, token_counter=len)
    system_message_content = report_generator_prompt_template.format(tool_descriptions=tool_descriptions, output_format=output_format, cleaning_metadata=state["cleaning_metadata"], analysis_insights=state["analysis_insights"], visualization_results=state["visualization_results"], available_df_ids=state["df_ids"])
    final_msgs = [msg for msg in msgs if not isinstance(msg, SystemMessage)]
    final_msgs.insert(0, SystemMessage(content=system_message_content))

    result = report_generator_agent.invoke({"messages": final_msgs}, config=state["_config"])
    pprint(f"Report generator result: {result}")
    update = {"messages": [HumanMessage(content=result["messages"][-1].content, name="report_generator")], "report_results": result["structured_response"], "report_generator_complete": True}
    return Command(update=update, goto="supervisor")

In [None]:
coordinator_node = make_supervisor_node(llm, ["initial_analysis", "data_cleaner", "analyst", "file_writer", "visualization", "report_generator"])
config = {"configurable": {"thread_id": "thread-1","user_id": "user-1"},"recursion_limit": 150}

data_analysis_team_builder = StateGraph(State)
checkpointer = MemorySaver() # Ensure this is the same instance if used by agents for shared memory, or None if agents manage their own.

data_analysis_team_builder.add_node("supervisor", coordinator_node)
data_analysis_team_builder.add_node("initial_analysis", initial_analysis_node)
data_analysis_team_builder.add_node("data_cleaner", data_cleaner_node)
data_analysis_team_builder.add_node("analyst", analyst_node)
data_analysis_team_builder.add_node("file_writer", file_writer_node)
data_analysis_team_builder.add_node("visualization", visualization_node)
data_analysis_team_builder.add_node("report_generator", report_generator_node)

data_analysis_team_builder.add_edge(START, "supervisor")
data_analysis_team_builder.add_edge("initial_analysis", "supervisor")
data_analysis_team_builder.add_edge("data_cleaner", "supervisor")
data_analysis_team_builder.add_edge("analyst", "supervisor")
data_analysis_team_builder.add_edge("file_writer", "supervisor")
data_analysis_team_builder.add_edge("visualization", "supervisor")
data_analysis_team_builder.add_edge("report_generator", "supervisor")

data_detective_graph = data_analysis_team_builder.compile(checkpointer=checkpointer, store=in_memory_store) # store was used in agent creation

In [None]:
display(Image(data_detective_graph.get_graph().draw_mermaid_png()))

In [None]:
# Original Cell 11 (8D0coE-3QW11) - commented out astream_events
# async for event in data_detective_graph.astream_events(
#     {"messages": [sample_prompt_tuple], "user_prompt":sample_prompt_text,"_config":config, "df_ids":[df_id]},config,stream_mode="debug",subgraphs=True, debug=True, version="v2"):
#     kind = event["event"]
#     # ... (rest of the original commented code)

# Original Cell 13 (Mvg52ebjIn0_)
received_chunks = []
sample_prompt_final_human = HumanMessage(content=sample_prompt_text, name="user") # Ensure it's a HumanMessage

try:
    for s_chunk in data_detective_graph.stream(
        {"messages": [sample_prompt_final_human], "user_prompt":sample_prompt_text, "_config":config, "df_ids":[df_id]},
        config,
        stream_mode="updates", # or "values" or "debug"
        # subgraphs=True, # subgraphs might not be a param for stream, check docs if error
        # debug=False
    ):
        pprint(s_chunk)
        print("\n")
        received_chunks.append(s_chunk)
except Exception as e:
    pprint(f"Streaming Error: {e}")
    pprint(f"Received chunks before error: {received_chunks}")

pprint(list(data_detective_graph.get_state_history(config)))

In [None]:
try:
    last_state_history = data_detective_graph.get_state_history(config)
    if last_state_history:
        last_state = last_state_history[-1] # Get the most recent state snapshot
        print("\nFinal State Snapshot:")
        # pprint(last_state.values) # .values might be more appropriate for State object
        
        # Safely access keys, as they might not be present in all states
        final_insights = last_state.values.get('analysis_insights')
        if final_insights:
            pprint(f"Analyst result summary: {final_insights.summary}")
            pprint(f"Analyst result correlation_insights: {final_insights.correlation_insights}")
            pprint(f"Analyst result anomaly_insights: {final_insights.anomaly_insights}")
            pprint(f"Analyst result recommended_visualizations: {final_insights.recommended_visualizations}")
        else:
            print("Analysis insights not found in the final state.")
        
        final_report_results = last_state.values.get('report_results')
        if final_report_results:
            pprint(f"Report results: {final_report_results.report_path}")
        else:
            print("Report results not found in the final state.")
    else:
        print("No state history found.")
except Exception as e:
    print(f"Error accessing final state: {e}")

In [None]:
from pydantic import ValidationError
print("Testing GetDataParams Pydantic model:")
try:
    print(GetDataParams(df_id='test', index=0, columns='all'))
    print(GetDataParams(df_id='test', index=[0,1], columns='all'))
    print(GetDataParams(df_id='test', index=(0,2), columns='all'))
    for invalid_index in ['invalid', (1,), [1,'a']]:
        try:
            GetDataParams(df_id='test', index=invalid_index, columns='all')
        except ValidationError as e:
            print(f"Caught expected ValidationError for index={invalid_index}: {e}")
except NameError as ne:
    print(f"Pydantic models (GetDataParams) not defined yet or error in definition: {ne}")

In [None]:
import unittest
import shutil # For cleaning up test directory if needed, though TemporaryDirectory handles it

# Ensure PatchedDataFrameRegistry is defined if it's different from DataFrameRegistry
# For now, assuming DataFrameRegistry is the one to test and is defined.

class TestDataFrameRegistry(unittest.TestCase):
    def setUp(self):
        # Each test gets its own temporary directory and registry instance
        self.test_temp_dir = TemporaryDirectory()
        self.test_working_dir = Path(self.test_temp_dir.name)
        
        # Temporarily patch global WORKING_DIRECTORY for this test suite
        global WORKING_DIRECTORY
        self.original_working_directory = WORKING_DIRECTORY
        WORKING_DIRECTORY = self.test_working_dir
        
        self.registry = DataFrameRegistry(capacity=2)
        self.sample_df1 = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
        self.sample_df2 = pd.DataFrame({'C': [5, 6], 'D': [7, 8]})
        self.sample_df3 = pd.DataFrame({'E': [9, 10], 'F': [11, 12]})

        self.test_csv_path = self.test_working_dir / "test_load.csv"
        self.sample_df_for_csv = pd.DataFrame({'X': [100, 200]})
        self.sample_df_for_csv.to_csv(self.test_csv_path, index=False)

    def tearDown(self):
        global WORKING_DIRECTORY
        WORKING_DIRECTORY = self.original_working_directory # Restore original
        self.test_temp_dir.cleanup() # Explicitly clean up temp dir

    def test_register_and_get_dataframe(self):
        df_id1 = self.registry.register_dataframe(self.sample_df1, "df1")
        self.assertEqual(df_id1, "df1")
        retrieved_df1 = self.registry.get_dataframe("df1")
        pd.testing.assert_frame_equal(retrieved_df1, self.sample_df1)
        self.assertIn("df1", self.registry.cache)

    def test_get_dataframe_not_exists(self):
        retrieved_df = self.registry.get_dataframe("non_existent_df")
        self.assertIsNone(retrieved_df)

    def test_remove_dataframe(self):
        self.registry.register_dataframe(self.sample_df1, "df1")
        self.registry.remove_dataframe("df1")
        self.assertNotIn("df1", self.registry.registry)
        self.assertNotIn("df1", self.registry.cache)
        self.assertNotIn("df1", self.registry.df_id_to_raw_path)

    def test_cache_lru_eviction(self):
        self.registry.register_dataframe(self.sample_df1, "df1")
        self.registry.register_dataframe(self.sample_df2, "df2")
        self.registry.get_dataframe("df1") # Access df1 to make it most recently used
        self.registry.register_dataframe(self.sample_df3, "df3")
        self.assertIn("df1", self.registry.cache)
        self.assertIn("df3", self.registry.cache)
        self.assertNotIn("df2", self.registry.cache)

    def test_get_raw_path_from_id(self):
        raw_path_str = str(self.test_working_dir / "custom_path.csv")
        df_id = self.registry.register_dataframe(self.sample_df1, "df_custom", raw_path=raw_path_str)
        retrieved_path = self.registry.get_raw_path_from_id(df_id)
        self.assertEqual(retrieved_path, raw_path_str)

    def test_get_dataframe_load_if_not_exists(self):
        df_id_load = self.registry.register_dataframe(df=None, df_id="df_load", raw_path=str(self.test_csv_path))
        self.assertNotIn(df_id_load, self.registry.cache)
        loaded_df = self.registry.get_dataframe(df_id_load, load_if_not_exists=True)
        self.assertIsNotNone(loaded_df)
        pd.testing.assert_frame_equal(loaded_df, self.sample_df_for_csv)
        self.assertIn(df_id_load, self.registry.cache)

    def test_get_dataframe_load_if_not_exists_file_not_found(self):
        df_id_missing = self.registry.register_dataframe(df=None, df_id="df_missing", raw_path=str(self.test_working_dir / "non_existent.csv"))
        loaded_df = self.registry.get_dataframe(df_id_missing, load_if_not_exists=True)
        self.assertIsNone(loaded_df)

print("Running DataFrameRegistry tests...")
suite = unittest.TestSuite()
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestDataFrameRegistry))
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
print("Tests completed.")

In [None]:
# Imports from original cell 15
from functools import cache
from io import BytesIO
# from google.colab import drive # Not needed if not on Colab / driving mounting
import ipywidgets as widgets
# from IPython.display import display, clear_output # display already imported