In [2]:
import os 
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
from langchain.agents import load_tools
from langchain_experimental.agents import create_pandas_dataframe_agent 
from langchain.chat_models import ChatOpenAI
from datetime import datetime
import sqlite3


In [7]:
# #shorten stop_times file - only needed one time.
# file_loc = '/workspaces/LLM_Testing/Data/GO_GTFS/stop_times.txt'
# df = pd.read_csv(file_loc, low_memory=False)
# df = df.sort_values(by = 'trip_id',ascending=False)
# # Calculate 10% of the length of the DataFrame
# ten_percent_length = int(len(df) * 0.1)

# # Keep only the first 10% of rows
# df = df.iloc[:ten_percent_length]

# # over write file
# df.to_csv(file_loc)

In [8]:
import os
import glob

conn = sqlite3.connect('/workspaces/LLM_Testing/Data/GO_GTFS/gtfs.db')

def create_table_from_file(con, file_path, table_name):
    # Read the file into a pandas DataFrame
    df = pd.read_csv(file_path, sep='\t',low_memory=False)  # Adjust the separator if needed
    # Convert the DataFrame to a table in the SQLite database
    df.to_sql(table_name, con=con, if_exists='replace', index=False)


# Specify the directory path
directory_path = '/workspaces/LLM_Testing/Data/GO_GTFS'

# Create a cursor object using the connection
cursor = conn.cursor()

# Loop through all the files in the directory
for file_name in os.listdir(directory_path):
    if file_name.endswith('.txt'):
        file_path = os.path.join(directory_path, file_name)
        table_name = os.path.splitext(file_name)[0]
        create_table_from_file(conn, file_path, table_name)

# Commit the changes and close the connection
conn.commit()
conn.close()

In [25]:

# Connect to the SQLite database
conn = sqlite3.connect('/workspaces/LLM_Testing/Data/GO_GTFS/gtfs.db')

# Create a cursor object
cursor = conn.cursor()

# Query to retrieve the list of tables
query = "SELECT name FROM sqlite_master WHERE type='table';"

# Execute the query
cursor.execute(query)

# Fetch all results
tables = cursor.fetchall()

# Print the list of tables
print("List of tables in the database:")
for table in tables:
    print(table[0])

# Close the connection
conn.close()


List of tables in the database:
agency
trips
stop_amentities
stops
fare_rules
transfers
stop_times
routes
calendar_dates
feed_info
fare_attributes
shapes


In [16]:
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit


# from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType

db = SQLDatabase.from_uri(f'sqlite:////workspaces/LLM_Testing/Data/GO_GTFS/gtfs.db')
llm = ChatOpenAI(model_name="gpt-4",temperature=0)

# tools = load_tools(["google-serper", "llm-math"], llm=llm)
# db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

agent = create_sql_agent(
    llm=OpenAI(temperature=0),
    toolkit=SQLDatabaseToolkit(db=db, llm=llm),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)


