In [417]:
import langgraph
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing_extensions import TypedDict
from langgraph.checkpoint.memory import MemorySaver

import os
from pydantic import BaseModel, Field
from typing import List, Dict, Any
from utils.gemini_service import GeminiModel, GeminiJsonEngine, GeminiSimpleChatEngine
from utils.logger import LOGGER
import time
import json

import hashlib
from sqlalchemy import create_engine, Column, String, Table, MetaData
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import text
import pandas as pd
import re

In [418]:
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/Users/debasmitroy/Desktop/programming/gemini-agent-assist/key.json"
os.environ["GOOGLE_CLOUD_PROJECT"] = "hackathon0-project"
os.environ["GOOGLE_CLOUD_LOCATION"] = "us-central1"

## Pydantic Base Classes

In [420]:
class Query(BaseModel):
    """
    Represents a financial query involving placeholders for key financial attributes: 
    <BUIS>, <DATE>, <NET>, <FACTOR>, <PROF_LOSS>, <CUR>, <PF>, and <DSK>.  

    Example Queries:  
    - What are the <FACTOR>s that contributed the highest <NET> profit/loss?  
    - Which <CUR> currencies are driving the top-performing portfolios <PF>?  
    """
    query: str = Field(..., title="Financial query using placeholders <BUIS>, <DATE>, <NET>, <FACTOR>, <PROF_LOSS>, <CUR>, <PF>, and <DSK>.")

class FinancialQueries(BaseModel):
    """
    A collection of structured queries designed to generate financial summaries.  
    Each query should use placeholders (<FIELD>) instead of actual values.  
    """
    queries: List[Query] = Field(..., title="List of financial queries using placeholders <FIELD> instead of actual values.")

In [421]:
class SQLScript(BaseModel):
    """
    SQL Script to query data from the given table. You have to use this tool to generate the SQL script.
    """
    sql_script: str = Field(..., title="SQL Script to query data from the given table.")
    columns: List[str] = Field(..., title="Which columns are being projected in the SQL script.")
    description: str = Field(..., title="What does the SQL script do in Finance Analyst's perspective")

## Agent State

In [422]:
class AgentState(TypedDict):
    state: str
    model: Dict[str, Any]
    results: Dict[str, Any]
    cache_location: Dict[str, Any]
    cache_flag: Dict[str, Any]

## Generic Utils

In [423]:
def load_json_data(path: str) -> Dict[str, Any]:
    if not os.path.exists(path):
        return None
    with open(path, 'r') as file:
        data = json.load(file)
    return data

def save_json_data(path: str, data: Dict[str, Any]):
    # Create the directory if it doesn't exist
    directory = os.path.dirname(path)
    if not os.path.exists(directory):
        os.makedirs(directory, exist_ok=True)
    with open(path, 'w') as file:
        json.dump(data, file, indent=4)

def load_cached_results(state):
    # Load the data from cache if the cache flag is set to True
    if state['cache_flag'][state['state']]:
        cached_result = load_json_data(state['cache_location'][state['state']])
        if cached_result:
            state['results'][state['state']] = cached_result['result']
            LOGGER.info(f"State: {state['state']} | Loaded cached data and skipping the model, {len(state['results'][state['state']])} old result found")
            return state
    return None

## SQL Engine

