## 5: Leveraging Assistants API for SQL Databases

## Setup

In [11]:
from openai import AzureOpenAI
import json
import os

## Import the helper function

In [12]:
import Helper
from Helper import get_positive_cases_for_state_on_date
from Helper import get_hospitalized_increase_for_state_on_date

## Launch the Assistant API

In [19]:

# Ensure environment variables are set
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_MODEL = os.environ["AZURE_OPENAI_MODEL"]

# Connect to the Azure OpenAI endpoint
client = AzureOpenAI(
    api_key=AZURE_OPENAI_API_KEY,
    api_version="2024-02-15-preview",
    azure_endpoint=AZURE_OPENAI_ENDPOINT
)

assistant = client.beta.assistants.create(
    name='Saad Assistant',
    instructions="""You are an assistant answering questions 
                    about a Covid dataset.""",
    model=AZURE_OPENAI_MODEL,
    tools=Helper.tools_sql,
    
    )

# II) Create thread
thread = client.beta.threads.create()
print(thread)

Thread(id='thread_caBk1HOUQBS3fv6g7r4Xoik7', created_at=1718544541, metadata={}, object='thread')


In [20]:
# III) Add message
message = client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content="""how many hospitalized people we had in Alaska
            the 2021-03-05?"""
)
print(message)

ThreadMessage(id='msg_mN1FicdB42nySJXdsLo5hdKl', assistant_id=None, content=[MessageContentText(text=Text(annotations=[], value='how many hospitalized people we had in Alaska\n            the 2021-03-05?'), type='text')], created_at=1718544545, file_ids=[], metadata={}, object='thread.message', role='user', run_id=None, thread_id='thread_caBk1HOUQBS3fv6g7r4Xoik7')


In [21]:
messages = client.beta.threads.messages.list(
    thread_id=thread.id
)

print(messages.model_dump_json(indent=2))

{
  "data": [
    {
      "id": "msg_mN1FicdB42nySJXdsLo5hdKl",
      "assistant_id": null,
      "content": [
        {
          "text": {
            "annotations": [],
            "value": "how many hospitalized people we had in Alaska\n            the 2021-03-05?"
          },
          "type": "text"
        }
      ],
      "created_at": 1718544545,
      "file_ids": [],
      "metadata": {},
      "object": "thread.message",
      "role": "user",
      "run_id": null,
      "thread_id": "thread_caBk1HOUQBS3fv6g7r4Xoik7"
    }
  ],
  "object": "list",
  "first_id": "msg_mN1FicdB42nySJXdsLo5hdKl",
  "last_id": "msg_mN1FicdB42nySJXdsLo5hdKl",
  "has_more": false
}


In [22]:
# IV) Run assistant on thread

run = client.beta.threads.runs.create(
    thread_id=thread.id,
    assistant_id=assistant.id,
)

## Leverage the function calling with Assistants API

In [23]:
import time
from IPython.display import clear_output

start_time = time.time()

status = run.status