In [17]:
agent('how many stations are there?')



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3magency, calendar_dates, fare_attributes, fare_rules, feed_info, routes, shapes, stop_amentities, stop_times, stops, transfers, trips[0m
Thought:[32;1m[1;3m I should query the stops table to get the number of stations.
Action: sql_db_schema
Action Input: stops[0m
Observation: [33;1m[1;3m
CREATE TABLE stops (
	"stop_id,stop_name,stop_lat,stop_lon,zone_id,stop_url,location_type,parent_station,wheelchair_boarding,stop_code" TEXT
)

/*
3 rows from stops table:
stop_id,stop_name,stop_lat,stop_lon,zone_id,stop_url,location_type,parent_station,wheelchair_boarding,stop_code
WR,West Harbour GO,43.266775,-79.866222,18,https://www.gotransit.com/find-a-station-or-stop/WR/stati
WH,Whitby GO,43.86484,-78.93818,93,https://www.gotransit.com/find-a-station-or-stop/WH/station-detai
WE,Weston GO,43.70022,-79.514671,04,https://www.gotransit.com/find-a-station-or-sto

{'input': 'how many stations are there?', 'output': '911'}

In [30]:
agent('the database holds GTFS data for GO Transit in Toronto. how long on average do trips take to get from Oakville to Union?')



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3magency, calendar_dates, fare_attributes, fare_rules, feed_info, routes, shapes, stop_amentities, stop_times, stops, transfers, trips[0m
Thought:[32;1m[1;3m I should query the trips table to get the trip_id, route_id, and service_id for the trips from Oakville to Union.
Action: sql_db_schema
Action Input: trips[0m
Observation: [33;1m[1;3m
CREATE TABLE trips (
	"route_id,service_id,trip_id,trip_headsign,trip_short_name,direction_id,block_id,shape_id,wheelchair_accessible,bikes_allowed,route_variant" TEXT
)

/*
3 rows from trips table:
route_id,service_id,trip_id,trip_headsign,trip_short_name,direction_id,block_id,shape_id,wheelchair_accessible,bikes_allowed,route_variant
01240424-21,20240422,20240422-21-21302,21D - Clarkson GO,,0,strvl016,311638,1,1,21D
01240424-21,20240422,20240422-21-21263,21D - Lisgar GO,,1,strvl016,311637,1,1,21D
01240424-21,2

{'input': 'the database holds GTFS data for GO Transit in Toronto. how long on average do trips take to get from Oakville to Union?',
 'output': "I don't know"}

In [9]:
# sql_query = "SELECT stop_times.arrival_time, stop_times.departure_time FROM stop_times INNER JOIN trips ON stop_times.trip_id = trips.trip_id WHERE trips.trip_headsign LIKE '%Oakville%' AND trips.trip_headsign LIKE '%Union%' LIMIT 10"
# sql_query = "SELECT arrival_time, departure_time FROM stop_times limit 10"
sql_query = "select service_id from trips limit 10"

# Connect to the SQLite database
conn = sqlite3.connect('/workspaces/LLM_Testing/Data/GO_GTFS/gtfs.db')

df = pd.read_sql(sql=sql_query,con=conn)
df.head()

DatabaseError: Execution failed on sql 'select service_id from trips limit 10': no such column: service_id

In [37]:
import sqlite3

conn = sqlite3.connect('/workspaces/LLM_Testing/Data/GO_GTFS/gtfs.db')

try:
    # Create a cursor object
    cursor = conn.cursor()

    # Query to retrieve the schema of all tables
    query = "SELECT name, sql FROM sqlite_master WHERE type='table'"

    # Execute the query
    cursor.execute(query)

    # Fetch all the results
    tables_schema = cursor.fetchall()

    # Display the schema of each table
    for table_name, schema in tables_schema:
        print(f"Table: {table_name}\nSchema:\n{schema}\n")

except sqlite3.Error as error:
    print("Error while getting table schemas:", error)

finally:
    # Close the cursor and connection
    if conn:
        conn.close()


Table: agency
Schema:
CREATE TABLE "agency" (
"agency_id,agency_name,agency_url,agency_timezone,agency_lang,agency_phone,agency_fare_url" TEXT
)

Table: trips
Schema:
CREATE TABLE "trips" (
"route_id,service_id,trip_id,trip_headsign,trip_short_name,direction_id,block_id,shape_id,wheelchair_accessible,bikes_allowed,route_variant" TEXT
)

Table: stop_amentities
Schema:
CREATE TABLE "stop_amentities" (
"stop_id,shelter,washroom,bike_rack,bench" TEXT
)

Table: stops
Schema:
CREATE TABLE "stops" (
"stop_id,stop_name,stop_lat,stop_lon,zone_id,stop_url,location_type,parent_station,wheelchair_boarding,stop_code" TEXT
)

Table: fare_rules
Schema:
CREATE TABLE "fare_rules" (
"fare_id,origin_id,destination_id" TEXT
)

Table: transfers
Schema:
CREATE TABLE "transfers" (
"from_stop_id,to_stop_id,transfer_type,min_transfer_time" TEXT
)

Table: stop_times
Schema:
CREATE TABLE "stop_times" (
"trip_id,arrival_time,departure_time,stop_id,stop_sequence,pickup_type,drop_off_type,stop_headsign" TEXT
)

Tab