In [1]:
!pip uninstall -qqy jupyterlab kfp  # Remove unused conflicting packages
!pip install -qU "google-genai==1.7.0" "chromadb==0.6.3"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
from google import genai
from google.genai import types
from typing import List, Dict, Optional, Union
from google.protobuf import struct_pb2

from IPython.display import Markdown

genai.__version__

'1.7.0'

In [3]:
!pip install python-dotenv
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()

GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
print(GOOGLE_API_KEY)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
AIzaSyDgPvHKl4G7qSpn1iXY6KzzEgAC4Jex-ws


In [4]:
client = genai.Client(api_key=GOOGLE_API_KEY)

In [5]:
for m in client.models.list():
    if "embedContent" in m.supported_actions:
        print(m.name)

models/embedding-001
models/text-embedding-004
models/gemini-embedding-exp-03-07
models/gemini-embedding-exp


In [6]:
# Define a retry policy. The model might make multiple consecutive calls automatically
# for a complex query, this ensures the client retries if it hits quota limits.
from google.api_core import retry

is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

if not hasattr(genai.models.Models.generate_content, '__wrapped__'):
  genai.models.Models.generate_content = retry.Retry(
      predicate=is_retriable)(genai.models.Models.generate_content)

In [7]:
!pip install psycopg2-binary
import psycopg2


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [8]:
conn = psycopg2.connect(
    host="localhost",
    port=15432,  # Your custom port
    database=os.getenv('DB_NAME', 'gobo'),
    user=os.getenv('DB_USER', 'postgres'),
    password=os.getenv('DB_PASSWORD', 'admin')
)

# Test the connection
try:
    cursor = conn.cursor()
    cursor.execute("SELECT version();")
    db_version = cursor.fetchone()
    print("PostgreSQL database version:", db_version)
finally:
    if conn:
        conn.close()

PostgreSQL database version: ('PostgreSQL 14.15 (Debian 14.15-1.pgdg120+1) on aarch64-unknown-linux-gnu, compiled by gcc (Debian 12.2.0-14) 12.2.0, 64-bit',)


In [9]:
# some tool functions

def get_db_connection():
    """Establish connection to PostgreSQL database"""
    try:
        return psycopg2.connect(
            host="localhost",
            port=15432,
            database="gobo",
            user='postgres',  # Get username from 'postgres' in .env
            password='admin'   # Get password from 'admin' in .env
        )
    except Exception as e:
        print("Connection error:", e)
        raise

def list_schemas() -> List[str]:
    """Returns a list of all schemas in the database, excluding system schemas."""
    query = """
    SELECT schema_name 
    FROM information_schema.schemata
    WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
    ORDER BY schema_name;
    """
    with get_db_connection() as conn:
        with conn.cursor() as cursor:
            cursor.execute(query)
            return [row[0] for row in cursor.fetchall()]

def list_tables(schema_name: str) -> List[str]:
    """Returns a list of all tables in a specified schema.
    
    Args:
        schema_name: The name of the schema to list tables from
    """
    query = """
    SELECT table_name 
    FROM information_schema.tables
    WHERE table_schema = %s
    ORDER BY table_name;
    """
    with get_db_connection() as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, (schema_name,))
            return [row[0] for row in cursor.fetchall()]

def list_columns(schema_name: str, table_name: str) -> List[Dict[str, str]]:
    """Returns basic column information for a specified table.
    
    Args:
        schema_name: The schema containing the table
        table_name: The table to list columns from
        
    Returns:
        List of dictionaries with 'name' and 'type' keys
    """
    query = """
    SELECT column_name, data_type 
    FROM information_schema.columns
    WHERE table_schema = %s AND table_name = %s
    ORDER BY ordinal_position;
    """
    with get_db_connection() as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, (schema_name, table_name))
            return cursor.fetchall()