In [452]:
class InMemoryDB:
    def __init__(self):
        self.engine = create_engine('sqlite:///:memory:', echo=False)
        self.metadata = MetaData()
        self.Session = sessionmaker(bind=self.engine)
        self.session = self.Session()

    def create_table(self, table_name, columns):
        """Creates a table dynamically with a SHA-256 hash primary key to prevent duplicates."""
        table = Table(
            table_name, self.metadata,
            Column("id", String, primary_key=True),  # Primary key hash column
            *[Column(col, String) for col in columns],
        )
        table.create(self.engine)

    def generate_hash(self, data):
        """Generates a SHA-256 hash over the string representation of a row."""
        row_string = str(sorted(data.items()))  # Ensure consistent ordering
        return hashlib.sha256(row_string.encode()).hexdigest()

    def insert_data(self, table_name, data):
        """Inserts a row into the table using parameterized queries and avoids duplicates."""
        data["id"] = self.generate_hash(data)  # Add hash key to data
        placeholders = ", ".join([f":{key}" for key in data.keys()])
        query = text(f"""
            INSERT INTO {table_name} ({', '.join(data.keys())})
            VALUES ({placeholders})
            ON CONFLICT(id) DO NOTHING
        """)
        self.session.execute(query, data)
        self.session.commit()

    def query_data(self, query):
        """Executes a SELECT query and returns results with column names."""
        result = self.session.execute(text(query))
        columns = result.keys()  # Get column names
        data = result.fetchall()  # Get data rows

        # Convert data to list of list from list of tuples
        data = [list(row) for row in data]
        return list(columns), data  # Return both columns and data

    def __del__(self):
        self.session.close()

In [453]:
def get_title_data_inmemory_db(rule_based_title_comment_data):
    # Initialize DB and create table
    columns = list(rule_based_title_comment_data[0].keys())
    columns.remove("id") if "id" in columns else None  # Ensure id isn't duplicated
    inmemory_db = InMemoryDB()
    inmemory_db.create_table("title_data", columns)

    # Insert data
    for data in rule_based_title_comment_data:
        inmemory_db.insert_data("title_data", data)

    return inmemory_db

In [454]:
TITLE_DATA_INMEM_DB = None

## Agent utils

In [455]:
def start_agent(state: AgentState):
    LOGGER.info("Starting the agent-assist")
    state['state'] = 'start'
    return state

In [456]:
def end_agent(state: AgentState):
    LOGGER.info("Ending the agent-assist")
    state['state'] = 'end'
    return state

In [457]:
def refine_old_summary_agent(state: AgentState):
    state['state'] = 'refine_old_summaries'
    LOGGER.info(f"State: {state['state']} | Initializing the agent to refine old summaries")

    # Load the data from cache if the cache flag is set to True
    cached_result = load_cached_results(state)
    if cached_result:
        return state
    
    # Initialize the model
    gemini_simple_chat_engine = GeminiSimpleChatEngine(model_name=state['model']['model_name'], 
                                                   temperature=state['model']['temperature'],
                                                   max_output_tokens=state['model']['max_output_tokens'],
                                                   systemInstructions="You are an expert financial bot. You will be given a financial report and you need to refine the report. Keep everything in a single large paragraph. Dont use any markdown or bullet points. ",
                                                   max_retries=state['model']['max_retries'],
                                                   wait_time=state['model']['wait_time'])

    # Old summaries from sample_summarized_pnl_commentaries (Note: This is a sample data not cached, admin will provide the data)
    sample_summarized_pnl_commentaries = load_json_data(state['cache_location']['sample_summarized_pnl_commentaries'])
    LOGGER.info(f"State: {state['state']} | Loaded the sample data, {len(sample_summarized_pnl_commentaries)} old summaries found")
    
    # Refine the old summaries
    result = []

    for summary in sample_summarized_pnl_commentaries:
        _refinement_prompt = [
            f"Given financial report: {summary}",
            f"Please refine the financial report in a more readable and meangingful way without losing any important information and entitites and technical/financial terms. Dont unnecessarily change the meaning of the report and dont increase the length of the report. "
        ]
        refined_summary = gemini_simple_chat_engine(_refinement_prompt)
        result.append(refined_summary)
        LOGGER.info(f"State: {state['state']} | Summary refined , {summary[:30]}... to {refined_summary[:30]}...")

    # Save the result to state var and set the cache flag to True
    state['results'][state['state']] = result
    state['cache_flag'][state['state']] = True

    # Save the result to cache with the state name and {result} key
    save_json_data(state['cache_location'][state['state']], {"result":state['results'][state['state']]})
    
    LOGGER.info(f"State: {state['state']} | Refinement of old summaries completed, saved the result to cache and set the cache flag to True")
    return state

