In [145]:
from langchain import LLMChain

from gensql.langchain.prompt_layer import PromptLayerCustom
from gensql.langchain.prompt import *
from langchain.embeddings.openai import OpenAIEmbeddings

from qdrant_client import QdrantClient
from typing import List
from qdrant_client.http import models
from langchain.vectorstores import Qdrant
import os
from pathlib import Path
from dotenv import load_dotenv
from underthesea import text_normalize
from tqdm import tqdm
import pandas as pd

env_path = Path(".") / "/Users/hieunguyen/DATN/.env"
load_dotenv(dotenv_path=env_path)

os.environ["EMBEDDING_OPENAI_API_KEY"] = os.environ.get("EMBEDDING_OPENAI_API_KEY")

import promptlayer
# promptlayer.utils.URL_API_PROMPTLAYER
promptlayer.api_key = os.environ.get("OPENAI_API_KEY")
from langchain.embeddings import openai
openai.api_key = os.environ.get("OPENAI_API_KEY")
import json

In [146]:
openai_key = os.environ.get("OPENAI_API_KEY_WITHOUT_AZURE")


In [147]:
class CheckSQL:
    def __init__(self, model_name) -> None:
        self.model_name = model_name

        self.llm = PromptLayerCustom(model_name=self.model_name,
                                     openai_api_key=openai_key,
                                     temperature=0,
                                     max_tokens=512,
                                     pl_tags=['filter_check_SQL', 'api-server']
                                    )
        self.llm_chain = LLMChain(
            llm=self.llm,
            prompt=CHECKSQL_TEMPLATE
        )

    async def get_response(self, question, time, SQL, script_tables):
        # print(ANALYZER_TEMPLATE.format(question=question, time=time, script_tables=script_tables))
        return await self.llm_chain.arun({"question": question, "time": time, "SQL": SQL, "script_tables": script_tables})
    
class GenerateSQL:
    def __init__(self, model_name) -> None:
        self.model_name = model_name

        self.llm = PromptLayerCustom(model_name=self.model_name,
                                     openai_api_key=openai_key,
                                     temperature=0,
                                     max_tokens=512,
                                     pl_tags=['filter_generate_SQL', 'api-server']
                                    )
        self.llm_chain = LLMChain(
            llm=self.llm,
            prompt=GENERATE_SQL_TEMPLATE
        )

    async def get_response(self, question, time, script_tables):
        # print(COMPANY_INFO_PROMPT.format(context=data, question=question))
        return await self.llm_chain.arun({"question": question, "time": time, "script_tables": script_tables})

class RepairSQL:
    def __init__(self, model_name) -> None:
        self.model_name = model_name

        self.llm = PromptLayerCustom(model_name=self.model_name,
                                     openai_api_key=openai_key,
                                     temperature=0,
                                     max_tokens=512,
                                     pl_tags=['filter_repair_SQL', 'api-server']
                                    )
        self.llm_chain = LLMChain(
            llm=self.llm,
            prompt=REPAIR_TEMPLATE
        )

    async def get_response(self, question, time, old_SQL, explanation, script_tables):
        # print(COMPANY_INFO_PROMPT.format(context=data, question=question))
        return await self.llm_chain.arun({"question": question, "time": time, "old_SQL": old_SQL, "explanation": explanation, "script_tables": script_tables})


