In [2]:
# Cell 1: Import Libraries and Setup
import os
import pandas as pd
import sqlite3
from dotenv import load_dotenv
import warnings
warnings.filterwarnings('ignore')

# LangChain imports
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_types import AgentType
from langchain.callbacks import StdOutCallbackHandler
from langchain.schema import HumanMessage, SystemMessage

# Data visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go

print("All libraries imported successfully!")

All libraries imported successfully!


In [3]:
#Cell 2: Load Environment Variables and Setup API
load_dotenv()

# Get API key
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
if not GOOGLE_API_KEY:
    raise ValueError("Please set your GOOGLE_API_KEY in the .env file")

# Initialize Gemini model
llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash-lite",
    google_api_key=GOOGLE_API_KEY,
    temperature=0,
    convert_system_message_to_human=True
)

print("Gemini API initialized successfully!")

Gemini API initialized successfully!


In [4]:
# Cell 3: Load and Inspect Data
# Load the CSV data
df = pd.read_csv('../data/synthetic_bank_customer_data.csv')

print("Dataset Info:")
print(f"Shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
print("\nFirst 5 rows:")
print(df.head())
print("\nData types:")
print(df.dtypes)
print("\nMissing values:")
print(df.isnull().sum())


Dataset Info:
Shape: (4014, 9)
Columns: ['Customer ID', 'Name', 'Surname', 'Gender', 'Age', 'Region', 'Job Classification', 'Date Joined', 'Balance']

First 5 rows:
   Customer ID     Name  Surname  Gender  Age            Region  \
0    100000001    Simon    Walsh    Male   21           England   
1    400000002  Jasmine   Miller  Female   34  Northern Ireland   
2    100000003     Liam    Brown    Male   46           England   
3    300000004   Trevor     Parr    Male   32             Wales   
4    100000005  Deirdre  Pullman  Female   38           England   

  Job Classification Date Joined    Balance  
0       White Collar   05.Jan.15  113810.15  
1        Blue Collar   06.Jan.15   36919.73  
2       White Collar   07.Jan.15  101536.83  
3       White Collar   08.Jan.15    1421.52  
4        Blue Collar   09.Jan.15   35639.79  

Data types:
Customer ID             int64
Name                   object
Surname                object
Gender                 object
Age                    

In [5]:
# Cell 4: Data Preprocessing and Cleaning
# Clean column names (remove spaces, standardize)
df.columns = [col.strip().replace(' ', '_').lower() for col in df.columns]
print("Cleaned column names:", list(df.columns))

# Convert date column to datetime
if 'date_joined' in df.columns:
    df['date_joined'] = pd.to_datetime(df['date_joined'])

# Create age groups for better analysis
df['age_group'] = pd.cut(df['age'], 
                        bins=[0, 25, 35, 50, 65, 100], 
                        labels=['18-25', '26-35', '36-50', '51-65', '65+'])

print("\nProcessed data info:")
print(df.head())
print(f"\nAge groups distribution:\n{df['age_group'].value_counts()}")

Cleaned column names: ['customer_id', 'name', 'surname', 'gender', 'age', 'region', 'job_classification', 'date_joined', 'balance']

Processed data info:
   customer_id     name  surname  gender  age            region  \
0    100000001    Simon    Walsh    Male   21           England   
1    400000002  Jasmine   Miller  Female   34  Northern Ireland   
2    100000003     Liam    Brown    Male   46           England   
3    300000004   Trevor     Parr    Male   32             Wales   
4    100000005  Deirdre  Pullman  Female   38           England   

  job_classification date_joined    balance age_group  
0       White Collar  2015-01-05  113810.15     18-25  
1        Blue Collar  2015-01-06   36919.73     26-35  
2       White Collar  2015-01-07  101536.83     36-50  
3       White Collar  2015-01-08    1421.52     26-35  
4        Blue Collar  2015-01-09   35639.79     36-50  

Age groups distribution:
age_group
36-50    1767
26-35    1379
51-65     565
18-25     303
65+         0
N

In [6]:
# Cell 5: Create SQLite Database
# Create SQLite database
db_path = '../data/bank_customers.db'
conn = sqlite3.connect(db_path)

# Write DataFrame to SQLite
df.to_sql('customers', conn, if_exists='replace', index=False)

print(f"Database created at: {db_path}")

# Verify the table
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
print(f"Tables in database: {tables}")

cursor.execute("SELECT COUNT(*) FROM customers;")
row_count = cursor.fetchone()[0]
print(f"Number of records: {row_count}")

conn.close()


Database created at: ../data/bank_customers.db
Tables in database: [('customers',)]
Number of records: 4014


In [7]:
# Cell 6: Setup LangChain SQL Database Connection
# Create SQLDatabase object for LangChain
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

print("Database schema:")
print(db.get_table_info())

# Test database connection
result = db.run("SELECT COUNT(*) FROM customers;")
print(f"Test query result: {result}")

Database schema:

CREATE TABLE customers (
	customer_id INTEGER, 
	name TEXT, 
	surname TEXT, 
	gender TEXT, 
	age INTEGER, 
	region TEXT, 
	job_classification TEXT, 
	date_joined TIMESTAMP, 
	balance REAL, 
	age_group TEXT
)

/*
3 rows from customers table:
customer_id	name	surname	gender	age	region	job_classification	date_joined	balance	age_group
100000001	Simon	Walsh	Male	21	England	White Collar	2015-01-05 00:00:00	113810.15	18-25
400000002	Jasmine	Miller	Female	34	Northern Ireland	Blue Collar	2015-01-06 00:00:00	36919.73	26-35
100000003	Liam	Brown	Male	46	England	White Collar	2015-01-07 00:00:00	101536.83	36-50
*/
Test query result: [(4014,)]


In [8]:
# Cell 7: Create SQL Database Toolkit and Agent
# Create SQL Database Toolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

# Create SQL Agent
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    handle_parsing_errors=True
)