In [458]:
def generate_subj_query_agent(state: AgentState):
    state['state'] = 'subj_query_generation'
    LOGGER.info(f"State: {state['state']} | Initializing the agent to refine old summaries")

    # Load the data from cache if the cache flag is set to True
    cached_result = load_cached_results(state)
    if cached_result:
        return state
    
    # Initialize the model
    fin_qry_engine =  GeminiJsonEngine(
                                    model_name=state['model']['model_name'],
                                    basemodel=FinancialQueries,
                                    temperature=state['model']['temperature'],
                                    max_output_tokens=state['model']['max_output_tokens'],
                                    systemInstructions=None,
                                    max_retries=state['model']['max_retries'],
                                    wait_time=state['model']['wait_time'])

    # Outputs from the previous state
    refined_sample_summarized_pnl_commentaries = state['results']['refine_old_summaries']

    # Prompt
    title_comment_template = "For Buisness <BUIS>, on  <DATE>, driven by <NET>$  <FACTOR> <PROF_LOSS> to PL on <CUR> Currency on Portfolio <PF> and Desk <DSK>"
    user_prompt_list = [
    "You are a financial assistant. Your task is to generate structured queries from given templates to create financial summaries.",
    
    f"Here is an example pattern for financial summaries: {refined_sample_summarized_pnl_commentaries[0]}.",
    
    f"You are provided with a list of rule-based templates in the format List[{title_comment_template}]. Extract meaningful queries from these templates.",
    
    """Generate at least 15 diverse queries that can be used to generate sample financial summaries.
    
    - The queries should focus on aggregations such as min, max, mean, and sum, or retrieve the top 5 / bottom 5 entities.  
    - Avoid queries that fetch all rows or list all entities without aggregation.  
    - Do not create separate queries for different aggregations on the same entity; instead, combine them into a single query.  
    - Dont ask for a particular value; instead, ask for a top k or bottom k value. Say, top 5 Business Units or bottom 5 Desks.
    - The queries should be sufficient to address the financial summary patterns mentioned above.  
    - Replace all field values with placeholders using the format <FIELD>. Do not include actual values.  
    - Do not summarize the data; just generate structured queries.""",
    
    "You must use the tool `FinancialQueries`. Your response must strictly follow the argument structure of `FinancialQueries`."
    ]

    # Generate queries
    queries = fin_qry_engine(user_prompt_list)[0]['queries']
    LOGGER.info(f"State: {state['state']} | Generated {len(queries)} queries")

    # Save the result to state var and set the cache flag to True
    state['results'][state['state']] = queries
    state['cache_flag'][state['state']] = True

    # Assign id to each query with sha hash
    for query in state['results'][state['state']]:
        query['id'] = hashlib.sha256(json.dumps(query).encode()).hexdigest()

    # Save the result to cache with the state name and {result} key
    save_json_data(state['cache_location'][state['state']], {"result":state['results'][state['state']]})

    LOGGER.info(f"State: {state['state']} | Query generation completed, saved the result to cache and set the cache flag to True")
    return state