In [148]:
class QdrantSchema:
    def __init__(self, schema_file, qdrant_host, qdrant_port, collection_name) -> None:
        
        self.schema_file = schema_file
        list_field, metadata = self.get_data()
        # print(list_sector)
        self.doc_search = self.init_embedding_store(list_field, metadata, qdrant_host, qdrant_port, collection_name)
    
    def get_data(self):
        with open(self.schema_file) as fp:
            data = json.load(fp)

        self.schema_json = data

        list_field = []
        metadata = []
        for table, infor in data.items():
            for item in infor["fields"]:
                list_field.append(item['explanation'].split(', ví dụ:')[0])
                # metadata.append([{'schema': infor['schema']}, {'table': table}, {'field': item['field']}, {'explanation': item['explanation']}, {'key': item['key']}, {'type': item['type']}])
                metadata.append({'schema': infor['schema'], 'table': table, 'field': item['field'], 'explanation': item['explanation'], 'key': item['key'], 'type': item['type']})
        
        return list_field, metadata

    def get_list_exist(self, qdrant_client, collection_name, emb):
        res = qdrant_client.get_collections()
        is_exist = False
        for cl in res.collections:
            if cl.name==collection_name:
                is_exist = True
                break
        # return False, []
        list_page_docs = []
        if is_exist:
            data =qdrant_client.search(
                collection_name=collection_name,
                query_vector=emb,
                with_vectors=False,
                with_payload=True,
                limit=1000000
            )
            for doc in data:
                list_page_docs.append(doc.payload['page_content'])
        return is_exist, list_page_docs



    def remove_not_exist(self, qdrant_client, collection_name, sample_list, emb):
        res = qdrant_client.get_collections()
        is_exist = False
        for cl in res.collections:
            if cl.name==collection_name:
                is_exist = True
                break

        if is_exist:
            data = qdrant_client.search(
                collection_name=collection_name,
                query_vector=emb,
                with_vectors=False,
                with_payload=True,
                limit=1000000
            )
        samples = [x for x in sample_list]

        for doc in data:
            if str(doc.payload['page_content']) not in samples:
                # delete this one 
                print(f"Delete sample: {str(doc.payload['page_content'])}")
                qdrant_client.delete(
                    collection_name=collection_name,
                    points_selector=models.PointIdsList(
                        points=[doc.id],
                    ),
                    wait=True
                )


    def init_embedding_store(self, texts: List[str], metadatas: List[str], qdrant_host: str, qdrant_port: int, collection_name: str) -> Qdrant:
        embeddings = OpenAIEmbeddings(deployment="text-embedding-ada-002")
        qdrant_client = QdrantClient(host=qdrant_host, port=qdrant_port, timeout=None)
        sample_emb = embeddings.embed_query('Doanh thu của 1 quý trước')
        is_exist, exists =self.get_list_exist(qdrant_client, collection_name, sample_emb)
        if not is_exist:
            qdrant_client.recreate_collection(
                collection_name=collection_name,
                vectors_config=models.VectorParams(
                    size=1536, distance=models.Distance.COSINE
                ))
        
        self.remove_not_exist(qdrant_client, collection_name, texts, sample_emb)
        docsearch = Qdrant(embeddings=embeddings, client=qdrant_client, collection_name=collection_name)
        _texts = []
        _metadatas = []
        # filter existed documents
        for text, metadata in zip(texts, metadatas):
            if text not in exists:
                _texts.append(text)
                _metadatas.append(metadata)

        chunks_text = [_texts[x:x+15] for x in range(0, len(_texts), 15)]
        chunks_meta = [_metadatas[x:x+15] for x in range(0, len(_metadatas), 15)]

        for i in tqdm(range(len(chunks_text))):
            docsearch.add_texts(chunks_text[i],chunks_meta[i], batch_size=15)

        return docsearch

    def return_tables(self, question, top_k=1):
        data = self.doc_search.similarity_search_with_score(query=text_normalize(question), k=top_k)
        print(f'QD: {data}')
        script = ''
        res = pd.DataFrame(data=[data[i][0].metadata for i in range(len(data))])
        for table in self.schema_json:
            if table in res.table.to_list():
                script+= f'''Create table {self.schema_json[table]['schema']}.{table} -- {self.schema_json[table]['description']} ( \n'''
                for field in self.schema_json[table]['fields']:
                    if field['field'] in res.field.to_list() or field['default'] == 1:
                        script += f'''\t{field['field']}\t{field['type']} {field['key'] if field['key'] else ""} -- {field['explanation']} \n'''
                script+= ')\n\n'
        return {'script_tables': script, 'question': question, 'fields': res.field.to_list()}

In [149]:
from db.session import get_postgres_engine
SQLgenerator = GenerateSQL(model_name='gpt-4')
SQLrepairer = RepairSQL(model_name='gpt-4')

qdrant_schema = QdrantSchema(schema_file='/Users/hieunguyen/DATN/gensql/data/filter_schema.json',
                             qdrant_host=os.environ.get('QDRANT_HOST'),
                             qdrant_port=os.environ.get('QDRANT_PORT'),
                             collection_name='filter_schema')

postgres_engine = get_postgres_engine()
with open("/Users/hieunguyen/DATN/gensql/data/mapping_field_name.json", "r") as json_file:
    mapping_field_name = json.load(json_file)



100%|██████████| 1/1 [00:01<00:00,  1.01s/it]


In [150]:
import re
from sqlalchemy.orm import sessionmaker

