In [1]:
import os 
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
from langchain.agents import load_tools
from langchain_experimental.agents import create_pandas_dataframe_agent 
from langchain.chat_models import ChatOpenAI
from datetime import datetime
import sqlite3


In [3]:
import os
import glob

conn = sqlite3.connect('/workspaces/LLM_Testing/Data/GO_GTFS/gtfs.db')


# Specify the directory path
directory_path = '/workspaces/LLM_Testing/Data/GO_GTFS'

# Create a cursor object using the connection
cursor = conn.cursor()

# Loop through all the files in the directory
for file_name in os.listdir(directory_path):
    if file_name.endswith('.txt'):
        file_path = os.path.join(directory_path, file_name)
        table_name = os.path.splitext(file_name)[0]
        
        #read dataframe
        df = pd.read_csv(file_path,low_memory=False)
        #export to db
        df.to_sql(table_name, con=conn, if_exists='replace', index=False)


# Commit the changes and close the connection
conn.commit()
conn.close()

In [25]:

# Connect to the SQLite database
conn = sqlite3.connect('/workspaces/LLM_Testing/Data/GO_GTFS/gtfs.db')

# Create a cursor object
cursor = conn.cursor()

# Query to retrieve the list of tables
query = "SELECT name FROM sqlite_master WHERE type='table';"

# Execute the query
cursor.execute(query)

# Fetch all results
tables = cursor.fetchall()

# Print the list of tables
print("List of tables in the database:")
for table in tables:
    print(table[0])

# Close the connection
conn.close()


List of tables in the database:
agency
trips
stop_amentities
stops
fare_rules
transfers
stop_times
routes
calendar_dates
feed_info
fare_attributes
shapes


In [11]:
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit


# from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType

db = SQLDatabase.from_uri(f'sqlite:////workspaces/LLM_Testing/Data/GO_GTFS/gtfs.db')



agent = create_sql_agent(
    llm = ChatOpenAI(model_name="gpt-4",temperature=0),
    toolkit=SQLDatabaseToolkit(db=db, llm=llm),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)


In [12]:
agent('how many stations are there?')



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m
Observation: [38;5;200m[1;3magency, calendar_dates, fare_attributes, fare_rules, feed_info, routes, shapes, stop_amentities, stop_times, stops, transfers, trips[0m
Thought:[32;1m[1;3mThe 'stops' table seems to be the most relevant to the question as it likely contains information about the stations.
Action: sql_db_schema
Action Input: "stops"[0m
Observation: [33;1m[1;3m
CREATE TABLE stops (
	stop_id TEXT, 
	stop_name TEXT, 
	stop_lat REAL, 
	stop_lon REAL, 
	zone_id INTEGER, 
	stop_url TEXT, 
	location_type INTEGER, 
	parent_station REAL, 
	wheelchair_boarding INTEGER, 
	stop_code REAL
)

/*
3 rows from stops table:
stop_id	stop_name	stop_lat	stop_lon	zone_id	stop_url	location_type	parent_station	wheelchair_boarding	stop_code
WR	West Harbour GO	43.266775	-79.866222	18	https://www.gotransit.com/find-a-station-or-stop/WR/station-details	0	None	1	None
WH	Whitby GO	43.86484	

{'input': 'how many stations are there?', 'output': 'There are 911 stations.'}

In [17]:
agent('the database holds GTFS data for GO Transit. on January 4th, how many trips stop at Oakville GO between 7am and 8am that go east')



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m
Observation: [38;5;200m[1;3magency, calendar_dates, fare_attributes, fare_rules, feed_info, routes, shapes, stop_amentities, stop_times, stops, transfers, trips[0m
Thought:[32;1m[1;3mThe relevant tables for this query seem to be "trips", "stops", "stop_times", and "calendar_dates". I will now check the schema of these tables to understand their structure and the data they hold.
Action: sql_db_schema
Action Input: "trips,stops,stop_times,calendar_dates"[0m
Observation: [33;1m[1;3m
CREATE TABLE calendar_dates (
	service_id INTEGER, 
	date INTEGER, 
	exception_type INTEGER
)

/*
3 rows from calendar_dates table:
service_id	date	exception_type
20240104	20240104	1
20240103	20240103	1
20231228	20231228	1
*/


CREATE TABLE stop_times (
	trip_id TEXT, 
	arrival_time TEXT, 
	departure_time TEXT, 
	stop_id TEXT, 
	stop_sequence INTEGER, 
	pickup_type INTEGER, 
	drop_off_type INTEG

{'input': 'the database holds GTFS data for GO Transit. on January 4th, how many trips stop at Oakville GO between 7am and 8am that go east',
 'output': 'There was 1 trip that stopped at Oakville GO between 7am and 8am going east on January 4th.'}