In [459]:
def generate_stat_query_agent(state: AgentState):
    state['state'] = 'stat_query_generation'
    LOGGER.info(f"State: {state['state']} | Initializing the agent to refine old summaries")

    # Load the data from cache if the cache flag is set to True
    cached_result = load_cached_results(state)
    if cached_result:
        return state
    
    ## It is not an AI task, it is a rule-based task. So, we can directly write the code here.

    statistical_queries = [
        "What is average, max, min, varaince, sum of <NET> profit/loss?",
        "What is average, max, min, varaince, sum of <FACTOR> profit/loss grouped by <BUIS>?",
        "What is average, max, min, varaince, sum of <NET> profit/loss grouped by <CUR> currency?",
        "What is average, max, min, varaince, sum of <NET> profit/loss grouped by top 5 <PF> portfolios?",
        "What is average, max, min, varaince, sum of <NET> profit/loss grouped by bottom 5 <PF> portfolios?",
        "What is average, max, min, varaince, sum of <NET> profit/loss grouped by top 5 <DSK> desks?",
        "What is average, max, min, varaince, sum of <NET> profit/loss grouped by bottom 5 <DSK> desks?",
        "What are the top currencies by average <NET> profit/loss?",
        "What are the bottom currencies by average <NET> profit/loss?",
        "What is the total count of transactions for each <FACTOR>?",
        "What is the percentage contribution of each <FACTOR> to total profit/loss?",
        "What is the trend of total <NET> profit/loss over time (daily, monthly, yearly)?",
        "What is the moving average of <NET> profit/loss over the past 7 days?",
        "What is the standard deviation of <NET> profit/loss grouped by <BUIS>?",
        "What is the correlation between <FACTOR> and <NET> profit/loss?",
        "What is the skewness and kurtosis of <NET> profit/loss distribution?",
        "Which <PF> portfolios have the highest standard deviation in <NET> profit/loss?",
        "Which <DSK> desks have the highest variance in <NET> profit/loss?",
        "What is the probability distribution of <NET> profit/loss?",
        "What is the cumulative sum of <NET> profit/loss over time?",
        "Which <CUR> currency has the most volatile <NET> profit/loss?",
        "What is the ratio of profitable to loss-making transactions per <FACTOR>?",
        "Which <PF> portfolios have the most consistently positive (low variance) profits?",
        "Which <FACTOR> contributes most to total profit/loss variance?",
        "What is the ratio of profit to loss per <DSK> desk?",
        "Which <CUR> currency contributes the most to total profit?",
        "Which <FACTOR> has the highest frequency of losses?"
    ]
    statistical_queries = [{"query": query} for query in statistical_queries]

    # Save the result to state var and set the cache flag to True
    state['results'][state['state']] = statistical_queries
    state['cache_flag'][state['state']] = True

    # Assign id to each query with sha hash
    for query in state['results'][state['state']]:
        query['id'] = hashlib.sha256(json.dumps(query).encode()).hexdigest()
        
    # Save the result to cache with the state name and {result} key
    save_json_data(state['cache_location'][state['state']], {"result":state['results'][state['state']]})

    LOGGER.info(f"State: {state['state']} | Query generation completed, saved the result to cache and set the cache flag to True")

    return state

In [460]:
def register_data(state: AgentState):
    state['state'] = 'register_data'
    LOGGER.info(f"State: {state['state']} | Initializing the agent to register data")

    # It is a very naive implementation, we can directly write the code here.
    rule_based_title_comment_data  = load_json_data(state['cache_location']['rule_based_title_comment_data'])
    
    global TITLE_DATA_INMEM_DB
    TITLE_DATA_INMEM_DB = get_title_data_inmemory_db(rule_based_title_comment_data)

    LOGGER.info(f"State: {state['state']} | Data registration completed, saved the data to in-memory DB. Global variable TITLE_DATA_INMEM_DB is set")

    return state

