In [1]:
#Dependencies
!pip install dspy-ai[chromadb] -Uqq
!pip install termcolor -Uqq
!pip install sqlalchemy -Uqq
!pip install lxml -Uqq
!pip install xlrd -Uqq


## AIM OF THE TUTORIAL

* Build an end-to-end Text-to-SQL pipeline inspired from this [video](https://www.youtube.com/watch?v=L1o1VPVfbb0&pp=ygUYYWR2YW5jZWQgUkFHIGxsYW1hIGluZGV4) from Llama Index. In Llama Index, they used llama index Query Pipeline to build a Text-to-SQL pipeline. Here, we will build a Text-to-SQL pipeline based on our own dataset and from scratch. We will go from the dataset scraping, to building SQLlite database and using DSPy signatures to implementa a text-to-SQL pipeline (从 Llama Index 的 [视频](https:www.youtube.comwatch?v=L1o1VPVfbb0&pp=ygUYYWR2YW5jZWQgUkFHIGxsYW1hIGluZGV4) 中获得灵感，构建端到端的文本到 SQL 管道。在 Llama Index 中，他们使用了 llama index Query Pipeline 来构建文本到 SQL 管道。在这里，我们将基于自己的数据集，从零开始构建文本到 SQL 管道。我们将从数据集搜刮开始，到构建 SQLlite 数据库并使用 DSPy 签名来实现文本到 SQL 管道)

## ABOUT THE DATASET
* You can find the dataset [here](https://pages.stern.nyu.edu/~adamodar/New_Home_Page/datacurrent.html). The dataset has different industry based different financial metrics like WACC, tax rates, EBITDA, etc. There are multiple regions data `['US', 'Europe', 'Japan', 'AUS_NZ_CANADA', 'Emerging', 'China', 'India', 'Global']` where we have multiple tables for each region. There are nearly 250 tables with multiple columns in each table. We will build a text-to-SQL pipeline based on our own dataset and from scratch, starting from embedding tables schema and rows using ChromDB vector database. (该数据集具有基于不同行业的不同财务指标，如 WACC、税率、EBITDA 等。有多个区域数据“[美国'、'欧洲'、'日本'、'AUS_NZ_CANADA'、'新兴'、'中国'、'印度'、'全球']'，每个区域都有多个表。有近 250 个表，每个表中有多列。我们将基于自己的数据集从头开始构建一个文本到 SQL 管道，从使用 ChromDB 矢量数据库嵌入表架构和行开始。)

In [1]:
from experiment_project.utils.files.read import read_yaml
from experiment_project.utils.initial.util import init_sys_env
#imports
from bs4 import BeautifulSoup 
import urllib.request
import ssl
from dotenv import load_dotenv
import openai
import os
import requests
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import pandas as pd
import json
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)
import re
from sqlalchemy import inspect
import sqlalchemy
from sqlalchemy import text 
import dspy
from typing import List
from termcolor import colored
import chromadb
init_sys_env()
secret_env_file = '/mnt/d/project/dy/extra/autogen/env_secret_config.yaml'

api_configs = read_yaml(secret_env_file)


In [2]:
from dsp.modules.cache_utils import cache_turn_on

cache_turn_on

## SCRAPING THE LINKS OF THE EXCEL FILES FROM THE WEBSITE(从网站上抓取 excel 文件的链接)

In [3]:
ssl._create_default_https_context = ssl._create_stdlib_context
html_link = "https://pages.stern.nyu.edu/~adamodar/New_Home_Page/datacurrent.html"

with urllib.request.urlopen(html_link) as url:
    s = url.read()
    # I'm guessing this would output the html source code ?
    soup = BeautifulSoup(s,"lxml")

html_table = soup.find_all("table")
req_table = html_table[1]
hrefs_list = req_table.find_all('a')

In [4]:
req_href = {"US":[],"Europe":[],"Japan":[],"AUS_NZ_CANADA":[],"Emerging":[],"China":[],"India":[],"Global":[]}

for i in hrefs_list:
    name = i.get_text().strip()
    try:
        href_attr = i['href']
        # Only get the excel files
        if href_attr.endswith('.xls'):
            if "US" in name:
                req_href["US"].append(href_attr)
            elif "Europe" in name:
                req_href["Europe"].append(href_attr)
            elif "Japan" in name:
                req_href["Japan"].append(href_attr)
            elif "Aus" in name:
                req_href['AUS_NZ_CANADA'].append(href_attr)
            elif "Emerging" in name:
                req_href['Emerging'].append(href_attr)
            elif "China" in name:
                req_href['China'].append(href_attr)
            elif "India" in name:
                req_href['India'].append(href_attr)
            elif "Global" in name: 
                req_href['Global'].append(href_attr)
    except:
        pass

In [10]:
# #Download the excel files from the website and store it in a folder named DATA
# ssl._create_default_https_context = ssl._create_stdlib_context
# 
# import os
# os.makedirs("DATA",exist_ok=True)
# for country,excel_files in req_href.items():
#     country_path = os.path.join("DATA",country) 
#     os.makedirs(country_path,exist_ok=True)
#     for file in excel_files:
#         file_name = file.split("/")[-1].split(".")[0]
#         full_file_name = os.path.join(country_path,f"{file_name}.xls")
#         resp = requests.get(file,verify=False)
#         output = open(full_file_name, 'wb')
#         output.write(resp.content)
#         output.close()

In [11]:
req_href

In [12]:
import os
import ssl
import requests
from requests.exceptions import ChunkedEncodingError
from time import sleep
from urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter

# 创建数据存储文件夹
os.makedirs("DATA", exist_ok=True)
# 创建会话
session = requests.Session()
retry = Retry(connect=5, backoff_factor=1)
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)