In [10]:
def get_column_details(schema_name: str, table_name: str) -> List[Dict[str, Union[str, bool, List[str]]]]:
    """Returns detailed column information including constraints.
    
    Args:
        schema_name: The schema containing the table
        table_name: The table to analyze
        
    Returns:
        List of dictionaries with column details including:
        - name: Column name
        - type: Data type
        - nullable: Whether column allows NULL values
        - default: Default value if any
        - description: Column comment/description
        - constraints: List of constraints
    """

    query = """
    SELECT 
        c.column_name,
        c.data_type,
        c.is_nullable = 'YES',
        c.column_default,
        pgd.description,
        (SELECT string_agg(tc.constraint_type, ', ')
         FROM information_schema.table_constraints tc
         JOIN information_schema.key_column_usage kcu
           ON tc.constraint_name = kcu.constraint_name
           AND tc.table_schema = kcu.table_schema
         WHERE tc.table_schema = %s
           AND tc.table_name = %s
           AND kcu.column_name = c.column_name) as constraints
    FROM information_schema.columns c
    LEFT JOIN pg_catalog.pg_statio_all_tables st
      ON c.table_schema = st.schemaname AND c.table_name = st.relname
    LEFT JOIN pg_catalog.pg_description pgd
      ON pgd.objoid = st.relid AND pgd.objsubid = c.ordinal_position
    WHERE c.table_schema = %s AND c.table_name = %s
    ORDER BY c.ordinal_position;
    """
    
    with get_db_connection() as conn:
        with conn.cursor() as cursor:
            try:
                cursor.execute(query, (schema_name, table_name, schema_name, table_name))
                
                columns = []
                for row in cursor.fetchall():
                    # Ensure we have all expected fields
                    if len(row) >= 6:  # We expect 6 columns
                        columns.append({
                            'name': row[0],
                            'type': row[1],
                            'nullable': row[2],
                            'default': row[3],
                            'description': row[4],
                            'constraints': row[5].split(', ') if row[5] else []
                        })
                    else:
                        print(f"Warning: Unexpected row format for column {row[0] if row else 'unknown'}")
                
                return columns
                
            except Exception as e:
                print(f"Error getting column details for {schema_name}.{table_name}: {e}")
                return None

In [11]:
def get_table_relationships(schema_name: Optional[str] = None) -> List[Dict[str, str]]:
    """Returns foreign key relationships between tables.
    
    Args:
        schema_name: Optional filter for specific schema
        
    Returns:
        List of relationship dictionaries with:
        - source_table: The table containing the FK
        - source_column: The FK column
        - target_table: The referenced table
        - target_column: The referenced column
        - relationship_type: 1:1, 1:many, or many:many
    """
    # Base query to get all foreign key relationships
    query = """
    WITH fk_relationships AS (
        SELECT
            tc.table_schema as source_schema,
            tc.table_name as source_table,
            kcu.column_name as source_column,
            ccu.table_schema as target_schema,
            ccu.table_name as target_table,
            ccu.column_name as target_column
        FROM 
            information_schema.table_constraints AS tc
            JOIN information_schema.key_column_usage AS kcu
              ON tc.constraint_name = kcu.constraint_name
              AND tc.table_schema = kcu.table_schema
            JOIN information_schema.constraint_column_usage AS ccu
              ON ccu.constraint_name = tc.constraint_name
              AND ccu.table_schema = tc.table_schema
        WHERE tc.constraint_type = 'FOREIGN KEY'
    ),
    relationship_counts AS (
        SELECT
            source_schema,
            source_table,
            source_column,
            target_schema,
            target_table,
            target_column,
            (SELECT COUNT(*) 
             FROM information_schema.table_constraints
             WHERE table_schema = fr.target_schema
               AND table_name = fr.target_table
               AND constraint_type = 'PRIMARY KEY') as target_pk_count,
            (SELECT COUNT(*) 
             FROM information_schema.table_constraints
             WHERE table_schema = fr.source_schema
               AND table_name = fr.source_table
               AND constraint_type = 'PRIMARY KEY') as source_pk_count,
            (SELECT COUNT(DISTINCT kcu.column_name)
             FROM information_schema.table_constraints tc
             JOIN information_schema.key_column_usage kcu
               ON tc.constraint_name = kcu.constraint_name
             WHERE tc.table_schema = fr.source_schema
               AND tc.table_name = fr.source_table
               AND tc.constraint_type = 'FOREIGN KEY'
               AND kcu.table_name = fr.target_table) as fk_count_to_target
        FROM fk_relationships fr
    )
    SELECT
        source_schema,
        source_table,
        source_column,
        target_schema,
        target_table,
        target_column,
        CASE
            WHEN source_pk_count > 0 AND target_pk_count > 0 AND fk_count_to_target > 1 THEN 'many:many'
            WHEN source_pk_count = 0 THEN '1:many'
            WHEN target_pk_count = 0 THEN 'many:1'
            ELSE '1:1'
        END as relationship_type
    FROM relationship_counts
    """ + ("WHERE source_schema = %s" if schema_name else "") + """
    ORDER BY source_schema, source_table, target_table;
    """
    
    with get_db_connection() as conn:
        with conn.cursor() as cursor:
            try:
                cursor.execute(query, (schema_name,) if schema_name else None)
                
                relationships = []
                for row in cursor.fetchall():
                    relationships.append({
                        'source_schema': row[0],
                        'source_table': row[1],
                        'source_column': row[2],
                        'target_schema': row[3],
                        'target_table': row[4],
                        'target_column': row[5],
                        'relationship_type': row[6]
                    })
                
                return relationships
                
            except Exception as e:
                print(f"Error analyzing table relationships: {e}")
                return None

