In [1]:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from chromadb.config import Settings
import chromadb

persist_directory="vectorstore/"
top_k = 4

CHROMA_SETTINGS = Settings(
    persist_directory=persist_directory,
    anonymized_telemetry=False
)

embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS , path=persist_directory)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS, client=chroma_client)

retriever = db.as_retriever(search_kwargs={"k": top_k})


In [2]:
from langchain.llms import GPT4All
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

model_path = "../privateGPT/models/ggml-model-gpt4all-falcon-q4_0.bin"

llm = GPT4All(model=model_path, max_tokens=1000, backend='gptj',
                          n_batch=8, callbacks=[StreamingStdOutCallbackHandler()], verbose=False)


Found model file at  ../privateGPT/models/ggml-model-gpt4all-falcon-q4_0.bin
falcon_model_load: loading model from '../privateGPT/models/ggml-model-gpt4all-falcon-q4_0.bin' - please wait ...
falcon_model_load: n_vocab   = 65024
falcon_model_load: n_embd    = 4544
falcon_model_load: n_head    = 71
falcon_model_load: n_head_kv = 1
falcon_model_load: n_layer   = 32
falcon_model_load: ftype     = 2
falcon_model_load: qntvr     = 0
falcon_model_load: ggml ctx size = 3872.64 MB
falcon_model_load: memory_size =    32.00 MB, n_mem = 65536
falcon_model_load: ..

objc[2463]: Class GGMLMetalClass is implemented in both /opt/miniconda3/envs/gnn/lib/python3.11/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/libreplit-mainline-metal.dylib (0x294720228) and /opt/miniconda3/envs/gnn/lib/python3.11/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/libllamamodel-mainline-metal.dylib (0x2945e8228). One of the two will be used. Which one is undefined.


...................... done
falcon_model_load: model size =  3872.59 MB / num tensors = 196


In [3]:
from langchain.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    )
from langchain.memory import ConversationSummaryBufferMemory

# TODO: move SystemMessagePromptTemplate to first question
prompt = ChatPromptTemplate (messages=[
    SystemMessagePromptTemplate.from_template("""### Instructions:
Your task is convert a question into a SQL query, given a MYSQL database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question.
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
- Input can contain example query given the question. In the format `Question: Give me col1 of table1;
Query: SELECT table1.col1 FROM table1`.

### Context:
This query will run on a database whose schema is represented in this string:
CREATE TABLE fact_retention_model \{
  fathercodename VARCHAR(50), -- Name of the main fund, 主基金名称
  raisetype VARCHAR(50), -- Type of issurance or raise, 发行类别
  product_type_bi VARCHAR(50), -- Type of product , 产品类型
  agencyname_adj VARCHAR(50), -- Name of agency, 销售渠道
  custtype VARCHAR(50), -- Type of customer, 客户类型
  tougu_flag TINYINT(1), -- If has tougu, 是否投顾
  business_module VARCHAR(50), -- Business module, 业务模块
  cdate DATETIME, -- Date of confirmation, 确认日
  invest_manager VARCHAR(50), -- Name of fund invest manager, 基金经理
  parentareaname VARCHAR(50), -- Name of parent area, 大区
  salescenter VARCHAR(50), -- Name of sales center 营销中心
  retail_brokername VARCHAR(50), -- Name of retail broker, 客户经理
  provname VARCHAR(50), -- Name of province, 省
  CITYNAME VARCHAR(50), -- Name of city, 市
  relaname VARCHAR(50), -- Name of real client, 事实客户
  age INT, -- Age, 年龄
  port_code VARCHAR(10) PRIMARY KEY, -- Code of fund, 基金代码
  fathercode VARCHAR(10), -- Code of main fund, 主基金代码
  fundname VARCHAR(10), -- Name of fund 基金名称
  shares FLOAT, -- Latest shares, 最新份额
  asset FLOAT, -- Latest size of fund, 最新规模
  asset_fof FLOAT -- Latest size of fund of fund(fof), 最新规模FOF双算
\}

The example value of each field can be as listed:
fathercodename['中欧盛世', '新蓝筹', '时代先锋', '新常态', '潜力价值', '阿尔法', '时代先锋', '创新成长'],
raisetype['公募基金', '一对一专户', '投顾'],
product_type_bi['混合偏债', '债券类', '权益类', '货币类'],
agencyname_adj['南京证券', '光大证券', '招商银行', '中信证券', '工商银行', '浦发银行'],
custtype['机构', '个人'],
tougu_flag['0', '1'],
business_module['机构一部', '银行', '券商'],
cdate['2022-12-30', '2022-06-30'],
invest_manager['袁维德', '蓝小康', '洪慧梅', '刘金辉', '许文星'],
parentareaname['南方区', '北方区', '华东区'],
salescenter['江南营销中心', '华北营销中心', '西北营销中心', '东北营销中心', '深圳营销中心', '其他'],
retail_brokername['王鹏', '卢肇昱', '冯文欣', '吕霖', '段家庆'],
provname['湖南', '辽宁', '江苏', '河北', '陕西', '天津', '浙江'],
CITYNAME['天津', '无锡', '郑州', '聊城', '兰州', '莆田', '锦州', '重庆', '厦门'],
relaname['南京证券股份有限公司（非经纪业务）', '幸福人寿保险股份有限公司', '广发证券股份有限公司（非经纪业务）', '泰康资产管理有限责任公司'],
age[43, 45, 51, 59, 101, 60, 33],
port_code['005242', '013221', '166005', '150071', '166023'],
fathercode['001117', '001980', '578183', '166020', '166007'],
fundname['中欧创新成长灵活配置混合型证券投资基金A', '中欧增强回报债券（LOF）A', '中欧消费主题股票型证券投资基金A', '中欧远见两年定期开放混合A'],
shares[10002.334, 2566.31, 1026.524, 1539.786, 2789.145, 0],
asset[1424.9179644, 2137.3769466, 3560.064678, 1424.0258712, 2136.0388068],
asset_fof[3562.294911, 1424.9179644, 2137.3769466, 3560.064678, 1424.0258712]
"""),
    MessagesPlaceholder(variable_name="chat_history"),
    HumanMessagePromptTemplate.from_template("{question}")
])

