In [1]:
import os
from openai import AzureOpenAI
from dotenv import load_dotenv
import json
from sqlalchemy import create_engine
import pandas as pd
import numpy as np
from sqlalchemy import text

load_dotenv()

True

In [2]:
client = AzureOpenAI(
  azure_endpoint = os.getenv('AZURE_ENDPOINT'),
  api_key=os.getenv("OPENAI_API_KEY"),
  api_version=os.getenv('OPENAI_API_VERSION')
)

In [3]:
def get_current_weather(location, unit="fahrenheit"):
    """Get the current weather in a given location. 
    The default unit when not specified is fahrenheit"""
    if "new york" in location.lower():
        return json.dumps(
            {"location": "New York", "temperature": "40", "unit": unit}
        )
    elif "san francisco" in location.lower():
        return json.dumps(
            {"location": "San Francisco", "temperature": "50", "unit": unit}
        )
    elif "las vegas" in location.lower():
        return json.dumps(
            {"location": "Las Vegas", "temperature": "70", "unit": unit}
        )
    else:
        return json.dumps(
            {"location": location, "temperature": "unknown"}
        )

get_current_weather("New York")

'{"location": "New York", "temperature": "40", "unit": "fahrenheit"}'

In [4]:
messages = [
    {"role": "user",
     "content": """What's the weather like in San Francisco,
                   New York, and Las Vegass?"""
    }
]

tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": """Get the current weather in a given
                              location.The default unit when not
                              specified is fahrenheit""",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": """The city and state,
                                        e.g. San Francisco, CA""",
                    },
                    "unit": {
                        "type": "string",
                        "default":"fahrenheit",
                        "enum": [ "fahrenheit", "celsius"],
                        "description": """The messuring unit for
                                          the temperature.
                                          If not explicitly specified
                                          the default unit is 
                                          fahrenheit"""
                    },
                },
                "required": ["location"],
            },
        },
    }
]

In [5]:
response = client.chat.completions.create(
    model=os.getenv('AZURE_DEPLOYMENT'),
    messages=messages,
    tools=tools,
    tool_choice="auto", 
)

response_message = response.choices[0].message
tool_calls = response_message.tool_calls

if tool_calls:
    print (tool_calls)
    
    available_functions = {
        "get_current_weather": get_current_weather,
    } 
    messages.append(response_message)  
    
    for tool_call in tool_calls:
        function_name = tool_call.function.name
        function_to_call = available_functions[function_name]
        function_args = json.loads(tool_call.function.arguments)
        function_response = function_to_call(
            location=function_args.get("location"),
            unit=function_args.get("unit"),
        )
        messages.append(
            {
                "tool_call_id": tool_call.id,
                "role": "tool",
                "name": function_name,
                "content": function_response,
            }
        )  
    print (messages)

[ChatCompletionMessageToolCall(id='call_vk2MU9zAPsZ67GfbbkU5k5ky', function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_7iVPqMoWyqD70QHd7koXRGCm', function=Function(arguments='{"location": "New York, NY"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_4JMfoZJTyMWmootlPl7Xl0u4', function=Function(arguments='{"location": "Las Vegas, NV"}', name='get_current_weather'), type='function')]
[{'role': 'user', 'content': "What's the weather like in San Francisco,\n                   New York, and Las Vegass?"}, ChatCompletionMessage(content=None, refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_vk2MU9zAPsZ67GfbbkU5k5ky', function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_7iVPqMoWyqD70QHd7koXRGCm

In [6]:
second_response = client.chat.completions.create(
            model=os.getenv('AZURE_DEPLOYMENT'),
            messages=messages,
        )
print (second_response)

ChatCompletion(id='chatcmpl-AZlmFtagjwJFuVOgv71JnnK9u9sX9', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="Here's the current weather for the three locations:\n\n- **San Francisco, CA**: 50°F\n- **New York, NY**: 40°F\n- **Las Vegas, NV**: 70°F", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), content_filter_results={'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}})], created=1733088499, model='gpt-4o-2024-08-06', object='chat.completion', service_tier=None, system_fingerprint='fp_04751d0b65', usage=CompletionUsage(completion_tokens=44, prompt_tokens=159, total_tokens=203, completion_tokens_details=None, prompt_tokens_details=None), prompt_filter_results=[{'prompt_index': 0, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'sa

In [7]:
df = pd.read_csv("./data/all-states-history.csv").fillna(value = 0)

In [8]:
database_file_path = "./test.db"

engine = create_engine(f'sqlite:///{database_file_path}')

df.to_sql(
    'all_student_performance',
    con=engine,
    if_exists='replace',
    index=False)

20780

In [9]:
def get_hospitalized_increase_for_state_on_date(state_abbr, specific_date):
    try:
        query = f"""
        SELECT date, hospitalizedIncrease
        FROM all_states_history
        WHERE state = '{state_abbr}' AND date = '{specific_date}';
        """
        query = text(query)

        with engine.connect() as connection:
            result = pd.read_sql_query(query, connection)
        if not result.empty:
            return result.to_dict('records')[0]
        else:
            return np.nan
    except Exception as e:
        print(e)
        return np.nan

In [10]:
def get_positive_cases_for_state_on_date(state_abbr, specific_date):
    try:
        query = f"""
        SELECT date, state, positiveIncrease AS positive_cases
        FROM all_states_history
        WHERE state = '{state_abbr}' AND date = '{specific_date}';
        """
        query = text(query)

        with engine.connect() as connection:
            result = pd.read_sql_query(query, connection)
        if not result.empty:
            return result.to_dict('records')[0]
        else:
            return np.nan
    except Exception as e:
        print(e)
        return np.nan

In [11]:
get_hospitalized_increase_for_state_on_date("AK","2021-03-05")

{'date': '2021-03-05', 'hospitalizedIncrease': 3}

In [12]:
messages = [
    {"role": "user",
     "content": """ how many hospitalized people we had in Alaska
                    the 2021-03-05?"""
    }
]

In [13]:
tools_sql = [
    {
        "type": "function",
        "function": {
            "name": "get_hospitalized_increase_for_state_on_date",
            "description": """Retrieves the daily increase in
                              hospitalizations for a specific state
                              on a specific date.""",
            "parameters": {
                "type": "object",
                "properties": {
                    "state_abbr": {
                        "type": "string",
                        "description": """The abbreviation of the state
                                          (e.g., 'NY', 'CA')."""
                    },
                    "specific_date": {
                        "type": "string",
                        "description": """The specific date for
                                          the query in 'YYYY-MM-DD'
                                          format."""
                    }
                },
                "required": ["state_abbr", "specific_date"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_positive_cases_for_state_on_date",
            "description": """Retrieves the daily increase in 
                              positive cases for a specific state
                              on a specific date.""",
            "parameters": {
                "type": "object",
                "properties": {
                    "state_abbr": {
                        "type": "string",
                        "description": """The abbreviation of the 
                                          state (e.g., 'NY', 'CA')."""
                    },
                    "specific_date": {
                        "type": "string",
                        "description": """The specific date for the
                                          query in 'YYYY-MM-DD'
                                          format."""
                    }
                },
                "required": ["state_abbr", "specific_date"]
            }
        }
    }
]

In [14]:
response = client.chat.completions.create(
    model=os.getenv('AZURE_DEPLOYMENT'),
    messages=messages,
    tools=tools_sql,
    tool_choice="auto",
)

response_message = response.choices[0].message
tool_calls = response_message.tool_calls

if tool_calls:
    print (tool_calls)
    
    available_functions = {
        "get_positive_cases_for_state_on_date": get_positive_cases_for_state_on_date,
        "get_hospitalized_increase_for_state_on_date":get_hospitalized_increase_for_state_on_date
    }  
    messages.append(response_message)  
   
    for tool_call in tool_calls:
        function_name = tool_call.function.name
        function_to_call = available_functions[function_name]
        function_args = json.loads(tool_call.function.arguments)
        function_response = function_to_call(
            state_abbr=function_args.get("state_abbr"),
            specific_date=function_args.get("specific_date"),
        )
        messages.append(
            {
                "tool_call_id": tool_call.id,
                "role": "tool",
                "name": function_name,
                "content": str(function_response),
            }
        ) 
    import json

    # Convert each message to a dictionary if it has a `dict` method; otherwise, keep it as is
    print(json.dumps(
        [message.dict() if hasattr(message, 'dict') else message for message in messages],
        indent=2
    ))

[ChatCompletionMessageToolCall(id='call_IxVFFxeEBv8AmZzBjcx3uxyC', function=Function(arguments='{"specific_date":"2021-03-05","state_abbr":"AK"}', name='get_hospitalized_increase_for_state_on_date'), type='function')]
[
  {
    "role": "user",
    "content": " how many hospitalized people we had in Alaska\n                    the 2021-03-05?"
  },
  {
    "content": null,
    "refusal": null,
    "role": "assistant",
    "audio": null,
    "function_call": null,
    "tool_calls": [
      {
        "id": "call_IxVFFxeEBv8AmZzBjcx3uxyC",
        "function": {
          "arguments": "{\"specific_date\":\"2021-03-05\",\"state_abbr\":\"AK\"}",
          "name": "get_hospitalized_increase_for_state_on_date"
        },
        "type": "function"
      }
    ]
  },
  {
    "tool_call_id": "call_IxVFFxeEBv8AmZzBjcx3uxyC",
    "role": "tool",
    "name": "get_hospitalized_increase_for_state_on_date",
    "content": "{'date': '2021-03-05', 'hospitalizedIncrease': 3}"
  }
]


/tmp/ipykernel_16982/1994234385.py:40: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  [message.dict() if hasattr(message, 'dict') else message for message in messages],


In [15]:
second_response = client.chat.completions.create(
            model=os.getenv('AZURE_DEPLOYMENT'),
            messages=messages,
        )
print(second_response.model_dump_json(indent=2))

{
  "id": "chatcmpl-AZlmH83OstxAKN8l2Z5uaLaIDSj7k",
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "logprobs": null,
      "message": {
        "content": "On March 5, 2021, Alaska had an increase of 3 hospitalizations.",
        "refusal": null,
        "role": "assistant",
        "audio": null,
        "function_call": null,
        "tool_calls": null
      },
      "content_filter_results": {
        "hate": {
          "filtered": false,
          "severity": "safe"
        },
        "protected_material_code": {
          "filtered": false,
          "detected": false
        },
        "protected_material_text": {
          "filtered": false,
          "detected": false
        },
        "self_harm": {
          "filtered": false,
          "severity": "safe"
        },
        "sexual": {
          "filtered": false,
          "severity": "safe"
        },
        "violence": {
          "filtered": false,
          "severity": "safe"
        }
   