In [461]:
def generate_sql_script_agent(state: AgentState):
    state['state'] = 'sql_script_generation'
    LOGGER.info(f"State: {state['state']} | Initializing the agent to generate SQL script")

    # Load the data from cache if the cache flag is set to True
    cached_result = load_cached_results(state)
    if cached_result:
        return state

    # Initialize the model
    sql_script_engine =  GeminiJsonEngine(
                                    model_name=state['model']['model_name'],
                                    basemodel=SQLScript,
                                    temperature=state['model']['temperature'],
                                    max_output_tokens=state['model']['max_output_tokens'],
                                    systemInstructions="You are an expert financial bot. You will be given a table and you need to generate a SQL script to query the data from the table. ",
                                    max_retries=state['model']['max_retries'],
                                    wait_time=state['model']['wait_time'])

    # Previous state outputs
    subj_queries = state['results']['subj_query_generation']
    stat_queries = state['results']['stat_query_generation']
    all_queries = subj_queries + stat_queries

    global TITLE_DATA_INMEM_DB
    # Head of the table
    _rule_based_title_comment_data_cols,_rule_based_title_comment_data_head = TITLE_DATA_INMEM_DB.query_data("SELECT * FROM title_data LIMIT 5")
    head = pd.DataFrame(_rule_based_title_comment_data_head, columns=_rule_based_title_comment_data_cols).drop(columns=['id','COMMENT']).head()

    sql_scripts = []
    for i, query in enumerate(all_queries):
        user_sql_prompt = [
            f"You are a SQL expert. Your task is to write a SQL script to query data from the given table. Note: you are generating a SQL script for SQLLite's python library. You must be careful while writing complex queries as it is very sensitive.",
            f"Library specific notes: STDDEV is not supported in SQLLite. You can use AVG and SUM to calculate the standard deviation.",
            f"Here is the schema of the table `title_data`: {TITLE_DATA_INMEM_DB.metadata.tables}",
            f"Here is the are the first few rows of the table `title_data`: {head}",
            f"User is trying to answer the following query: {query['query']}",
            f"Write a SQL script to answer the query using the tool `SQLScript`. Your answer must follow the argument strucure of the tool `SQLScript`. You are encouraged to use compound and complex SQL queries to answer the query."
        ]
        sql_script = sql_script_engine(user_sql_prompt)[0]
        sql_scripts.append(sql_script)
        LOGGER.info(f"State: {state['state']} | {i}/{len(all_queries)} SQL script generated for the query: {query['query'][:20]} ... to {sql_script['sql_script'][:20]} ...")

    # Save the result to state var and set the cache flag to True
    state['results'][state['state']] = sql_scripts
    state['cache_flag'][state['state']] = True

    # Assign id to each query with the same hash as the query
    for raw_query, sql_script in zip(all_queries, state['results'][state['state']]):
        sql_script['id'] = raw_query['id']

    # Save the result to cache with the state name and {result} key
    save_json_data(state['cache_location'][state['state']], {"result":state['results'][state['state']]})

    LOGGER.info(f"State: {state['state']} | SQL script generation completed, saved the result to cache and set the cache flag to True")
    return state

In [488]:
def sql_result_agent(state:AgentState):
    state['state'] = 'sql_result'
    LOGGER.info(f"State: {state['state']} | Initializing the agent to generate SQL result")

    # Load the data from cache if the cache flag is set to True
    cached_result = load_cached_results(state)
    if cached_result:
        return state

    global TITLE_DATA_INMEM_DB

    # Previous state outputs
    sql_scripts = state['results']['sql_script_generation']

    # Execute the SQL scripts
    sql_results = []
    pass_count = 0
    fail_count = 0
    overlength_count = 0
    for i, sql_script in enumerate(sql_scripts):
        try:
            columns, data = TITLE_DATA_INMEM_DB.query_data(sql_script['sql_script'])
            sql_results.append({
                "id": sql_script['id'],
                "columns": columns,
                "data": data,
                "status": "success",
                "description": sql_script['description'],
                "sql_script": sql_script['sql_script']
            })
            if len(data) < 20:
                LOGGER.info(f"State: {state['state']} | {i}/{len(sql_scripts)} SQL script executed, {len(data)} rows returned")
                pass_count += 1
            else:
                LOGGER.warning(f"State: {state['state']} | {i}/{len(sql_scripts)} SQL script executed, {len(data)} rows returned. Too many rows returned, consider refining the query")
                sql_results[-1]['status'] = "overlength"
                overlength_count += 1

        except Exception as e:
            LOGGER.error(f"State: {state['state']} | {i}/{len(sql_scripts)} SQL script execution failed: {str(e)}. Skipping the query")
            sql_results.append({
                "id": sql_script['id'],
                "columns": [],
                "data": [],
                "status": "failed",
                "description": sql_script['description'],
                "sql_script": sql_script['sql_script']
            })
            fail_count += 1

    LOGGER.info(f"State: {state['state']} | SQL script execution completed, {pass_count} passed, {fail_count} failed, {overlength_count} overlength")

    # Save the result to state var and set the cache flag to True
    state['results'][state['state']] = sql_results
    state['cache_flag'][state['state']] = True

    # Save the result to cache with the state name and {result} key
    save_json_data(state['cache_location'][state['state']], {"result":state['results'][state['state']]})

    LOGGER.info(f"State: {state['state']} | SQL result generation completed, saved the result to cache and set the cache flag to True")

    return state
    