# 下载文件
for country, excel_files in req_href.items():
    country_path = os.path.join("DATA", country)
    os.makedirs(country_path, exist_ok=True)
    for file in excel_files:
        file_name = file.split("/")[-1].split(".")[0]
        full_file_name = os.path.join(country_path, f"{file_name}.xls")
        success = False
        attempts = 0
        while not success and attempts < 5:
            try:
                resp = session.get(file, verify=False, stream=True)
                resp.raise_for_status()  # 检查请求是否成功
                with open(full_file_name, 'wb') as output:
                    for chunk in resp.iter_content(chunk_size=8192):
                        if chunk:
                            output.write(chunk)
                success = True
            except ChunkedEncodingError as e:
                attempts += 1
                print(f"ChunkedEncodingError: {e}. Retrying {attempts}/5...")
                sleep(2)
            except requests.exceptions.RequestException as e:
                print(f"RequestException: {e}")
                break


In [13]:
# Sanity check
for country in os.listdir("DATA"):
    dir_len = len(os.listdir(os.path.join("DATA",country)))
    country_len = len(req_href[country])
    print(f'FOR {country} WE HAVE DIRECTORY LEN = {dir_len} and ACTUAL LEN = {country_len}')

## CLEANING THE DATASET

In [3]:
sample_excel = pd.ExcelFile("DATA/US/capex.xls")

In [4]:
sn = 'Variables & FAQ'
sample_excel.parse(sn).head(10)

In [5]:
sample_excel.parse(sample_excel.sheet_names[1]).head(10)

* In the dataset above, there are two sheets. The first sheet is the variables and summary, and the second sheet is the table with the data. 
* We will clean the first sheet to get the table name and summary
* 在上面的数据集中，有两张表。 第一张表是变量和摘要，第二张表是包含数据的表格。
* 我们将清理第一张表以获取表名称和摘要

In [7]:
# import pandas as pd
# from tqdm import tqdm
# import json
# pd.set_option('display.max_rows', 50)
# 
# def sanitize_column_name(col_name):
#     # Remove special characters and replace spaces with underscores
#     return re.sub(r"\W+", "_", col_name)
# 
# dir = "DATA"
# processed_dir = "Processed Data"
# all_infos_dict = []
# os.makedirs(processed_dir,exist_ok=True)
# for country in os.listdir(dir):
#     print(country)
#     file_name = os.path.join(dir,country)
#     os.makedirs(os.path.join(processed_dir,country),exist_ok=True)
#     os.makedirs(file_name,exist_ok=True)
#     for excel_file in tqdm(os.listdir(file_name)):
#         full_file_name = os.path.join(file_name,excel_file)
#         xls = pd.ExcelFile(full_file_name)
#         sns = xls.sheet_names
#         for sheet_name in sns:
#             if "Var" in sheet_name or "var" in sheet_name:
#                 info_df = xls.parse(sheet_name)
#                 info_df.dropna(how="all",inplace=True)
#                 info_dict = {}
#                 for cols in info_df.columns:
#                     if "End" not in cols and 'Unnamed' not in cols:
#                         info_dict['Summary'] = cols
#                 info_dict['Vars'] = info_df.values[1:].tolist()
#                 all_infos_dict.append(info_dict)
#             elif "Industry" in sheet_name or "industry" in sheet_name:
#                 data_df = xls.parse(sheet_name)
#         try:
#             data_df.dropna(axis=1,thresh=5,inplace=True)
#             data_df.dropna(inplace=True)
#             new_header = data_df.iloc[0] #grab the first row for the header
#         except:
#             print(full_file_name)
#             print(data_df)
#         data_df = data_df[1:] #take the data less the header row
#         data_df.reset_index(inplace=True,drop=True)
#         new_header = [sanitize_column_name(str(col)) for col in new_header]
#         data_df.columns = new_header #set the header row as the df header
#         save_name = full_file_name.split(".")[0].split("/")[-1]
#         save_file_path = os.path.join(os.path.join(processed_dir,country),save_name)
#         data_df.to_csv(save_file_path+".csv",index=False)
#         with open(save_file_path+".json", "w") as outfile: 
#             json.dump(info_dict, outfile)

