# Lesson 5: Leveraging Assistants API for SQL Databases

## Setup

In [144]:
from openai import AzureOpenAI
import json
import os
from dotenv import load_dotenv
import time

data_path = os.getenv('DATA_PATH')
csv_path = os.path.join(data_path, "all-states-history.csv")

## Import the helper function

To access the ``Helper.py`` file, please go to the ``File`` menu and select ``Open...``.

In [145]:
from Helper import client, tools_sql, get_positive_cases_for_state_on_date, get_hospitalized_for_state_on_date

## Launch the Assistant API

**Note**: The pre-configured cloud resource grants you access to the Azure OpenAI GPT model. The key and endpoint provided below are intended for teaching purposes only. Your notebook environment is already set up with the necessary keys, which may differ from those used by the instructor during the filming.

In [158]:
# Initialize Azure OpenAI client
client = AzureOpenAI(
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version="2024-05-01-preview"
)

assistant = client.beta.assistants.create(
    model="gpt-35-turbo",
    instructions="You are an assistant answering questions about a Covid dataset. Use the provided SQL functions to query the database.",
    tools=tools_sql
)

# Create thread
thread = client.beta.threads.create()

In [159]:
# Add message
message = client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content="""how many hospitalized people we had in Alaska the 3/7/2021? and how many positive cases?"""
)
print(message)

ThreadMessage(id='msg_jdVtocRpvDPOv25g791lqUBL', assistant_id=None, content=[MessageContentText(text=Text(annotations=[], value='how many hospitalized people we had in Alaska the 3/7/2021? and how many positive cases?'), type='text')], created_at=1736966707, file_ids=None, metadata={}, object='thread.message', role='user', run_id=None, thread_id='thread_ScXUUYitKGklgkjOrTalmU0l', attachments=[])


In [161]:
messages = client.beta.threads.messages.list(
  thread_id=thread.id
)

print(messages.model_dump_json(indent=2))

{
  "data": [
    {
      "id": "msg_jdVtocRpvDPOv25g791lqUBL",
      "assistant_id": null,
      "content": [
        {
          "text": {
            "annotations": [],
            "value": "how many hospitalized people we had in Alaska the 3/7/2021? and how many positive cases?"
          },
          "type": "text"
        }
      ],
      "created_at": 1736966707,
      "file_ids": null,
      "metadata": {},
      "object": "thread.message",
      "role": "user",
      "run_id": null,
      "thread_id": "thread_ScXUUYitKGklgkjOrTalmU0l",
      "attachments": []
    }
  ],
  "object": "list",
  "first_id": "msg_jdVtocRpvDPOv25g791lqUBL",
  "last_id": "msg_jdVtocRpvDPOv25g791lqUBL",
  "has_more": false
}


In [162]:
# Run assistant on thread
run = client.beta.threads.runs.create(
  thread_id=thread.id,
  assistant_id=assistant.id,
)

In [163]:
from sqlalchemy import create_engine
import pandas as pd

df = pd.read_csv(os.path.join(csv_path)).fillna(value=0)

database_file_path = os.path.join(data_path, "test.db")

engine = create_engine(f'sqlite:///{database_file_path}')

df.to_sql(
    'all_states_history',
    con=engine,
    if_exists='replace',
    index=False)

21

In [164]:
df

Unnamed: 0,date,state,death,hospitalized,negative,positive,recovered
0,3/7/2021,AK,305,1293.0,0.0,56886,0.0
1,3/7/2021,AL,10148,45976.0,1931711.0,499819,295690.0
2,3/7/2021,AR,5319,14926.0,2480716.0,324818,315517.0
3,3/7/2021,AS,0,0.0,2140.0,0,0.0
4,3/7/2021,AZ,16328,57907.0,3073010.0,826454,0.0
5,3/7/2021,CA,54124,0.0,0.0,3501394,0.0
6,3/7/2021,CO,5989,23904.0,2199458.0,436602,0.0
7,3/7/2021,CT,7704,0.0,0.0,285330,0.0
8,3/7/2021,DC,1030,0.0,0.0,41419,29570.0
9,3/7/2021,DE,1473,0.0,545070.0,88354,0.0