memory = ConversationSummaryBufferMemory(
    llm=llm, 
    max_token_limit=2000, 
    memory_key="chat_history",
    return_messages=True,
    verbose=True,
    )


In [5]:
from langchain.chains import ConversationalRetrievalChain

# TODO: override from_llm class method to allow custom doc_chain
qa = ConversationalRetrievalChain.from_llm(
    llm=llm, 
    retriever=retriever, 
    memory=memory,
    condense_question_prompt=prompt,
    condense_question_llm=llm
    )


In [6]:
question = "如何计算保有规模"
result = qa({"question": question})
result

 To answer the question at the end, we need to join the three queries with the same table `fact_retention_model` and group by the `agencyname` column. Then, we can use the `SUM` function to calculate the total asset values for each agency on the given date range. Finally, we can filter out the agencies that do not have any assets on the given date range using the `WHERE` clause.

SELECT agencyname, SUM(asset) FROM fact_retention_model WHERE product_type_bi='权益类' AND cdate=(DATE_FORMAT(CURRENT_DATE(), '%Y-%m-01')-INTERVAL 1 DAY) GROUP BY agencyname

This query will return the total asset values for each agency on the given date range.

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

{'question': '如何计算保有规模',
 'chat_history': [HumanMessage(content='如何计算保有规模', additional_kwargs={}, example=False),
  AIMessage(content=" To answer the question at the end, we need to join the three queries with the same table `fact_retention_model` and group by the `agencyname` column. Then, we can use the `SUM` function to calculate the total asset values for each agency on the given date range. Finally, we can filter out the agencies that do not have any assets on the given date range using the `WHERE` clause.\n\nSELECT agencyname, SUM(asset) FROM fact_retention_model WHERE product_type_bi='权益类' AND cdate=(DATE_FORMAT(CURRENT_DATE(), '%Y-%m-01')-INTERVAL 1 DAY) GROUP BY agencyname\n\nThis query will return the total asset values for each agency on the given date range.", additional_kwargs={}, example=False)],
 'answer': " To answer the question at the end, we need to join the three queries with the same table `fact_retention_model` and group by the `agencyname` column. Then, we can us

In [7]:
question2 = "如何获取所有表的结构"
result2 = qa({"question": question2})
result2


ValueError: Missing some input keys: {'\n  fathercodename VARCHAR(50), -- Name of the main fund, 主基金名称\n  raisetype VARCHAR(50), -- Type of issurance or raise, 发行类别\n  product_type_bi VARCHAR(50), -- Type of product , 产品类型\n  agencyname_adj VARCHAR(50), -- Name of agency, 销售渠道\n  custtype VARCHAR(50), -- Type of customer, 客户类型\n  tougu_flag TINYINT(1), -- If has tougu, 是否投顾\n  business_module VARCHAR(50), -- Business module, 业务模块\n  cdate DATETIME, -- Date of confirmation, 确认日\n  invest_manager VARCHAR(50), -- Name of fund invest manager, 基金经理\n  parentareaname VARCHAR(50), -- Name of parent area, 大区\n  salescenter VARCHAR(50), -- Name of sales center 营销中心\n  retail_brokername VARCHAR(50), -- Name of retail broker, 客户经理\n  provname VARCHAR(50), -- Name of province, 省\n  CITYNAME VARCHAR(50), -- Name of city, 市\n  relaname VARCHAR(50), -- Name of real client, 事实客户\n  age INT, -- Age, 年龄\n  port_code VARCHAR(10) PRIMARY KEY, -- Code of fund, 基金代码\n  fathercode VARCHAR(10), -- Code of main fund, 主基金代码\n  fundname VARCHAR(10), -- Name of fund 基金名称\n  shares FLOAT, -- Latest shares, 最新份额\n  asset FLOAT, -- Latest size of fund, 最新规模\n  asset_fof FLOAT -- Latest size of fund of fund(fof), 最新规模FOF双算\n\\'}

In [None]:
result2.get("chat_history")