In [8]:
import os
import re
import pandas as pd
from tqdm import tqdm
import json

# 设置 Pandas 显示选项
pd.set_option('display.max_rows', 50)

def sanitize_column_name(col_name):
    # 移除特殊字符并替换空格为下划线
    return re.sub(r"\W+", "_", col_name)

# 目录路径
dir = "DATA"
processed_dir = "Processed Data"
all_infos_dict = []

# 创建处理后的数据目录
os.makedirs(processed_dir, exist_ok=True)

# 遍历每个国家的文件夹
for country in os.listdir(dir):
    print(country)
    country_path = os.path.join(dir, country)
    processed_country_path = os.path.join(processed_dir, country)
    os.makedirs(processed_country_path, exist_ok=True)

    # 遍历每个 Excel 文件
    for excel_file in tqdm(os.listdir(country_path)):
        full_file_name = os.path.join(country_path, excel_file)
        xls = pd.ExcelFile(full_file_name)
        sheet_names = xls.sheet_names

        info_dict = {}
        data_df = pd.DataFrame()

        # 解析每个工作表
        for sheet_name in sheet_names:
            if "Var" in sheet_name or "var" in sheet_name:
                info_df = xls.parse(sheet_name)
                info_df.dropna(how="all", inplace=True)
                for cols in info_df.columns:
                    if "End" not in cols and 'Unnamed' not in cols:
                        info_dict['Summary'] = cols
                info_dict['Vars'] = info_df.values[1:].tolist()
                all_infos_dict.append(info_dict)
            elif "Industry" in sheet_name or "industry" in sheet_name:
                data_df = xls.parse(sheet_name)

        try:
            data_df.dropna(axis=1, thresh=5, inplace=True)
            data_df.dropna(inplace=True)
            new_header = data_df.iloc[0]  # 获取第一行作为表头
            data_df = data_df[1:]  # 去掉表头行
            data_df.reset_index(inplace=True, drop=True)
            new_header = [sanitize_column_name(str(col)) for col in new_header]
            data_df.columns = new_header  # 设置表头
        except Exception as e:
            print(f"Error processing file: {full_file_name}")
            print(e)
            continue

        # 保存处理后的数据
        save_name = os.path.splitext(excel_file)[0]
        save_file_path = os.path.join(processed_country_path, save_name)
        data_df.to_csv(save_file_path + ".csv", index=False)
        with open(save_file_path + ".json", "w") as outfile:
            json.dump(info_dict, outfile)


## A SAMPLE METADATA JSON

In [9]:
all_infos_dict[0]

## DATAFRAME AFTER PREPROCESSING

In [10]:
df = pd.read_csv("Processed Data/US/capex.csv")
df.head()

## BUILD TABLE NAMES AND METADATA

* Here we use a DSPy signature given the first 10 rows of the dataframe, we generate the table name and table explanation. It will help us to dynamically select the correct table based on the query.(我们使用给定数据帧前 10 行的 DSPy 签名，生成表名称和表说明。 它将帮助我们根据查询动态选择正确的表。)

In [11]:
df.head(10).to_csv()

In [3]:
load_dotenv(override=True)
openai.api_key = api_configs.get('openai').get('api_key')

In [4]:
# 这段代码展示了如何使用 DSPy 和 OpenAI API 来生成 SQL 表的元数据。通过定义 SQLTableMetadata 签名和 CoT 模块，可以将 pandas 数据框的前 10 行作为输入，生成合适的 SQL 表名和表的摘要。 
import dspy
turbo = dspy.OpenAI(model=api_configs.get('openai').get('model'), max_tokens=320)
dspy.settings.configure(lm=turbo)

class SQLTableMetadata(dspy.Signature):
    """给定数据表，生成合适的表名和描述"""
    pandas_dataframe_str = dspy.InputField(desc="First 10 rows of a pandas dataframe delimited by newline character") # 前10行的 pandas 数据框，以换行符分隔
    table_name = dspy.OutputField(desc="suitable table name") # 合适的表名 
    table_summary = dspy.OutputField(desc="a summary about the table") # 表的摘要