In [12]:
def execute_sql_query(query: str) -> List[List[str]]:
    """
    Executes a SQL query and returns the results as a list of rows
    
    Args:
        query: The SQL query to execute
        
    Returns:
        List of rows where each row is a list of string values
        
    Raises:
        Exception: If there's an error executing the query
    """
    with get_db_connection() as conn:
        with conn.cursor() as cursor:
            try:
                cursor.execute(query)
                if cursor.description:
                    return cursor.fetchall()
                return []
            except Exception as e:
                conn.rollback()
                raise Exception(f"SQL execution error: {str(e)}")

In [13]:
#create the model

db_tools = [list_schemas, list_columns, list_tables, get_column_details, get_table_relationships, execute_sql_query]
# use db tools defined above

instruction = """
You are a SQL expert who is acting as an agent on behalf of the user. You will take users questions and turn them into SQL queries using the tools available.
Once you have the information you need, you will answer the users question using the data returned.

tools overview:
- list_schemas - provides a list of all schemas in this database
- list_tables - provides a list of all tables in a schema
- list_columns - provides a list of all columns given a schema and table
- get_column_details - provides a more robust listing of all cols in a schema and table.  Returns: List of dictionaries with:
    - name: Column name
    - type: Data type
    - nullable: Whether column allows NULL values
    - default: Default value if any
    - description: Column comment/description
    - constraints: List of constraints
- get_table_relationships - Analyzes and returns the foreign key relationships between tables, given a schema. Returns:
    List of relationship dictionaries with:
    - source_table: The table containing the FK
    - source_column: The FK column
    - target_table: The referenced table
    - target_column: The referenced column
    - relationship_type: 1:1, 1:many, or many:many (inferred) 
- execute_sql_query - executes a sql query, and returns a result dict. 
    Args:
        query (str): The SQL query to execute
        params (tuple/dict): Optional query parameters
        return_results (bool): Whether to return results (for SELECT) 
        max_rows (int): Maximum rows to return (prevents large results)
    
    Returns:
        dict: {
            'success': bool,
            'data': list[dict] (for SELECT),
            'rowcount': int (for INSERT/UPDATE/DELETE),
            'columns': list[str] (column names),
            'error': str (if unsuccessful),
            'query_type': str (SELECT/INSERT/etc.)
        }
"""

client = genai.Client(api_key=GOOGLE_API_KEY)

# start a chat with automatic function calling
chat = client.chats.create(
  model="gemini-2.0-flash",
    config=types.GenerateContentConfig(
        system_instruction=instruction,
        tools=db_tools,
    ),
)

In [14]:
try:
    response = chat.send_message("How many tables are in inventory?")
    print(response.text)
        
except Exception as e:
    print(f"Error occurred: {str(e)}")
    if hasattr(e, 'response'):
        print("Detailed error:", e.response.text)

There are 75 tables in the inventory schema.


In [15]:
import textwrap


def print_chat_turns(chat):
    """Prints out each turn in the chat history, including function calls and responses."""
    for event in chat.get_history():
        print(f"{event.role.capitalize()}:")

        for part in event.parts:
            if txt := part.text:
                print(f'  "{txt}"')
            elif fn := part.function_call:
                args = ", ".join(f"{key}={val}" for key, val in fn.args.items())
                print(f"  Function call: {fn.name}({args})")
            elif resp := part.function_response:
                print("  Function response:")
                print(textwrap.indent(str(resp.response['result']), "    "))

        print()


print_chat_turns(chat)

User:
  "How many tables are in inventory?"

Model:
  Function call: list_tables(schema_name=inventory)

