In [1]:
import os
from dotenv import load_dotenv
# from google.colab import drive

# drive.mount('/content/drive')

try:
    import google.colab
    print("Colab 환경에서 실행 중입니다.")
    PATH = '/content/drive/MyDrive/data/'
    env_path = PATH + "env/.env"
except ImportError:
    print("로컬 환경에서 실행 중입니다.")
    import platform
    os_name = platform.system()
    if os_name == "Windows":
        print("Windows 로컬 환경에서 실행 중입니다.")
        PATH = './'
        env_path = PATH + "env/.env"
    elif os_name == "Linux":
        print("Linux 환경에서 실행 중입니다. (Colab일 가능성 있음)")
    else:
        print(f"운영 체제: {os_name}")

load_dotenv(dotenv_path=env_path)

# getenv로 환경 변수 가져오기
# KEY 쓸 때, 띄워쓰기 하면 안됨...
# env에 있는 모든 워드들은 띄워쓰기하면 못 알아봄
api_key = os.getenv('MY_OWN_KEY')

if not api_key:
    raise ValueError(".env 파일에서 API 키를 로드하지 못했습니다.")


# 가져온 값을 environ에 저장
os.environ['OPENAI_API_KEY'] = api_key

로컬 환경에서 실행 중입니다.
Windows 로컬 환경에서 실행 중입니다.


In [72]:
from langchain_community.utilities import SQLDatabase

# if you are using SQLite
# sqlite_uri = 'sqlite:///./Chinook.db'

# if you are using MySQL
# mysql_uri = 'mysql+mysqlconnector://root:root1234@192.168.123.103:3306/Chinook'
mysql_uri = 'mysql+mysqlconnector://root:root1234@localhost:3306/Chinook'

db = SQLDatabase.from_uri(mysql_uri)

In [73]:
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

mysql
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [74]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model = 'gpt-3.5-turbo', temperature=0)

In [75]:
from langchain.chains import create_sql_query_chain


chain_sql = create_sql_query_chain(llm=llm, db=db)
chain_sql.get_prompts()[0].pretty_print()
# 쿼리문을 가져와줌 - 한국어도 되네~
response = chain_sql.invoke({"question":"가게에서 종사하는 종업원이 몇 명이야?"})
print(response)

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [76]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

In [77]:
try:
    execute_query = QuerySQLDataBaseTool(db=db)
    result = execute_query.run(response)
    
    print(result)
    
except Exception as e:
    print(f"오류 발생: {str(e)}")

[(8,)]


In [78]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)


In [79]:
chain = (
    RunnablePassthrough.assign(query=chain_sql).assign(
        result=itemgetter("query") | execute_query
    )
    | answer_prompt
    | llm
    | StrOutputParser()
)

In [80]:
chain.invoke({"question": "How many employees are there"})

'There are 8 employees.'

In [57]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [59]:
result = agent_executor.invoke({"input": "what's the average age of employees"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Employee'}`


[0m[33;1m[1;3m
CREATE TABLE `Employee` (
	`EmployeeId` INTEGER NOT NULL, 
	`LastName` VARCHAR(20) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, 
	`FirstName` VARCHAR(20) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, 
	`Title` VARCHAR(30) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`ReportsTo` INTEGER, 
	`BirthDate` DATETIME, 
	`HireDate` DATETIME, 
	`Address` VARCHAR(70) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`City` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`State` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`Country` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`Po