In [1]:
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.llms.openai import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.agents import create_sql_agent
from langchain.agents import AgentExecutor

In [2]:
SQL_SERVER = "localhost"
SQL_DBNAME = "nba"
SQL_USER = "nba_sql"
SQL_PWD = "nba_sql"

OPENAI_API_KEY = "OPENAI_API_KEY"

In [3]:
llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)

In [None]:
llm = ChatOpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)

In [4]:
sqlconn = f"postgresql+psycopg2://{SQL_USER}:{SQL_PWD}@{SQL_SERVER}:5432/{SQL_DBNAME}"
db = SQLDatabase.from_uri(sqlconn)

In [5]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [6]:
agent_executor = create_sql_agent(
        llm=llm,
        toolkit=toolkit,
        verbose=True
)

In [7]:
agent_executor.run('What is the actual team of LeBron James?')



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mplayer_game_log, shot_chart_detail, play_by_play, event_message_type, team, player_season, game, season, player, player_general_traditional_total[0m
Thought:[32;1m[1;3m I should look at the schema of the team table to see what columns I can query.
Action: schema_sql_db
Action Input: "team"[0m
Observation: [33;1m[1;3m
CREATE TABLE team (
	team_id INTEGER NOT NULL, 
	abbreviation VARCHAR(255), 
	nickname VARCHAR(255), 
	yearfounded VARCHAR(255), 
	city VARCHAR(255), 
	CONSTRAINT team_pkey PRIMARY KEY (team_id)
)

/*
3 rows from team table:
team_id	abbreviation	nickname	yearfounded	city
1610612737	ATL	Hawks	1949	Atlanta
1610612738	BOS	Celtics	1946	Boston
1610612739	CLE	Cavaliers	1970	Cleveland
*/[0m
Thought:[32;1m[1;3m I should query the team table to get the team of LeBron James.
Action: query_sql_db
Action Input: SELECT nickname FROM team WH

"LeBron James' team is the Cavaliers."

In [8]:
agent_executor.run("What is the name of player with the most points in the last season?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mplayer_game_log, shot_chart_detail, play_by_play, event_message_type, team, player_season, game, season, player, player_general_traditional_total[0m
Thought:[32;1m[1;3m I should look at the schema of the player_season table to see what columns I can query.
Action: schema_sql_db
Action Input: "player_season"[0m
Observation: [33;1m[1;3m
CREATE TABLE player_season (
	id SERIAL NOT NULL, 
	player_id INTEGER NOT NULL, 
	season_id INTEGER NOT NULL, 
	team_id INTEGER, 
	age INTEGER, 
	player_height VARCHAR(255), 
	player_height_inches INTEGER, 
	player_weight VARCHAR(255), 
	gp INTEGER, 
	pts REAL, 
	reb REAL, 
	ast REAL, 
	net_rating REAL, 
	oreb_pct REAL, 
	dreb_pct REAL, 
	usg_pct REAL, 
	ts_pct REAL, 
	ast_pct REAL, 
	CONSTRAINT player_season_pkey PRIMARY KEY (id), 
	CONSTRAINT player_season_player_id_fkey FOREIGN KEY(player_id) REFERENCES player

'The player with the most points in the last season is 1628369.'