def process_SQL(response, fields, error=False):
    if "```sql" in response:
        response = response.split("```sql")[1]
        print('debug1: ', response)
        response = response.split("```")[0]
        print('debug2: ', response)
        response = response.strip()
        print('debug3: ', response)
    if response[-1] != ';':
        response += ';'
        print('debug4: ', response)
    if not error:
        pattern = r"trading_time = '(\d{4}-\d{2}-\d{2})'"
        match = re.search(pattern, response)
        if match:
            query_day = match.group(1).strip()
            print(f'Query day: {query_day}')
            session = sessionmaker(postgres_engine)
            SQL = f"""SELECT date
                        FROM public.dim_date
                        WHERE date BETWEEN TIMESTAMP '{query_day}' - INTERVAL '7 days' AND TIMESTAMP '{query_day}'
                        and is_holiday = 'N'
                        ORDER BY date desc
                        limit 1;"""
            print(f'SQL: {SQL}')
            with session() as s:
                result = s.execute(SQL)
            rows = result.unique().all()
            print(f'Rows: {rows}')
            last_trading_day = [dict(zip(result.keys(), rows[i])) for i in range(len(rows))][0]['date'].strftime("%Y-%m-%d")
            print('Last trading day: ', last_trading_day)
            response = response.replace(query_day, last_trading_day)
            print(f'Query day shift -1: {query_day};\t Last trading day: {last_trading_day}')

        new_select = 'SELECT symbol'
        appended = ['symbol']
        pattern = r'WHERE\s+(.*?)\n'
        match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
        where_claude = None
        if match is not None:
            where_claude = match.group(1).strip()
        else:
            pattern = r'WHERE\s+(.*?);'
            match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
            if match is not None:
                where_claude = match.group(1).strip()
        where_claude_appended = where_claude
        print(f'Where cluase 1 : {where_claude}')
        for field in fields:
            if (f' {field}' in response) and (field not in appended):
                if f', {field}' not in response:
                    new_select += f', {field}'
                
                if where_claude is not None:
                    if f'{field} IS NOT NULL' not in response:
                        where_claude_appended += f'\nAND {field} IS NOT NULL'
                        print(f'Where cluase 2 : {where_claude_appended}')
            
                appended.append(field)
            
            if (f'{field}' in response) and (field not in appended):
                if where_claude is not None:
                    if f'{field} IS NOT NULL' not in response:
                        where_claude_appended += f'\nAND {field} IS NOT NULL'
                        print(f'Where cluase 3 : {where_claude_appended}')
            
                appended.append(field)
        
        if 'GROUP BY' not in response:
            for field in ['qtr_number', 'year_number', 'quarter_year_name', 'trading_time']:
                if (f' {field}' in response) and (field not in appended):
                    if f', {field}' not in response:
                        new_select += f', {field}'
                        appended.append(field)

        response = response.replace('SELECT symbol', new_select)
        print(f'New select: {response}')
        if where_claude is not None:
            response = response.replace(where_claude, where_claude_appended)
    return response

In [151]:
from decimal import Decimal
from typing import Dict


def process_table(query_result: List[Dict]) -> str :
    df = pd.DataFrame(query_result)#.to_html(index=False)
    for column in df.columns:
        if df[column].dtype==Decimal:
            try:
                df[column] = df[column].astype(float)
            except ValueError:
                pass

    df = df.round(2)
    rename_dict = {}
    for column in df.columns:
        try:
            new_name = mapping_field_name[column]
        except:
            new_name = column

        if 'Tỷ' in new_name:
            df[column] = df[column]/1e9
            df[column] = df[column].astype(float)
            df[column] = df[column].apply(lambda x: "{:,.2f}".format(x))
        elif (df[column].dtype == float) or (df[column].dtype == int):
            if (column != 'year_number') and (column != 'qtr_number'):
                if any(df[column] >= 1e9) or any(df[column] <= -1e9):
                    df[column] = df[column]/1e9
                    df[column] = df[column].astype(float)
                    new_name += ' (Tỷ)'
                    df[column] = df[column].apply(lambda x: "{:,.2f}".format(x))
                else:
                    df[column] = df[column].apply(lambda x: "{:,}".format(x))
        
        rename_dict[column] = new_name

    df = df.rename(columns=rename_dict)
    return df

In [152]:
from datetime import datetime, timedelta


