In [1]:
import os
import json
import pandas as pd
from sqlalchemy import create_engine, text
from openai import OpenAI
from thefuzz import fuzz, process

In [2]:
def make_initial_query(client, query):
    with open("initial_prompt.txt", "r") as file:
        prompt = file.read()
        
    response = client.chat.completions.create(
        model="gpt-4o",
        response_format={"type": "json_object"},
        messages=[
            {
                "role": "system",
                "content": prompt
            },
            {
                "role": "user",
                "content": query
            }
        ],
        temperature=1,
        max_tokens=600,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0
    )
    #print(repr(response.choices[0].message.content))
    response_json = json.loads(response.choices[0].message.content)
    return response_json

In [3]:
def correct_station_names(engine, response_json):
    for station in response_json["stations"]:
        if station[1]:
            query = f"SELECT sr.*, s.station_complex FROM stations s JOIN station_routes sr ON s.station_complex_id = sr.station_complex_id WHERE route_name ILIKE '%%{station[1]}%%'"
        else:
            query = f"SELECT station_complex FROM stations"
            
        station_list = pd.read_sql(query, con=engine.connect())["station_complex"].tolist()
        top_matches = process.extract(station[0], station_list, limit=3, scorer=fuzz.partial_token_sort_ratio)
        #print(f"The best matches between '{station[0]}' and all stations is: {top_matches[0]}")
        response_json["sql"] = response_json["sql"].replace(station[0], top_matches[0][0])
    return response_json

In [4]:
client = OpenAI()
client.api_key = os.environ["OPENAI_API_KEY"]
engine = create_engine('postgresql://conductor:train0109@localhost/subway')

In [5]:
query = "Which station used OMNY more on 2024-01-19? union square (l) or 96th street (6)"
response_json = make_initial_query(client, query)
response_json = correct_station_names(engine, response_json)

The best matches between 'Union Square' and all stations is: ('14 St-Union Sq', 70)
The best matches between '96th Street' and all stations is: ('96 St', 75)


{'sql': "WITH ridership_union_square_l AS ( \n    SELECT s.station_complex_id, s.station_complex, sr.route_name, SUM(r.total_omny_ridership) AS total_omny_ridership \n    FROM ridership r \n    JOIN stations s ON r.station_complex_id = s.station_complex_id \n    JOIN station_routes sr ON s.station_complex_id = sr.station_complex_id \n    WHERE s.station_complex = '14 St-Union Sq' AND sr.route_name = 'L' AND r.transit_timestamp::date = '2024-01-19' \n    GROUP BY s.station_complex_id, s.station_complex, sr.route_name \n), ridership_96th_street_6 AS ( \n    SELECT s.station_complex_id, s.station_complex, sr.route_name, SUM(r.total_omny_ridership) AS total_omny_ridership \n    FROM ridership r \n    JOIN stations s ON r.station_complex_id = s.station_complex_id \n    JOIN station_routes sr ON s.station_complex_id = sr.station_complex_id \n    WHERE s.station_complex = '96 St' AND sr.route_name = '6' AND r.transit_timestamp::date = '2024-01-19' \n    GROUP BY s.station_complex_id, s.statio

In [6]:
with engine.connect() as conn:
    results = conn.execute(text(response_json["sql"]))
query_result = results.mappings().all()
pd.DataFrame(query_result)

Unnamed: 0,omny_ridership_96th_street,omny_ridership_union_square,station_with_more_omny_ridership
0,0,31505,14 St-Union Sq on the L line
1,6314,0,96 St on the 6 line