class CoT(dspy.Module):
    def __init__(self):
        super().__init__()
        self.prog = dspy.ChainOfThought(SQLTableMetadata)
    
    def forward(self, pandas_dataframe_str):
        return self.prog(pandas_dataframe_str=pandas_dataframe_str)

cot = CoT()

# cot(pandas_dataframe_str = df.head(10).to_csv())


In [14]:
processed_dir = "Processed Data"
dfs_str = []
for country in os.listdir(processed_dir):
    country_folder = os.path.join(processed_dir,country)
    # print(f"{country}")
    for files in tqdm(os.listdir(country_folder),desc=f"Building the summary and name for {country}"):
        if files.endswith(".csv"):
            file_name = files.split(".")[0]
            csv_file_path = os.path.join(country_folder,files)
            df = pd.read_csv(csv_file_path,index_col=False)
            json_file_path = os.path.join(country_folder,f"{file_name}.json")
            with open(json_file_path,'r') as f:
                data = json.loads(f.read())
            if 'table_name' in data and 'table_summary' in data:
                # if data['table_name'] == "" or data['table_summary'] == "":
                if data['table_summary'] == "":
                    pass
                else:
                    continue
            dfs_str.append(df.head(10).to_csv())
            table_preds = cot(pandas_dataframe_str = df.head(10).to_csv())
            data['table_name'] = table_preds.table_name
            data['table_summary'] = table_preds.table_summary
            with open(json_file_path,'w') as f:
                json.dump(data, f)

## NEXT TASKS
1. Build database with each region for each table
2. Embed the table summary and table SCHEMA. Also, embed the table rows
3. Retrieval at table level and embed the rows to retrieve relevant rows from the retrieved schema of table
4. Text-to-SQL pipeline

1. 为每个表建立每个区域的数据库
2. 嵌入表摘要和表SCHEMA。 另外，嵌入表格行
3. 在表级别检索并嵌入行以从表的检索模式中检索相关行
4. 文本到 SQL 管道

## BUILD THE SQLITE DATABASE FROM THE CSV FILES (根据 csv 文件建立 sqlite 数据库)