User:
  Function response:
    ['application', 'audit_check_config', 'audit_check_customization', 'audit_check_severity_enum', 'audit_event', 'audit_failure', 'audit_run', 'audit_suite', 'ci', 'ci_change', 'ci_change_recent', 'ci_to_inventory_source_enum', 'ci_type_enum', 'device', 'device_config', 'device_config_command', 'device_config_family', 'device_config_history', 'device_config_history_2023_06_08', 'device_operation_state_enum', 'device_type_enum', 'dns_record', 'dns_record_type_enum', 'doigov_fortimanager_inventory_device', 'doigov_fortimanager_inventory_device_config', 'doigov_lumen_inventory_device', 'doigov_lumen_inventory_interface', 'doigov_lumen_inventory_service', 'dotgov_unified_inventory_device', 'eox', 'eox_type_enum', 'firmware', 'fqdn', 'infoblox_poller_dot_inventory_dns_a', 'infoblox_poller_dot_inventory_dns_aaaa', 'interface', 'interface_medium_enum', 'interfa

In [16]:
def ask_database(query):
  try:
    response = chat.send_message(query)
    print(response.text)
        
  except Exception as e:
      print(f"Error occurred: {str(e)}")
      if hasattr(e, 'response'):
          print("Detailed error:", e.response.text)

In [17]:
chat = client.chats.create(
  model="gemini-2.0-flash",
    config=types.GenerateContentConfig(
        system_instruction=instruction,
        tools=db_tools,
    ),
)

ask_database('How many devices per vendor do we have in inventory?')

Here is the number of devices per vendor:
net optics .: 1
aruba networks : 7
hpe: 52
a10 networks: 13
fortinet: 954
juniper networks: 24
neoteris: 1
anue systems: 4
_unknown: 106412
cisco: 2347



In [18]:
print_chat_turns(chat)

User:
  "How many devices per vendor do we have in inventory?"

Model:
  Function call: list_schemas()

User:
  Function response:
    ['capacityplanning', 'dotgov', 'dramatiq', 'financial', 'gobo', 'gobofw', 'inventory', 'pg_toast', 'public', 'workflow', 'x2ipv6']

Model:
  Function call: list_tables(schema_name=inventory)

User:
  Function response:
    ['application', 'audit_check_config', 'audit_check_customization', 'audit_check_severity_enum', 'audit_event', 'audit_failure', 'audit_run', 'audit_suite', 'ci', 'ci_change', 'ci_change_recent', 'ci_to_inventory_source_enum', 'ci_type_enum', 'device', 'device_config', 'device_config_command', 'device_config_family', 'device_config_history', 'device_config_history_2023_06_08', 'device_operation_state_enum', 'device_type_enum', 'dns_record', 'dns_record_type_enum', 'doigov_fortimanager_inventory_device', 'doigov_fortimanager_inventory_device_config', 'doigov_lumen_inventory_device', 'doigov_lumen_inventory_interface', 'doigov_lumen_inve

In [19]:
# ask a follow up
ask_database('Looking only at Fortinet Devices in the inventory, Where are they located?')

Fortinet devices are located in the following locations:

*   Unknown: 404
*   Null: 550


In [20]:
ask_database('Looking only at cisco devices in the inventory, Where are they located?')

Cisco devices are located in the following locations:

*   Unknown: 1948
*   Null: 399


In [21]:
ask_database('How many locations exist for the devices in inventory? where are those places?')

There are 1115 locations. A sample of the locations are:

*   Unknown
*   Null


In [22]:
print_chat_turns(chat)

User:
  "How many devices per vendor do we have in inventory?"

Model:
  Function call: list_schemas()

User:
  Function response:
    ['capacityplanning', 'dotgov', 'dramatiq', 'financial', 'gobo', 'gobofw', 'inventory', 'pg_toast', 'public', 'workflow', 'x2ipv6']

Model:
  Function call: list_tables(schema_name=inventory)

User:
  Function response:
    ['application', 'audit_check_config', 'audit_check_customization', 'audit_check_severity_enum', 'audit_event', 'audit_failure', 'audit_run', 'audit_suite', 'ci', 'ci_change', 'ci_change_recent', 'ci_to_inventory_source_enum', 'ci_type_enum', 'device', 'device_config', 'device_config_command', 'device_config_family', 'device_config_history', 'device_config_history_2023_06_08', 'device_operation_state_enum', 'device_type_enum', 'dns_record', 'dns_record_type_enum', 'doigov_fortimanager_inventory_device', 'doigov_fortimanager_inventory_device_config', 'doigov_lumen_inventory_device', 'doigov_lumen_inventory_interface', 'doigov_lumen_inve

In [23]:
!pip install requests


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [25]:
load_dotenv()
GOOGLE_GEOCODING_API_KEY = os.getenv('GOOGLE_GEOCODING_API_KEY')
print(GOOGLE_GEOCODING_API_KEY)

AIzaSyD-i25DSym1v9PYTWVnoTN8wjmgK51tLTY


In [36]:
from typing import Optional
import requests

def get_location_name(lat: float, lon: float) -> Optional[str]:
    """
    Get the city and state from latitude and longitude coordinates using Google Maps API.
    
    Args:
        lat: Latitude coordinate (e.g., 37.7749 for San Francisco)
        lon: Longitude coordinate (e.g., -122.4194 for San Francisco)
        
    Returns:
        str: City and state in "City, State" format (e.g., "San Francisco, CA") if successful.
        None: If the location cannot be determined or an error occurs.
        
    Example:
        >>> get_location_name(40.7128, -74.0060)
        "New York, NY"
    """
    base_url = "https://maps.googleapis.com/maps/api/geocode/json"
    params = {
        "latlng": f"{lat},{lon}",
        "key": GOOGLE_GEOCODING_API_KEY  # Ensure this is defined in your environment
    }
    
    try:
        response = requests.get(base_url, params=params)
        data = response.json()
        
        if data['status'] == 'OK':
            city = None
            state = None
            
            for component in data['results'][0]['address_components']:
                if 'locality' in component['types'] or 'postal_town' in component['types']:
                    city = component['long_name']
                elif 'administrative_area_level_1' in component['types']:
                    state = component['short_name']
                elif 'country' in component['types'] and not state:
                    state = component['short_name']
            
            if city and state:
                return f"{city}, {state}"
            return data['results'][0]['formatted_address']
        
        print(f"Geocoding API error: {data['status']}")
        return None
        
    except Exception as e:
        print(f"Error in reverse geocoding: {e}")
        return None

In [52]:
db_tools = [list_schemas, list_columns, list_tables, get_column_details, get_table_relationships, execute_sql_query, get_location_name]
# use db tools defined above

instruction = """
You are a SQL expert who is acting as an agent on behalf of the user. You will take users questions and turn them into SQL queries using the tools available.
Once you have the information you need, you will answer the users question using the data returned.

tools overview:
- list_schemas - provides a list of all schemas in this database
- list_tables - provides a list of all tables in a schema
- list_columns - provides a list of all columns given a schema and table
- get_column_details - provides a more robust listing of all cols in a schema and table.  Returns: List of dictionaries with:
    - name: Column name
    - type: Data type
    - nullable: Whether column allows NULL values
    - default: Default value if any
    - description: Column comment/description
    - constraints: List of constraints
- get_table_relationships - Analyzes and returns the foreign key relationships between tables, given a schema. Returns:
    List of relationship dictionaries with:
    - source_table: The table containing the FK
    - source_column: The FK column
    - target_table: The referenced table
    - target_column: The referenced column
    - relationship_type: 1:1, 1:many, or many:many (inferred) 
- execute_sql_query - executes a sql query, and returns a result dict. 
    Args:
        query (str): The SQL query to execute
        params (tuple/dict): Optional query parameters
        return_results (bool): Whether to return results (for SELECT) 
        max_rows (int): Maximum rows to return (prevents large results)
    
    Returns:
        dict: {
            'success': bool,
            'data': list[dict] (for SELECT),
            'rowcount': int (for INSERT/UPDATE/DELETE),
            'columns': list[str] (column names),
            'error': str (if unsuccessful),
            'query_type': str (SELECT/INSERT/etc.)
        }
- get_location_name - given a latitude and longitude coordinates, find the location address using reverse geocoding lookup returns a string of the full address.
"""

client = genai.Client(api_key=GOOGLE_API_KEY)

# start a chat with automatic function calling
chat = client.chats.create(
  model="gemini-2.0-flash",
    config=types.GenerateContentConfig(
        system_instruction=instruction,
        tools=db_tools,
    ),
)

In [53]:
ask_database('How many cisco devices do we have in inventory schema?')

There are 2347 Cisco devices in the inventory.



In [54]:
ask_database('Of those devices, please provide a count of the number of devices at each location.')

Here is the breakdown of Cisco devices by location:

*   **Unknown Location:** 1948
*   **Null Location:** 399