In [152]:
# Get column names
df = pd.read_sql("SELECT * FROM all_states_history", engine)
print(df.columns.tolist())

['date', 'state', 'death', 'hospitalized', 'negative', 'positive', 'recovered']


In [153]:
# Monitor run status and handle function calls
while True:
    run = client.beta.threads.runs.retrieve(
        thread_id=thread.id,
        run_id=run.id
    )
    
    if run.status == "requires_action":
        tool_outputs = []
        for tool_call in run.required_action.submit_tool_outputs.tool_calls:
            function_name = tool_call.function.name
            function_args = json.loads(tool_call.function.arguments)
            
            if function_name in ["get_positive_cases_for_state_on_date", "get_hospitalized_for_state_on_date"]:
                function_to_call = globals()[function_name]
                function_response = function_to_call(
                    state_abbr=function_args.get("state_abbr"),
                    specific_date=function_args.get("specific_date"),
                )
                tool_outputs.append({
                    "tool_call_id": tool_call.id,
                    "output": str(function_response)
                })
        
        print("Tool Outputs:", tool_outputs)  # Debugging print
        
        run = client.beta.threads.runs.submit_tool_outputs(
            thread_id=thread.id,
            run_id=run.id,
            tool_outputs=tool_outputs
        )
    
    elif run.status == "completed":
        messages = client.beta.threads.messages.list(
            thread_id=thread.id
        )
        for msg in messages.data:
            if msg.role == "assistant":
                print(msg.content[0].text.value)
        break
    
    elif run.status in ["failed", "cancelled", "expired"]:
        print(f"Run failed with status: {run.status}")
        break
        
    time.sleep(1)

Tool Outputs: [{'tool_call_id': 'call_erIMI5XUHnTTaXXyahchwhTB', 'output': "{'date': '3/7/2021', 'state': 'AK', 'hospitalized': 1293.0}"}]
Tool Outputs: [{'tool_call_id': 'call_zDgeERaEpglUmRMFxOQKymlK', 'output': "{'date': '3/7/2021', 'state': 'AK', 'positive_cases': 56886}"}]
Run failed with status: failed


## Leverage the function calling with Assistants API

In [154]:
import time
from IPython.display import clear_output

start_time = time.time()

status = run.status