In [None]:
def generate_bucket_query_agent(state: AgentState):
    state['state'] = 'bucket_query_generation'
    LOGGER.info(f"State: {state['state']} | Initializing the agent to generate bucket queries")

    # It is not an AI task, it is a rule-based task. So, we can directly write the code here.

    # Load previous state outputs
    old_summary = state['results']['refine_old_summaries'][0]
    subj_results = state['results']['subj_query_generation']
    stat_results = state['results']['stat_query_generation']
    sql_results = [result for result in state['results']['sql_result'] if result['status'] == 'success']

    # Subj IDs and  Stat IDs
    subj_ids = [query['id'] for query in subj_results]
    stat_ids = [query['id'] for query in stat_results]

    # Filter SQL results with subj and stat IDs
    sql_subj_results = [result for result in sql_results if result['id'] in subj_ids]
    sql_stat_results = [result for result in sql_results if result['id'] in stat_ids]

    # Bucket queries
    sql_subj_result_qa_s = "\n".join([f"Q{i}. {result['description']}:\n{pd.DataFrame(result['data'], columns=result['columns']).head()}" for i, result in enumerate(sql_subj_results)])
    sql_stat_result_qa_s = "\n".join([f"Q{i}. {result['description']}:\n{pd.DataFrame(result['data'], columns=result['columns']).head()}" for i, result in enumerate(sql_stat_results)])

    
    prompts = [[
        f"You are financial expert. Your task is to provide a DETAILED and SOPHISTICATED financial summary over P&L Trend and other financial metrics. You are provided with some structured questions and their answers. You need to generate the summary from the insights provided in the answers.",
        f"The meaning for the columns are as follows: BUIS means Business Unit, DATE means Day of calcualtion, NET means Net Profit/Loss, Factor such as IRDelta, IRGamma, FXDelta etc., PROF_LOSS means Profit or Loss, CUR means Currency, PF means Portfolio, DSK means Desk.",
        f"Here is a sample summary for reference (Just follow the pattern, not the exact values): {old_summary}",
        f"Here are the structured questions and their answers: {qa}",
        f"Generate a detailed financial summary based on the insights provided in the answers."
    ] for qa in [sql_subj_result_qa_s, sql_stat_result_qa_s]]


    # Update the state
    state['results'][state['state']] = prompts

    LOGGER.info(f"State: {state['state']} | Bucket queries generated")

    return state

In [554]:
def generate_final_result(state: AgentState):
    state['state'] = 'final_result'
    LOGGER.info(f"State: {state['state']} | Initializing the agent to generate final result")

    # Load the data from cache if the cache flag is set to True
    cached_result = load_cached_results(state)
    if cached_result:
        return state

    # Initialize the model
    gemini_simple_chat_engine = GeminiSimpleChatEngine(model_name=state['model']['model_name'], 
                                                   temperature=state['model']['temperature'],
                                                   max_output_tokens=1024,
                                                   systemInstructions=None,
                                                   max_retries=state['model']['max_retries'],
                                                   wait_time=state['model']['wait_time'])

    # Previous state outputs
    bucket_queries = state['results']['bucket_query_generation']

    # Generate the final result
    final_results = []
    for i, bucket_query in enumerate(bucket_queries):
        final_result = gemini_simple_chat_engine(bucket_query)
        final_results.append(final_result)
        LOGGER.info(f"State: {state['state']} | {i}/{len(bucket_queries)} Final result generated from the bucket query. {final_result[:30]}...")

    # Save the result to state var and set the cache flag to True
    state['results'][state['state']] = final_results
    state['cache_flag'][state['state']] = True

    # Save the result to cache with the state name and {result} key
    save_json_data(state['cache_location'][state['state']], {"result":state['results'][state['state']]})


    LOGGER.info(f"State: {state['state']} | Final result generation completed, saved the result to cache and set the cache flag to True")
    return state