while status not in ["completed", "cancelled", "expired", "failed"]:
    time.sleep(5)
    run = client.beta.threads.runs.retrieve(
        thread_id=thread.id,run_id=run.id
    )
    print("Elapsed time: {} minutes {} seconds".format(
        int((time.time() - start_time) // 60),
        int((time.time() - start_time) % 60))
        )
    status = run.status
    print(f'Status: {status}')
    if (status=="requires_action"):
        available_functions = {
            "get_positive_cases_for_state_on_date": get_positive_cases_for_state_on_date,
            "get_hospitalized_increase_for_state_on_date":get_hospitalized_increase_for_state_on_date
        }

        tool_outputs = []
        for tool_call in run.required_action.submit_tool_outputs.tool_calls:
            function_name = tool_call.function.name
            function_to_call = available_functions[function_name]
            function_args = json.loads(tool_call.function.arguments)
            function_response = function_to_call(
                state_abbr=function_args.get("state_abbr"),
                specific_date=function_args.get("specific_date"),
            )
            print(function_response)
            print(tool_call.id)
            tool_outputs.append(
                { "tool_call_id": tool_call.id,
                "output": str(function_response)
                }
            )

        run = client.beta.threads.runs.submit_tool_outputs(
            thread_id=thread.id,
            run_id=run.id,
            tool_outputs = tool_outputs
        )


messages = client.beta.threads.messages.list(
    thread_id=thread.id
)

print(messages)

Elapsed time: 0 minutes 6 seconds
Status: requires_action
{'date': '2021-03-05', 'hospitalizedIncrease': 3}
call_qtVu9Gj7hMOwixQV2n2sxe3S
Elapsed time: 0 minutes 12 seconds
Status: completed
SyncCursorPage[ThreadMessage](data=[ThreadMessage(id='msg_676CQezZz0Q2bjp306wa5OlW', assistant_id='asst_z2ETTQh3XLV6JcE8xDVB4Q1G', content=[MessageContentText(text=Text(annotations=[], value='On March 5th, 2021, Alaska had an increase of 3 hospitalized individuals due to COVID-19.'), type='text')], created_at=1718544565, file_ids=[], metadata={}, object='thread.message', role='assistant', run_id='run_mywyVaCKBC9gMJndAkTdFbC0', thread_id='thread_caBk1HOUQBS3fv6g7r4Xoik7'), ThreadMessage(id='msg_mN1FicdB42nySJXdsLo5hdKl', assistant_id=None, content=[MessageContentText(text=Text(annotations=[], value='how many hospitalized people we had in Alaska\n            the 2021-03-05?'), type='text')], created_at=1718544545, file_ids=[], metadata={}, object='thread.message', role='user', run_id=None, thread_id=

In [24]:
print(messages.model_dump_json(indent=2))

{
  "data": [
    {
      "id": "msg_676CQezZz0Q2bjp306wa5OlW",
      "assistant_id": "asst_z2ETTQh3XLV6JcE8xDVB4Q1G",
      "content": [
        {
          "text": {
            "annotations": [],
            "value": "On March 5th, 2021, Alaska had an increase of 3 hospitalized individuals due to COVID-19."
          },
          "type": "text"
        }
      ],
      "created_at": 1718544565,
      "file_ids": [],
      "metadata": {},
      "object": "thread.message",
      "role": "assistant",
      "run_id": "run_mywyVaCKBC9gMJndAkTdFbC0",
      "thread_id": "thread_caBk1HOUQBS3fv6g7r4Xoik7"
    },
    {
      "id": "msg_mN1FicdB42nySJXdsLo5hdKl",
      "assistant_id": null,
      "content": [
        {
          "text": {
            "annotations": [],
            "value": "how many hospitalized people we had in Alaska\n            the 2021-03-05?"
          },
          "type": "text"
        }
      ],
      "created_at": 1718544545,
      "file_ids": [],
      "metadata": {

## Add the code interpreter

In [25]:
file = client.files.create(
    file=open("./data/all-states-history.csv", "rb"),
    purpose='assistants'
)
assistant = client.beta.assistants.create(
    instructions="""You are an assitant answering questions about
                a Covid dataset.""",
    model=AZURE_OPENAI_MODEL, 
    tools=[{"type": "code_interpreter"}],
    file_ids=[file.id])
thread = client.beta.threads.create()
print(thread)
message = client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content="""how many hospitalized people we had in Alaska
            the 2021-03-05?"""
)
print(message)
run = client.beta.threads.runs.create(
    thread_id=thread.id,
    assistant_id=assistant.id,
)

Thread(id='thread_WsnKNELFCyF5dTOoqshiFyLv', created_at=1718544663, metadata={}, object='thread')
ThreadMessage(id='msg_ogbZMz5myHS6tTWle0CkOD5D', assistant_id=None, content=[MessageContentText(text=Text(annotations=[], value='how many hospitalized people we had in Alaska\n            the 2021-03-05?'), type='text')], created_at=1718544664, file_ids=[], metadata={}, object='thread.message', role='user', run_id=None, thread_id='thread_WsnKNELFCyF5dTOoqshiFyLv')


In [26]:
status = run.status
start_time = time.time()
while status not in ["completed", "cancelled", "expired", "failed"]:
    time.sleep(5)
    run = client.beta.threads.runs.retrieve(
        thread_id=thread.id,
        run_id=run.id
    )
    print("Elapsed time: {} minutes {} seconds".format(
        int((time.time() - start_time) // 60),
        int((time.time() - start_time) % 60))
        )
    status = run.status
    print(f'Status: {status}')
    clear_output(wait=True)


messages = client.beta.threads.messages.list(
    thread_id=thread.id
)

print(messages.model_dump_json(indent=2))

{
  "data": [
    {
      "id": "msg_Di8pMbWucS1TY6zBZM8Uqhw7",
      "assistant_id": "asst_0BRrm0wM5sadf7xpr4xuCX8X",
      "content": [
        {
          "text": {
            "annotations": [],
            "value": "On March 5, 2021, there were 33 people hospitalized due to COVID-19 in Alaska."
          },
          "type": "text"
        }
      ],
      "created_at": 1718544681,
      "file_ids": [],
      "metadata": {},
      "object": "thread.message",
      "role": "assistant",
      "run_id": "run_G4qjylDkSF6MbxbzaTGQSR7F",
      "thread_id": "thread_WsnKNELFCyF5dTOoqshiFyLv"
    },
    {
      "id": "msg_lLY5BPpPlXXFonOmxLjaHs4Z",
      "assistant_id": "asst_0BRrm0wM5sadf7xpr4xuCX8X",
      "content": [
        {
          "text": {
            "annotations": [],
            "value": "The dataset covers various metrics related to COVID-19, including hospitalized cases. To find the number of people hospitalized in Alaska on 2021-03-05, we'll filter the data by the state an