## 设置环境变量

In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('../../src'))
# os.environ['PYTHONPATH'] = module_path
# os.environ['PYTHONPATH']
sys.path.append(module_path)

利用本地模型测试 Agent

In [2]:
import os
from data_retrieval.tools.mock_tools import add, multiply, divide
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama


model = ChatOpenAI(model="gpt-3.5-turbo")
# model = ChatOllama(model="phi3")
# model = ChatOllama(model="gemma:7b")

In [3]:
from langchain_core.callbacks import StdOutCallbackHandler
from data_retrieval import AFAgent, AgentPrompt

agent = AFAgent(
    tools=[add, multiply, divide],
    llm=model,
    prompt=AgentPrompt(lang="en"),
    personality="",
    background=""
)

# with callbacks
# agent.invoke({"input": "what is 33 divide 2, and plus 3?"}, {"callbacks": [StdOutCallbackHandler()]})

# Without callbacks
agent.invoke("what is 33 divide 2?")

{'name': 'divide',
 'arguments': {'first': 33, 'second': 2, 'tables': []},
 'output': 16.5}

## 测试数据源

初始化一个数据库，用于测试 Agent

```sql
CREATE TABLE movie(title, year, score)

In [7]:
# 检查文件是否存在不存在则删除
import os

if os.path.exists('fake.db'):
    os.remove('fake.db')

In [11]:
# 创建数据表并初始化数据
import sqlite3
con = sqlite3.connect("fake.db")

cur = con.cursor()

cur.execute("DROP TABLE IF EXISTS movie")
cur.execute(
"""CREATE TABLE movie(
    title varchar (128) NOT NULL PRIMARY KEY,  -- 电影名称
    year int NOT NULL,  -- 上映年份
    score float NOT NULL  -- 电影评分
)"""
)

cur.execute("""
    INSERT INTO movie VALUES
        ('Monty Python and the Holy Grail', 1975, 8.2),
        ('And Now for Something Completely Different', 1971, 7.5)
""")

con.commit()

In [12]:
res = cur.execute("SELECT * FROM sqlite_master")
res.fetchone()

# Schema 表的格式如下:
# CREATE TABLE sqlite_schema(
#   type text,
#   name text,
#   tbl_name text,
#   rootpage integer,
#   sql text
# );
# 具体参考 https://www.sqlite.org/schematab.html

('table',
 'movie',
 'movie',
 2,
 'CREATE TABLE movie(\n    title varchar (128) NOT NULL PRIMARY KEY,  -- 电影名称\n    year int NOT NULL,  -- 上映年份\n    score float NOT NULL  -- 电影评分\n)')

In [14]:
res = cur.execute("SELECT score FROM movie")
res.fetchall()

[(8.2,), (7.5,)]

In [15]:
data = [
    ("Monty Python Live at the Hollywood Bowl", 1982, 7.9),
    ("Monty Python's The Meaning of Life", 1983, 7.5),
    ("Monty Python's Life of Brian", 1979, 8.0),
]
cur.executemany("INSERT INTO movie VALUES(?, ?, ?)", data)
con.commit()

In [16]:
for row in cur.execute("SELECT year, title FROM movie ORDER BY year"):
    print(row)

(1971, 'And Now for Something Completely Different')
(1975, 'Monty Python and the Holy Grail')
(1979, "Monty Python's Life of Brian")
(1982, 'Monty Python Live at the Hollywood Bowl')
(1983, "Monty Python's The Meaning of Life")


In [17]:
# 查元数据的另外一种方法
res = cur.execute("PRAGMA TABLE_INFO(movie);")
res.fetchall()

[(0, 'title', 'varchar (128)', 1, None, 1),
 (1, 'year', 'INT', 1, None, 0),
 (2, 'score', 'float', 1, None, 0)]

In [18]:
# 关闭数据库连接
con.close()
res = cur.execute("SELECT * FROM sqlite_master")

ProgrammingError: Cannot operate on a closed database.

## 利用 Datasoure 库来查询

In [19]:
import sys
import os
module_path = os.path.abspath(os.path.join('../../src'))
# os.environ['PYTHONPATH'] = module_path
# os.environ['PYTHONPATH']
sys.path.append(module_path)

In [29]:
from data_retrieval.datasource.sqlite_ds import SQLiteDataSource

db = SQLiteDataSource(db_file="./fake.db", tables=["movie"])

In [30]:
# 获取元数据
db.get_metadata()

{'movie': 'CREATE TABLE movie(\n    title varchar (128) NOT NULL PRIMARY KEY,  -- 电影名称\n    year int NOT NULL,  -- 上映年份\n    score float NOT NULL  -- 电影评分\n)'}

In [31]:
# 获取样例数据
db.get_sample(num=2)

{'movie': (['title', 'year', 'score'],
  [('Monty Python and the Holy Grail', 1975, 8.2),
   ('And Now for Something Completely Different', 1971, 7.5)])}

In [32]:
# 关闭数据库
db.close()

## 调用 Text2SQL 工具

In [3]:
from data_retrieval.datasource.sqlite_ds import SQLiteDataSource
from data_retrieval.tools.text2sql import Text2SQLTool
from langchain_community.chat_models import ChatOllama

# from langchain_openai import ChatOpenAI
# os.environ["OPENAI_API_KEY"] = "your key"
# model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
# model = ChatOllama(model="phi3:latest")
# model = ChatOllama(model="gemma:7b")

sqlite = SQLiteDataSource(
    db_file="./fake.db",
    tables=["movie"]
)

In [4]:
tool = Text2SQLTool(
    language="cn",
    data_source=sqlite,
    llm=model,
    background="电影表中的年份字段是年份，如果用户使用两位数的年份，要注意转换成四位数的年份。",
)

tool.invoke({"text": "1975年以后上映的电影", "tables": ["movie"]})

{'sql': 'SELECT title, year, score FROM movie WHERE year >= 1975;',
 'res': [{'title': 'Monty Python and the Holy Grail',
   'year': 1975,
   'score': 8.2},
  {'title': 'Monty Python Live at the Hollywood Bowl',
   'year': 1982,
   'score': 7.9},
  {'title': "Monty Python's The Meaning of Life", 'year': 1983, 'score': 7.5},
  {'title': "Monty Python's Life of Brian", 'year': 1979, 'score': 8.0}]}