## Agent Builder

In [555]:
class MyAgent:
    def __init__(self, thread_id=None):
        self.config = None
        self.app = None
        self.build(thread_id)

    def build(self, thread_id):
        workflow = StateGraph(AgentState)

        # Nodes
        workflow.add_node('start', start_agent)
        workflow.add_node('end', end_agent)
        workflow.add_node('refine_old_summaries', refine_old_summary_agent)
        workflow.add_node('generate_subj_queries', generate_subj_query_agent)
        workflow.add_node('generate_stat_queries', generate_stat_query_agent)
        workflow.add_node('register_data', register_data)
        workflow.add_node('sql_script_generation', generate_sql_script_agent)
        workflow.add_node('sql_result_generation', sql_result_agent)
        workflow.add_node('bucket_query_generation', generate_bucket_query_agent)
        workflow.add_node('final_result', generate_final_result)

        # Edges
        workflow.add_edge('start', 'refine_old_summaries')
        workflow.add_edge('refine_old_summaries', 'generate_subj_queries')
        workflow.add_edge('generate_subj_queries', 'generate_stat_queries')
        workflow.add_edge('generate_stat_queries', 'register_data')
        workflow.add_edge('register_data', 'sql_script_generation')
        workflow.add_edge('sql_script_generation', 'sql_result_generation')
        workflow.add_edge('sql_result_generation', 'bucket_query_generation')
        workflow.add_edge('bucket_query_generation', 'final_result')
        workflow.add_edge('final_result', 'end')

        # Compile
        workflow.set_entry_point('start')
        memory = MemorySaver()
        self.app = workflow.compile(checkpointer=memory)
        self.config = {"configurable":{"thread_id":str(thread_id)}}

    def get_recent_state_snap(self):
        snap = self.app.get_state(self.config).values.copy()
        return snap
    
    def get_graph(self):
        graph = self.app.get_graph(xray=True)
        return graph
    
    def continue_flow(self, state):
        self.app.invoke(state, config=self.config)
        return self.get_recent_state_snap()

In [556]:
MY_AGENT = MyAgent(thread_id=1)

In [557]:
MY_AGENT.get_recent_state_snap()

{}

In [558]:
graph = MY_AGENT.get_graph()
print(graph.draw_ascii())

       +-----------+         
       | __start__ |         
       +-----------+         
              *              
              *              
              *              
         +-------+           
         | start |           
         +-------+           
              *              
              *              
              *              
  +----------------------+   
  | refine_old_summaries |   
  +----------------------+   
              *              
              *              
              *              
 +-----------------------+   
 | generate_subj_queries |   
 +-----------------------+   
              *              
              *              
              *              
 +-----------------------+   
 | generate_stat_queries |   
 +-----------------------+   
              *              
              *              
              *              
     +---------------+       
     | register_data |       
     +---------------+       
          

In [564]:
MY_AGENT.continue_flow({
    'state': 'start',
    'model':{
        'model_name':'gemini-2.0-flash-001',
        'temperature': 0.5,
        'max_output_tokens': 512,
        'max_retries':5,
        'wait_time':30
    },
    'results':{
        'refine_old_summaries':[],
        'subj_query_generation':[],
        'stat_query_generation':[],
        'sql_script_generation':[],
        'sql_result':[],
        'bucket_query_generation':[],
        'final_result':[]
    },
    'cache_location':{
        "sample_summarized_pnl_commentaries":"../sample_data/sample_summarized_pnl_commentaries.json",
        "rule_based_title_comment_data":"../sample_data/rule_based_title_comment_data.json",
        
        "refine_old_summaries":"../sample_data/cached/refine_old_summaries.json",
        "subj_query_generation":"../sample_data/cached/subj_query_generation.json",
        "stat_query_generation":"../sample_data/cached/stat_query_generation.json",
        "sql_script_generation":"../sample_data/cached/sql_script_generation.json",
        "sql_result":"../sample_data/cached/sql_result.json",
        "final_result":"../sample_data/cached/final_result.json"
    },
    'cache_flag':{
        "refine_old_summaries":True,
        "subj_query_generation":True,
        "stat_query_generation":True,
        "sql_script_generation":True,
        "sql_result":True,
        "final_result":True
    }
});