async def symbol_filter(native_query: str = "", max_try: int = 2, debug: bool = False, date_time: str = None):
    executed_params = {}
    executed_params['native_query'] = native_query
    response = {}
    if debug:
        response['executed_params'] = executed_params

    # get current time
    if date_time is None:
        now = datetime.now()
        date_time = (now + timedelta(hours=7)).strftime("%Y-%m-%d")# + f' Q{1 + now.month//4}-{now.year}'
    print(f'Current time: {date_time}')

    # get relevant schema
    top_k = 15
    rel_schema = qdrant_schema.return_tables(native_query, top_k=top_k)
    rel_schema['fields'].append('close_price')
    print(f'Relevant schema: {rel_schema}')

    native_query = native_query.replace('pe', 'p/e').replace('pb', 'p/b').replace('PE', 'P/E').replace('PE', 'P/E')
    if 'sma' not in native_query:
        native_query = native_query.replace('ma', 'sma')
    if 'SMA' not in native_query:
        native_query = native_query.replace('MA', 'SMA')
    print(f'Native query: {native_query}')

    # Generate SQL and query DB
    error = False
    while max_try > 0:
        #  Generate SQL
        if not error:
            try:
                SQL = await SQLgenerator.get_response(question=native_query,
                                                time=date_time,
                                                script_tables=rel_schema['script_tables'])
                if "SELECT" not in SQL:
                    print(f'Generated SQL:\n{SQL}')
                    response['response'] = 'Hiện tại em không có thông tin để trả lời câu hỏi.'
                    print(f"Response: {response['response']}")
                    return response
                print('*1.1*SQL:', SQL)
                SQL = process_SQL(SQL, fields=rel_schema['fields'])
                print('*1*SQL:', SQL)

                if "IS NOT NULL" not in SQL:
                    print(f'Old SQL:\n{SQL}')
                    SQL = await SQLrepairer.get_response(question=native_query,
                                            time=date_time,
                                            old_SQL=SQL,
                                            explanation="Bạn cần lọc bỏ các giá trị NULL",
                                            script_tables=rel_schema['script_tables'])
                    print('*2.1*SQL :', SQL)
                    SQL = process_SQL(SQL, fields=rel_schema['fields'])
                    print('*2*SQL :', SQL)

            except Exception as e:
                response['error'] = 'API_RATE_LIMIT'
                print(e)
                return response
        else:
            try:
                SQL = await SQLrepairer.get_response(question=native_query,
                                            time=date_time,
                                            old_SQL=SQL,
                                            explanation=explanation,
                                            script_tables=rel_schema['script_tables'])
                print('*3.1*SQL:', SQL)
                SQL = process_SQL(SQL, fields=rel_schema['fields'], error=True)
                print('*3*SQL:', SQL)
                
            except Exception as e:
                response['error'] = 'API_RATE_LIMIT'
                print(e)
                return response
        max_try -= 1
        print(f'Generated SQL:\n{SQL}')
        
        # Query DB

        try:
            session = sessionmaker(postgres_engine)
            with session() as s:
                result = s.execute(SQL)
            rows = result.unique().all()[:10]
            query_result = [dict(zip(result.keys(), rows[i])) for i in range(len(rows))]
            print(f'Query result:\n{query_result}')
            return query_result

        # repair
        except Exception as e:
            with session() as s:
                s.execute("ROLLBACK")
            error = True
            query_result = str(e)
            max_try -= 1
            explanation = query_result
            if 'does not exist' in query_result:
                print(f'Câu SQL không phù hợp với câu hỏi: {query_result}')
                response['response'] = 'Hiện tại em không có thông tin để trả lời câu hỏi.'
                print(f"Response: {response['response']}")
                return query_result
            print(f'Query result:\n{query_result}')

        print(f'Generated SQL:\n{SQL}')
        

In [157]:
df_result = await symbol_filter('Top 5 mã có pe, pb cao nhất', debug=True)


Current time: 2024-07-01
QD: [(Document(page_content='Tỷ lệ tăng trưởng doanh thu so với 5 năm trước', metadata={'explanation': 'Tỷ lệ tăng trưởng doanh thu so với 5 năm trước, ví dụ: -0.67', 'field': 'ty_le_tang_truong_doanh_thu_5_y', 'key': None, 'schema': 'dwh_market', 'table': 'fact_symbol_chi_so_tang_truong_yearly', 'type': 'double precision'}), 0.8148963), (Document(page_content='Doanh thu của 5 năm trước', metadata={'explanation': 'Doanh thu của 5 năm trước, ví dụ: 4,236,069,360,770 (đơn vị: đồng)', 'field': 'doanh_thu_5_y_ago', 'key': None, 'schema': 'dwh_market', 'table': 'fact_symbol_chi_so_tang_truong_yearly', 'type': 'double precision'}), 0.8112711), (Document(page_content='Tỷ lệ tăng trưởng lợi nhuận so với 5 năm trước', metadata={'explanation': 'Tỷ lệ tăng trưởng lợi nhuận so với 5 năm trước, ví dụ: 1.46', 'field': 'ty_le_tang_truong_loi_nhuan_5_y', 'key': None, 'schema': 'dwh_market', 'table': 'fact_symbol_chi_so_tang_truong_yearly', 'type': 'double precision'}), 0.81080

In [158]:
df_process = process_table(df_result)
df_process

Unnamed: 0,mack,pe,pb
0,VIT,10851.7,1.66
1,LBM,6094.69,1.65
2,BTH,3532.85,1.34
3,PXA,2197.11,0.58
4,VRC,2028.03,0.41
