In [2]:
%load_ext dotenv
%dotenv

import os
from langchain_openai import ChatOpenAI
from langchain_together import ChatTogether

USE_CHINESE=False

In [3]:
if not USE_CHINESE:
    # Use Mixtral
    model = ChatOpenAI(
        base_url="https://api.together.xyz/v1",
        api_key=os.environ["KEY_TOGETHERAI"],
        model="mistralai/Mixtral-8x7B-Instruct-v0.1",)
else:
    # Use Qwen
    model = ChatTogether(
        together_api_key=os.environ['KEY_TOGETHERAI'],
        model="Qwen/Qwen1.5-72B-Chat",
    )

In [4]:
%reload_ext dotenv

In [5]:
# https://python.langchain.com/v0.2/docs/tutorials/rag/
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain

In [6]:
db = SQLDatabase.from_uri("sqlite:///data/chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [7]:
chain = create_sql_query_chain(model, db)
response = chain.invoke({"question": "name of all albums of the artist named 'Alanis Morissette'"})

In [8]:
print(response)

SELECT "Album"."Title"
FROM "Artist"
JOIN "Album" ON "Artist"."ArtistId" = "Album"."ArtistId"
WHERE "Artist"."Name" = 'Alanis Morissette'
LIMIT 5;


In [9]:
db.run(response)

"[('Jagged Little Pill',)]"