print("SQL Agent created successfully!")

SQL Agent created successfully!


In [9]:
# Cell 8: Define Analysis Functions
def ask_data_question(question):
    """Function to ask questions to the agentic AI system"""
    try:
        print(f"Question: {question}")
        print("-" * 50)
        
        response = agent_executor.run(question)
        
        print(f"Answer: {response}")
        print("=" * 70)
        return response
    except Exception as e:
        print(f"Error: {str(e)}")
        return None

from langchain.agents import create_sql_agent, AgentType

def create_enhanced_agent():
    """Create an enhanced agent with better prompting"""
    system_prompt = """
    You are a data analyst AI assistant. When answering questions about the bank customer data:
    1. Always provide specific numbers and percentages
    2. Format your responses clearly with proper structure
    3. If creating queries, make sure they are accurate and efficient
    4. Provide insights and context for the numbers you present
    5. Round decimal numbers to 2 places for better readability

    The database has a table called 'customers' with these columns:
    - customer_id: unique identifier
    - name, surname: customer names
    - gender: customer gender
    - age: customer age
    - age_group: grouped ages (18-25, 26-35, 36-50, 51-65, 65+)
    - region: geographical region
    - job_classification: job category
    - date_joined: when customer joined
    - balance: account balance
    """

    # ✅ Fix: move max_iterations out of agent_executor_kwargs
    enhanced_agent = create_sql_agent(
        llm=llm,
        toolkit=toolkit,
        verbose=True,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        handle_parsing_errors=True,
        max_iterations=5
    )

    return enhanced_agent

enhanced_agent = create_enhanced_agent()

In [10]:
# Cell 9: Sample Analysis Questions
# Define the key analysis questions
analysis_questions = [
    "What is the average balance by gender?",
    "What is the average balance by age group?", 
    "What is the average balance by job classification?",
    "Which regions have the highest concentration of specific job classifications?",
    "What is the distribution of customers across different regions?",
    "Which age group has the highest average balance?",
    "What is the gender distribution in each region?",
    "Show me the top 5 job classifications by average balance",
    "What percentage of customers fall into each age group?",
    "Which region has the highest total balance?"
]

print("Sample analysis questions prepared!")
for i, question in enumerate(analysis_questions, 1):
    print(f"{i}. {question}")