while status not in ["completed", "cancelled", "expired", "failed"]:
    time.sleep(5)
    run = client.beta.threads.runs.retrieve(
        thread_id=thread.id,run_id=run.id
    )
    print("Elapsed time: {} minutes {} seconds".format(
        int((time.time() - start_time) // 60),
        int((time.time() - start_time) % 60))
         )
    status = run.status
    print(f'Status: {status}')
    if (status=="requires_action"):
        available_functions = {
            "get_positive_cases_for_state_on_date": get_positive_cases_for_state_on_date,
            "get_hospitalized_for_state_on_date":get_hospitalized_for_state_on_date
        }

        tool_outputs = []
        for tool_call in run.required_action.submit_tool_outputs.tool_calls:
            function_name = tool_call.function.name
            function_to_call = available_functions[function_name]
            function_args = json.loads(tool_call.function.arguments)
            function_response = function_to_call(
                state_abbr=function_args.get("state_abbr"),
                specific_date=function_args.get("specific_date"),
            )
            print(function_response)
            print(tool_call.id)
            tool_outputs.append(
                { "tool_call_id": tool_call.id,
                 "output": str(function_response)
                }
            )

        run = client.beta.threads.runs.submit_tool_outputs(
          thread_id=thread.id,
          run_id=run.id,
          tool_outputs = tool_outputs
        )


messages = client.beta.threads.messages.list(
  thread_id=thread.id
)

print(messages)

SyncCursorPage[ThreadMessage](data=[ThreadMessage(id='msg_I7qxvorNITYaGet0FDNaweOS', assistant_id=None, content=[MessageContentText(text=Text(annotations=[], value='how many hospitalized people we had in Alaska the 3/7/2021? and how many positive cases?'), type='text')], created_at=1736966536, file_ids=None, metadata={}, object='thread.message', role='user', run_id=None, thread_id='thread_87dlLzztbRgB1geyICfoKP7x', attachments=[])], object='list', first_id='msg_I7qxvorNITYaGet0FDNaweOS', last_id='msg_I7qxvorNITYaGet0FDNaweOS', has_more=False)


In [155]:
print(messages.model_dump_json(indent=2))

{
  "data": [
    {
      "id": "msg_I7qxvorNITYaGet0FDNaweOS",
      "assistant_id": null,
      "content": [
        {
          "text": {
            "annotations": [],
            "value": "how many hospitalized people we had in Alaska the 3/7/2021? and how many positive cases?"
          },
          "type": "text"
        }
      ],
      "created_at": 1736966536,
      "file_ids": null,
      "metadata": {},
      "object": "thread.message",
      "role": "user",
      "run_id": null,
      "thread_id": "thread_87dlLzztbRgB1geyICfoKP7x",
      "attachments": []
    }
  ],
  "object": "list",
  "first_id": "msg_I7qxvorNITYaGet0FDNaweOS",
  "last_id": "msg_I7qxvorNITYaGet0FDNaweOS",
  "has_more": false
}


## Add the code interpreter

In [165]:
# 1. Upload file for analysis
with open(csv_path, "rb") as file:
    data = pd.read_csv(file)
    data_json = data.to_json()

data['date'] = pd.to_datetime(data['date']).dt.strftime('%Y-%m-%d')

In [166]:
data

Unnamed: 0,date,state,death,hospitalized,negative,positive,recovered
0,2021-03-07,AK,305,1293.0,,56886,
1,2021-03-07,AL,10148,45976.0,1931711.0,499819,295690.0
2,2021-03-07,AR,5319,14926.0,2480716.0,324818,315517.0
3,2021-03-07,AS,0,,2140.0,0,
4,2021-03-07,AZ,16328,57907.0,3073010.0,826454,
5,2021-03-07,CA,54124,,,3501394,
6,2021-03-07,CO,5989,23904.0,2199458.0,436602,
7,2021-03-07,CT,7704,,,285330,
8,2021-03-07,DC,1030,,,41419,29570.0
9,2021-03-07,DE,1473,,545070.0,88354,


In [176]:
# 2. Create assistant with verified data
assistant = client.beta.assistants.create(
    model="gpt-35-turbo",
    instructions=f"""Analyzing COVID-19 data.""",
    tools=[{"type": "code_interpreter"}]
)

# 3. Create thread with context
thread = client.beta.threads.create()
message = client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content=f"""How many hospitalized people we had in Alaska
                the 2021-03-07? and how many positive cases?
                Target date: 2021-03-07
                Available data: {data.to_dict('records')}"""
)

# 4. Run and monitor
run = client.beta.threads.runs.create(
    thread_id=thread.id,
    assistant_id=assistant.id
)

while True:
    run = client.beta.threads.runs.retrieve(
        thread_id=thread.id,
        run_id=run.id
    )
    print(f"Status: {run.status}")
    
    if run.status == "completed":
        messages = client.beta.threads.messages.list(thread_id=thread.id)
        for msg in messages.data:
            if msg.role == "assistant":
                print("\nResult:", msg.content[0].text.value)
        break
    elif run.status == "failed":
        print(f"\nFailed")
        break
    
    time.sleep(1)

Status: in_progress
Status: completed

Result: According to the available data, on 2021-03-07, Alaska had 1293 hospitalized people and 56886 positive cases of COVID-19.