It was taken from the [tutorial](https://docs.llamaindex.ai/en/stable/examples/pipeline/query_pipeline_sql/)
(它摘自 [教程]（https:docs.llamaindex.aienstableexamplespipelinequery_pipeline_sql）)

In [5]:
# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)


# Function to create a table from a DataFrame using SQLAlchemy (使用 SQLAlchemy 从 DataFrame 创建表格的函数)
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()

## DATABASE CREATION (创建数据库)

In [6]:
processed_dir  = "Processed Data"
def sqlalchemy_engine(region:str):
    """Create a SQLAlchemy engine for the given region"""
    assert region in os.listdir(processed_dir), f"{region} is not a valid region from {os.listdir(processed_dir)}"
    # Create a SQLAlchemy database for each region
    engine = create_engine(f"sqlite:///{region}.db")
    metadata_obj = MetaData()
    region_path = os.path.join(processed_dir,region)
    dfs = []
    for dataframes_path in os.listdir(region_path):
        if dataframes_path.endswith(".csv"):
            df = pd.read_csv(os.path.join(region_path,dataframes_path),index_col=False)
            dfs.append((dataframes_path,df))
    pbar = tqdm(total=len(dfs),desc=f"Creating tables for {region}")
    for _, df_table_name in enumerate(dfs):
        table_name = df_table_name[0]
        table_name = table_name.split(".")[0]
        df = df_table_name[1]
        # print(f"Creating table: {table_name}")
        create_table_from_dataframe(df,table_name, engine, metadata_obj)
        # print(f"Done creating table for: {table_name}")
        pbar.update(1)
    return engine

In [9]:
us_engine = sqlalchemy_engine("US")
india_engine = sqlalchemy_engine("India")
china_engine = sqlalchemy_engine("China")
europe_engine = sqlalchemy_engine("Europe")
global_engine = sqlalchemy_engine("Global")
aus_nz_canada_engine = sqlalchemy_engine("AUS_NZ_CANADA")
japan_engine = sqlalchemy_engine("Japan")
emerging_engine = sqlalchemy_engine("Emerging")

In [10]:

def get_table_infos(sql_engine:sqlalchemy.engine.base.Engine,region:str):
    """Get all the tables info in the database based on the given region"""
    inspector = inspect(sql_engine)
    table_names = inspector.get_table_names()
    table_infos_dict = {tb: [] for tb in table_names}
    for tb in table_names:
        column_dict = inspector.get_columns(tb)
        schema_str = ""
        primary_keys = []
        for col in column_dict:
            schema_str += f"{col['name']} ({col['type']}), "
            if col["primary_key"] not in primary_keys:
                primary_keys.append(col["name"])
        with open(os.path.join(processed_dir,region,f"{tb}.json")) as f:
            table_info = json.loads(f.read())
        table_infos_dict[tb] = [
            {
                "table_info": f"Table {tb} has columns: {schema_str[:-2]}",
                "table_summary": f'{table_info.get("Summary",None)}. {table_info["table_summary"]}. ',
            }
        ]
    return table_infos_dict

In [11]:
# 获取每一个类型中的每个表的详细信息
us_tb_dict = get_table_infos(us_engine,"US")
india_tb_dict = get_table_infos(india_engine,"India")
china_tb_dict = get_table_infos(china_engine,"China")
europe_tb_dict = get_table_infos(europe_engine,"Europe")
global_tb_dict = get_table_infos(global_engine,"Global")
aus_nz_canada_tb_dict = get_table_infos(aus_nz_canada_engine,"AUS_NZ_CANADA")
japan_tb_dict = get_table_infos(japan_engine,"Japan")
emerging_tb_dict = get_table_infos(emerging_engine,"Emerging")

In [12]:
us_tb_dict

In [13]:
us_tb_dict['DollarUS']

## EMBEDDINGS

1. Embed the table summary and table SCHEMA to get the table that the user is looking for (1.嵌入表格摘要和表格 SCHEMA，以获取用户要查找的表格)
2. Embed the table rows for each table, so as to get relevant rows from the retrieved table  (2.嵌入每个表的表行，以便从检索到的表中获取相关行)

## EMBED THE TABLE SUMMARY AND TABLE SCHEMA  (嵌入表格摘要和表格模式)

In [14]:
# 使用了 chromadb 库和 OpenAI 的嵌入模型。这个函数的主要目的是将表的概要信息和表的模式（schema）嵌入到向量空间中，以便于后续的查询和匹配操作
import chromadb
import chromadb.utils.embedding_functions as embedding_functions
from chromadb.utils.batch_utils import create_batches

load_dotenv(override=True)
emb_fn = embedding_functions.OpenAIEmbeddingFunction(
                api_key=api_configs.get('openai').get('api_key'),
                model_name="text-embedding-3-small")
# EMBEDDING_MODEL = "mixedbread-ai/mxbai-embed-large-v1"
# emb_fn = embedding_functions.HuggingFaceEmbeddingFunction(model_name=EMBEDDING_MODEL,api_key=os.environ["HF_API_KEY"])
def embed_table_info(region:str,tb_dict):
    """Embed the table summary and table SCHEMA to get the table that the user is looking for"""
    client = chromadb.PersistentClient(path=f"{region}_TABLE")

    table_collection = client.create_collection(name="table",embedding_function=emb_fn)

    table_docs = []
    table_metadata = []


    for table_name,table_data in tb_dict.items():
        table_docs.append(table_data[0]['table_info'] + ". " + table_data[0]['table_summary'])
        table_metadata.append({"table_name":table_name,'table_metadata':table_data[0]['table_info']})
    table_ids = [f"id{i}" for i in range(len(table_docs))]
    assert len(table_docs) == len(table_metadata)
    print(len(table_docs),len(table_metadata))
    # Create a batch of data to be sent to OpenAI Embedding API
    batches = create_batches(api=client,ids=table_ids, documents=table_docs, metadatas=table_metadata)
    for batch in tqdm(batches,desc="Embedding table info"):
        table_collection.add(ids=batch[0],
                    documents=batch[3],
                    metadatas=batch[2])

# embed_table_info("US",us_tb_dict)

## For some strange reason, the `create_batches` was not batching the below documents, hence I had to do it manually (由于某些奇怪的原因，"create_batches "没有对以下文件进行批处理，因此我不得不手动操作)

In [15]:
# 使用了 chromadb 库和 OpenAI 的嵌入模型。这个函数的主要目的是将表格中的每一行数据嵌入到向量空间中，以便于后续的查询和匹配操作。以下是该代码的详细解析：
def embed_rows(region:str,batch_size:int=24):
    client = chromadb.PersistentClient(path=f"{region}_TABLE")
    # client.delete_collection(name="rows")
    rows_collection = client.create_collection(name="rows",embedding_function=emb_fn)

    rows_docs = []
    rows_metadata = []
    region_path = os.path.join(processed_dir,region)
    for df_path in os.listdir(region_path):
        df_full_path = os.path.join(region_path,df_path)
        df = pd.read_csv(df_full_path,index_col=False)
        for idx,row in df.iterrows():
            row_str = ""
            full_rows = []
            for rv in row.values:
                if isinstance(rv,str):
                    row_str+= rv + ", "
                full_rows.append(str(rv))
                row_str = row_str.replace('"',"")
                # row_str = row_str.replace("'",'"')
            full_rows_str = ", ".join(full_rows)[:-2]
            full_rows_str = full_rows_str.replace('"',"")
            rows_docs.append(row_str[:-2])
            rows_metadata.append({"table_name":df_path.split(".")[0],"region":region,"index":idx,"full_rows":full_rows_str})
    row_ids = [f"id{i}" for i in range(len(rows_docs))]
    # print(len(rows_docs),len(rows_metadata))
    assert len(rows_docs) == len(rows_metadata) == len(row_ids)
    # return rows_docs,rows_metadata,row_ids
    for start in tqdm(range(0,len(rows_docs),batch_size),desc="Embedding rows"):
        end = min(start+batch_size,len(rows_docs))
        batch_ids = row_ids[start:end]
        batch_rows = rows_docs[start:end]
        batch_metadatas = rows_metadata[start:end]
        rows_collection.add(ids=batch_ids,
                    documents=batch_rows,
                    metadatas=batch_metadatas)
    # return batches

In [17]:
region = "US"
embed_table_info(region,us_tb_dict)
embed_rows(region,2000)

In [21]:
region = "India"
embed_table_info(region,india_tb_dict)
embed_rows(region,2000)

In [19]:
region = "China"
embed_table_info(region,china_tb_dict)
embed_rows(region,2000)

In [20]:
region = "Europe"
embed_table_info(region,europe_tb_dict)
embed_rows(region,1000)

In [22]:
region = "Global"
embed_table_info(region,global_tb_dict)
embed_rows(region,2000)

In [23]:
region = "Emerging"
embed_table_info(region,emerging_tb_dict)
embed_rows(region,2000)

In [24]:
region = "Japan"
embed_table_info(region,japan_tb_dict)
embed_rows(region,2000)

In [25]:
region = "AUS_NZ_CANADA"
embed_table_info(region,aus_nz_canada_tb_dict)
embed_rows(region,2000)

## TEXT-TO-SQL PIPELINE

### LOAD DATABASE

In [26]:
db_dict = {
    "US":us_engine,
    "India":india_engine,
    "China":china_engine,
    "Europe":europe_engine,
    "Global":global_engine,
    "AUS_NZ_CANADA":aus_nz_canada_engine,
    "Japan":japan_engine,
    "Emerging":emerging_engine,
}

def get_collections_db(region:str):
    # Get the database for the given region, table collection and row collection
    client = chromadb.PersistentClient(path=f"{region}_TABLE")
    table_collection = client.get_collection(name="table",embedding_function=emb_fn)
    row_collection = client.get_collection(name="rows",embedding_function=emb_fn)
    return [db_dict[region],table_collection,row_collection]

In [27]:
db_collection_dict = {
    "US":get_collections_db("US"),
    "India":get_collections_db("India"),
    "China":get_collections_db("China"),
    "Europe":get_collections_db("Europe"),
    "Global":get_collections_db("Global"),
    "AUS_NZ_CANADA":get_collections_db("AUS_NZ_CANADA"),
    "Japan":get_collections_db("Japan"),
    "Emerging":get_collections_db("Emerging"),
}

In [28]:
load_dotenv(override=True)
text_to_sql = dspy.OpenAI(model=api_configs.get('openai').get('model'), max_tokens=1024)
sql_to_answer = dspy.OpenAI(model=api_configs.get('openai').get('model'),max_tokens=1024)

# DSPy signature for converting text to SQL query (# 将文本转换为 SQL 查询的 DSPy 签名)
# 
class TextToSQLAnswer(dspy.Signature):
    """Convert natural language text to SQL using suitable schema(s) from multiple schema choices"""

    question:str = dspy.InputField(desc="natural language input which will be converted to SQL")
    relevant_table_schemas_rows:str = dspy.InputField(desc="Multiple possible tables which has table name and corresponding columns, along with relevant rows from the table (values in the same order as columns above)")
    sql:str = dspy.OutputField(desc="Generate syntactically correct sqlite query with correct column names using suitable tables(s) and its rows.\n Don't forget to add distinct.\n Please rename the returned columns into suitable names.\n DON'T OUTPUT anything else other than the sqlite query")

# DSPy signature for converting SQL query and question to natural language text
class SQLReturnToAnswer(dspy.Signature):
    """Answer the question using the rows from the SQL query"""
    question:str = dspy.InputField()
    sql:str = dspy.InputField(desc="sqlite query that generated the rows")
    relevant_rows:str = dspy.InputField(desc="relevant rows to answer the question")
    answer:str = dspy.OutputField(desc="answer to the question using relevant rows and the sql query")

# If there is an SQLError, then rectify the error by trying again
class SQLRectifier(dspy.Signature):
    """Correct the SQL query to resolve the error using the proper table names, columns and rows"""  
    input_sql:str = dspy.InputField(desc="sqlite query that needs to be fixed")
    error_str: str = dspy.InputField(desc="error that needs to be resolved")
    relevant_table_schemas_rows:str = dspy.InputField(desc="Multiple possible tables which has table name and corresponding columns, along with relevant rows from the table (values in the same order as columns above)")
    sql:str = dspy.OutputField(desc="corrected sqlite query to resolve the error and remove and any invalid syntax in the query.\n Don't output anything else other than the sqlite query")

dspy.settings.configure(lm=text_to_sql)

# Filter out the SQL Query
def process_sql_str(sql_str:str):
    sql_str = sql_str.replace("```","")
    sql_str = sql_str.replace("sql","")
    sql_str = sql_str.strip()
    return sql_str

<p align="center">
  <img src="https://raw.githubusercontent.com/Athe-kunal/Text-to-SQL/main/Schema.png" alt="Sublime's custom image"/>
</p>

In [31]:
# 用于在指定的表集合和行集合中查询与给定问题相关的结果
def get_table_results(table_collection_,question:str):
    # question_emb = emb_fn.embed_with_retries(question)[0]
    # Get the table results for the given question
    table_results = table_collection_.query(
        query_texts = question,
        n_results = 5
    )
    # print(table_results['documents'][0])
    return table_results

def get_row_results(row_collection_,question,table_name:str):
    # Get the row results for the given question
    row_results = row_collection_.query(
        query_texts = question,
        where = {"table_name":table_name},
        n_results = 5
    )
    print(row_results['documents'][0])
    return row_results

In [34]:
from typing import Any
import re

class TextToSQLQueryModule(dspy.Module):
    """Text to SQL to final module"""
    def __init__(self,region:str,use_cot:bool=True,max_retries:int=3):
        """Text to Answer init module

        Args:
            region (str): Region for which the module will be used.
            use_cot (bool, optional): Whether to use chain of thought for sql query generation. Defaults to True.
            max_retries (int, optional): Number of max retries for SQLError. Defaults to 3.
        在初始化时设置区域、数据库连接、表集合、行集合等参数。
        
        """
        super().__init__()
        self.region = region
        db,table_collection,row_collection = db_collection_dict[region]
        # print(db,table_collection,row_collection)
        self.table_collection = table_collection
        self.use_cot = use_cot
        self.db = db
        self.row_collection = row_collection
        if self.use_cot == True:
            self.sqlAnswer = dspy.ChainOfThought(TextToSQLAnswer)
        else:
            self.sqlAnswer = dspy.Predict(TextToSQLAnswer)
        self.final_output = dspy.Predict(SQLReturnToAnswer)
        self.max_tries = max_retries
        # Initialize the sql rectifier with CoT reasoning
        self.sql_rectifier = dspy.ChainOfThought(SQLRectifier,rationale_type=dspy.OutputField(
            prefix="reasoning: let's think step by step in order to",
            desc="${produce the answer}. We ..."
        ))
    
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.forward(*args, **kwargs)
        
    def forward(self,question):
        """
            
    `forward` 函数通过以下步骤实现了从自然语言问题到 SQL 查询再到最终答案的转换：
    1. 将问题嵌入到向量空间中。
    2. 检索与问题相关的表格和行数据。
    3. 生成 SQL 查询。
    4. 执行 SQL 查询并处理可能的异常。
    5. 处理查询结果并生成最终答案。 
        """
        # Embed the question with embedding function
        question_emb = emb_fn([question])[0]
        # Retrieve the relevant tables from table schema and table summary
        docs = self.table_collection.query(
            query_embeddings = question_emb,
            n_results = 5
        )
        # docs = get_table_results(db_collection_dict[self.region][1],question)
        relevant_rows_schemas = ""
    
        existing_table_names = []

        for table_idx,metadata_name in enumerate(docs['metadatas'][0]):
            table_metadata = metadata_name['table_metadata']
            table_name = metadata_name['table_name']
            # If the table name is already in the list of existing table names, skip it
            # if table_name in existing_table_names: 
            #     continue
            existing_table_names.append(table_name)
            # Retrieve the relevant rows from the current table
            rows = self.row_collection.query(
                query_embeddings = question_emb,
                n_results = 5,
                # where clause to filter the rows
                where = {"table_name":table_name}
            )
            # Retrieve the relevant table with the schema and summary
            relevant_rows_schemas += f'Table name: {table_name} \n'
            relevant_rows_schemas += "/* \n"
            for match in re.finditer("columns: ",table_metadata):
                cols_end = match.end()
            relevant_rows_schemas += "col : " + " | ".join(table_metadata[cols_end:].split(", ")) + "\n"
            for row_idx,row in enumerate(rows['metadatas'][0]):
                # Get the relevant rows from the current table
                # relevant_rows_schemas += f'\tRow {row_idx+1} from table {table_name}: {row["full_rows"]}\n'
                relevant_rows_schemas += f'row {row_idx+1} : {" | ".join(row["full_rows"].split(", "))}\n'
            relevant_rows_schemas += "*/" + '\n\n'
        print(colored(relevant_rows_schemas,"yellow"))
        # return 
        sql_query = self.sqlAnswer(question=question,relevant_table_schemas_rows=relevant_rows_schemas)

        num_tries = 0
        print(sql_query)
        while num_tries <= self.max_tries:
            with self.db.connect() as conn:
                try:
                    # Try executing the sql query for the database
                    result = conn.execute(text(process_sql_str(sql_query.sql)))
                    num_tries = self.max_tries + 1
                except Exception as error:
                    # If there is an sql error, then try again with the sql rectifier
                    print(colored(str(error),'red'))
                    sql_query = self.sql_rectifier(input_sql=sql_query.sql,error_str=str(error),relevant_table_schemas_rows=relevant_rows_schemas)
                    print(colored(sql_query.rationale,'green'))
                    print()
                    print(colored(sql_query.sql,'green'))
                    # If all the num_retries are exhausted, then exit the program
                    num_tries += 1
                    if num_tries == self.max_tries+1:
                        return sql_query,error
        # With the retrieved rows from the database, then try to answer the question with dspy context
        with dspy.context(lm=sql_to_answer):
            row_str = ""
            key = tuple(result.keys())
            for row in result.fetchall():
                for r,k in zip(row,key):
                    row_str += f" {k} = {r},"
                row_str = row_str[:-1]
                row_str += "\n"
            print(f"Extracted rows: {row_str}")
            final_answer = self.final_output(question=question,sql=sql_query.sql,relevant_rows=row_str)
            return final_answer
tsql_ = TextToSQLQueryModule("US")
question = "What is the ebitda of software and packaging industry?" 
# "EBITDA" 是 "Earnings Before Interest, Taxes, Depreciation, and Amortization" 的缩写，中文翻译为“息税折旧摊销前利润”。这是一种常用的财务指标，用于衡量公司的经营业绩，不包括利息、税项、折旧和摊销的影响。
# 软件和包装行业的息税折旧摊销前利润是多少？
sq = tsql_(question = question)

In [35]:
print(sq)

In [36]:
# tsql = TextToSQLQueryModule("US")
# sq = tsql(question="What is the effective tax rate of the healthcare industry?")
sq = tsql_(question="What is the EBITDA value and number of firms for all the Software industries, semiconductor industry and aerospace?")
# 所有软件行业、半导体行业和航空航天行业的息税折旧摊销前利润（EBITDA）值和公司数量是多少？
# sq = tsql("What is the debt to EBITDA ratio for software industry?")

In [38]:
print(sq.answer)
#贝塔值（Beta Value）是一个衡量股票或投资组合相对于整个市场波动性的指标。具体来说，它表示某个资产的价格变动相对于市场整体价格变动的敏感程度。贝塔值在金融学中是资本资产定价模型（CAPM，Capital Asset Pricing Model）中的一个重要参数，用于评估系统性风险

In [54]:
tsql = TextToSQLQueryModule("India")
sq = tsql(question="What is the beta value and number of firms for all the Software industries and semiconductor industry?")
print(sq.answer)

In [55]:
print(sq.answer)

In [40]:
tsql = TextToSQLQueryModule("China")
sq = tsql(question="What is the beta value and number of firms for all the Software industries and semiconductor industry?")
# 所有软件行业和半导体行业的贝塔值和公司数量是多少？
print(sq.answer)

In [41]:
print(sq.answer)

In [58]:
tsql = TextToSQLQueryModule("US")
sq = tsql(question="What is the average tax rate of all healthcare industries?")
print(sq.answer)

In [60]:
print(sq.answer)

In [61]:
sq = tsql(question="Give me the average tax rate of all healthcare industries where revenues per employee is more than 1 million?")
print(sq.answer)

In [62]:
tsql = TextToSQLQueryModule("Europe")
sq = tsql(question="Give me the average tax rate of all banking industries where number of firms is more than 500?")
print(sq.answer)