Sample analysis questions prepared!
1. What is the average balance by gender?
2. What is the average balance by age group?
3. What is the average balance by job classification?
4. Which regions have the highest concentration of specific job classifications?
5. What is the distribution of customers across different regions?
6. Which age group has the highest average balance?
7. What is the gender distribution in each region?
8. Show me the top 5 job classifications by average balance
9. What percentage of customers fall into each age group?
10. Which region has the highest total balance?


In [35]:
# Cell 10: Execute Analysis - Average Balance by Gender
print("ANALYSIS 1: Average Balance by Gender")
print("=" * 50)
response1 = ask_data_question("What is the average balance by gender? Please also show the count of customers for each gender.")

ANALYSIS 1: Average Balance by Gender
Question: What is the average balance by gender? Please also show the count of customers for each gender.
--------------------------------------------------


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: I need to find the relevant tables first.
Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3mcustomers[0m[32;1m[1;3mThought: Now I need to query the schema of the `customers` table to see what columns are available.
Action: sql_db_schema
Action Input: customers[0m[33;1m[1;3m
CREATE TABLE customers (
	customer_id INTEGER, 
	name TEXT, 
	surname TEXT, 
	gender TEXT, 
	age INTEGER, 
	region TEXT, 
	job_classification TEXT, 
	date_joined TIMESTAMP, 
	balance REAL, 
	age_group TEXT
)

/*
3 rows from customers table:
customer_id	name	surname	gender	age	region	job_classification	date_joined	balance	age_group
100000001	Simon	Walsh	Male	21	England	White Collar	2015-01-05 00:00:00	113810.15	18-25
400000002	Jasmine

In [11]:
# Cell 11: Execute Analysis - Average Balance by Age Group  
print("ANALYSIS 2: Average Balance by Age Group")
print("=" * 50)
response2 = ask_data_question("What is the average balance by age group? Please show both the average balance and count of customers in each age group.")


ANALYSIS 2: Average Balance by Age Group
Question: What is the average balance by age group? Please show both the average balance and count of customers in each age group.
--------------------------------------------------


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mcustomers[0m[32;1m[1;3mI should query the schema of the customers table to see what columns are available.
Action: sql_db_schema
Action Input: customers[0m[33;1m[1;3m
CREATE TABLE customers (
	customer_id INTEGER, 
	name TEXT, 
	surname TEXT, 
	gender TEXT, 
	age INTEGER, 
	region TEXT, 
	job_classification TEXT, 
	date_joined TIMESTAMP, 
	balance REAL, 
	age_group TEXT
)

/*
3 rows from customers table:
customer_id	name	surname	gender	age	region	job_classification	date_joined	balance	age_group
100000001	Simon	Walsh	Male	21	England	White Collar	2015-01-05 00:00:00	113810.15	18-25
400000002	Jasmine	Miller	Female	34	Northern Ireland	Bl

In [12]:
# Cell 12: Execute Analysis - Average Balance by Job Classification
print("ANALYSIS 3: Average Balance by Job Classification")
print("=" * 50)
response3 = ask_data_question("What is the average balance by job classification? Please show the top 10 job classifications by average balance along with customer counts.")


ANALYSIS 3: Average Balance by Job Classification
Question: What is the average balance by job classification? Please show the top 10 job classifications by average balance along with customer counts.
--------------------------------------------------


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mcustomers[0m[32;1m[1;3mI should query the schema of the customers table to understand the available columns.
Action: sql_db_schema
Action Input: customers[0m[33;1m[1;3m
CREATE TABLE customers (
	customer_id INTEGER, 
	name TEXT, 
	surname TEXT, 
	gender TEXT, 
	age INTEGER, 
	region TEXT, 
	job_classification TEXT, 
	date_joined TIMESTAMP, 
	balance REAL, 
	age_group TEXT
)

/*
3 rows from customers table:
customer_id	name	surname	gender	age	region	job_classification	date_joined	balance	age_group
100000001	Simon	Walsh	Male	21	England	White Collar	2015-01-05 00:00:00	113810.15	18-25
400000002	Jasmine	Mille