[1;32m2025-03-10 02:20:57,638 - INFO ==> Starting the agent-assist[0m
[1;32m2025-03-10 02:20:57,640 - INFO ==> State: refine_old_summaries | Initializing the agent to refine old summaries[0m
[1;32m2025-03-10 02:20:57,640 - INFO ==> State: refine_old_summaries | Loaded cached data and skipping the model, 3 old result found[0m
[1;32m2025-03-10 02:20:57,641 - INFO ==> State: subj_query_generation | Initializing the agent to refine old summaries[0m
[1;32m2025-03-10 02:20:57,641 - INFO ==> State: subj_query_generation | Loaded cached data and skipping the model, 15 old result found[0m
[1;32m2025-03-10 02:20:57,642 - INFO ==> State: stat_query_generation | Initializing the agent to refine old summaries[0m
[1;32m2025-03-10 02:20:57,643 - INFO ==> State: stat_query_generation | Loaded cached data and skipping the model, 27 old result found[0m
[1;32m2025-03-10 02:20:57,644 - INFO ==> State: register_data | Initializing the agent to register data[0m
[1;32m2025-03-10 02:20:57,756

In [565]:
snap = MY_AGENT.get_recent_state_snap()

In [566]:
print(snap['results']['final_result'][0])

Overall, the top 5 factors impacting net profit/loss were Credit Spread (+€204m), Bond Basis (+€200m), IRGamma (+€193m), Theta (+€184m), and FXDelta (+€171m). However, specific desks, particularly US/LDN and LATAM/NYC, experienced the lowest net profit/loss.

Analyzing portfolio performance, the American London CEEMAEA Portfolio stands out with the highest overall net profit (+€392m), followed by LATAM Emerging (+€380m) and European CEEMAEA (+€346m). Within the American London CEEMAEA Portfolio, Bond Basis contributed +€78m, Credit Spread +€61m, FXDelta +€58m, IRDelta +€63m, and IRGamma +€67m.

Business unit performance, broken down by currency, reveals that CEEMAEA's net profit was significantly influenced by INR (+€94m), MXN (+€83m), CZK (+€82m), EUR (+€79m), and BRL (+€75m). However, some CEEMAEA and LATAM business units experienced the lowest net profit/loss overall.

Desk-level analysis shows that Bond Basis had a significant impact on LATAM/NYC DSK (+€81m), while Credit Spread an

In [567]:
print(snap['results']['final_result'][1])

The average Net Profit/Loss (NET) across all transactions is approximately €1.01 million, with a total sum of €1.12 billion across 1103 transactions. The skewness of 0.015 and kurtosis of -1.29 indicate a near-normal distribution with slightly flatter tails.

**Business Unit (BUIS) Performance:** Both CEEMAEA and LATAM Business Units show an equal distribution of profit and loss transactions, with a sum of €0.

**Currency (CUR) Performance:** The top-performing currencies by average NET are INR (€1.05m), GBP (€1.05m), and EUR (€1.04m). The bottom-performing currencies are BRL (€0.90m) and USD (€0.97m). USD has the highest total profit with €0.

**Portfolio (PF) Performance:** Among the bottom three portfolios by average NET, American London CEEMAEA Portfolio has the highest average NET at €1.06m, followed by LATAM Emerging Portfolio (€1.03m) and European CEEMAEA Portfolio (€0.95m). The portfolios with the highest standard deviation in NET are European CEEMAEA Portfolio, American London