In [3]:
import os
import sys
from langchain.document_loaders.json_loader import JSONLoader
from langchain.docstore.document import Document
import json
import re
from langchain.vectorstores import FAISS
from langchain.embeddings import BedrockEmbeddings
from functools import reduce
from langchain.prompts import PromptTemplate
from sqlalchemy import MetaData
from sqlalchemy import create_engine


import re
import pandas as pd
import numpy as np
import json
import sqlite3

data_path = './data'
with open(f'{data_path}/tables.json', 'rb') as ofp:
    meta = json.load(ofp)
data = meta[0]

data = [i for i in meta if i['db_id'] == 'department_store']

data = data[0]
columns = data["column_names_original"]
col_df = pd.DataFrame(columns).iloc[1:]
col_df.rename(columns={0: 'table_idx', 1: 'col_name'}, inplace=True)
col_df

types_df = pd.DataFrame(data["column_types"]).iloc[1:]
types_df.rename(columns={0: 'type'}, inplace=True)
types_df

merged_col = pd.concat([col_df, types_df], axis=1)

tables_df = pd.DataFrame(data["table_names_original"])
tables_df.reset_index(inplace=True)
tables_df.columns = ['table_idx', 'table_name']

meta = pd.merge(tables_df, merged_col, on=['table_idx'])
meta = meta.drop(columns=['table_idx'])

In [6]:
from botocore.config import Config
import boto3
DB_NAME = "text2sql"
DB_FAISS_PATH = './vectorstore/db_faiss'

bedrock_region = athena_region = boto3.session.Session().region_name
retry_config = Config(retries = {'max_attempts': 100})
session = boto3.Session(region_name=bedrock_region)
bedrock = session.client('bedrock-runtime', region_name=bedrock_region, config=retry_config)

In [28]:
files = os.listdir('./data/rag')
df = pd.DataFrame()

for f_name in files:
    with open(f'./data/rag/{f_name}', 'rb') as ofp:
        df_tmp = pd.DataFrame(json.load(ofp))
        df = pd.concat([df, df_tmp])

df

Unnamed: 0,tableName,question,tableSchema
0,Order_Items,• 1. What is the total number of unique orders...,order_item_id|order_id|product_id
1,Order_Items,• 2. Which products were included in a specifi...,order_item_id|order_id|product_id
2,Order_Items,• 3. How many times was a particular product o...,order_item_id|order_id|product_id
3,Order_Items,• 4. Can we identify the most frequently order...,order_item_id|order_id|product_id
4,Order_Items,• 5. What is the distribution of order sizes (...,order_item_id|order_id|product_id
...,...,...,...
5,Customer_Addresses,Which customers have had the most frequent ad...,customer_id|address_id|date_from|date_to
6,Customer_Addresses,How does the frequency of address changes var...,customer_id|address_id|date_from|date_to
7,Customer_Addresses,Can you identify any patterns or trends in th...,customer_id|address_id|date_from|date_to
8,Customer_Addresses,What is the average duration of an address as...,customer_id|address_id|date_from|date_to


In [29]:
import pandas as pd


def create_text(row, max_len=509):
    text = ""
    for col, val in row.items():
        text += f"{col}: {val},"
    if len(text) > max_len:
        text = text[:max_len] + "..."

    # print(text.rstrip("\n"))
    return text.rstrip()


# Assuming your DataFrame is called 'df'
df["text"] = df.apply(create_text, axis=1)
df.head(10)

Unnamed: 0,tableName,question,tableSchema,text
0,Order_Items,• 1. What is the total number of unique orders...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 1. What is ..."
1,Order_Items,• 2. Which products were included in a specifi...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 2. Which pr..."
2,Order_Items,• 3. How many times was a particular product o...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 3. How many..."
3,Order_Items,• 4. Can we identify the most frequently order...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 4. Can we i..."
4,Order_Items,• 5. What is the distribution of order sizes (...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 5. What is ..."
5,Order_Items,• 6. Which orders contained multiple instances...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 6. Which or..."
6,Order_Items,• 7. How can we identify potential duplicate o...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 7. How can ..."
7,Order_Items,• 8. What is the chronological sequence of ord...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 8. What is ..."
8,Order_Items,• 9. Can we determine the most popular product...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 9. Can we d..."
9,Order_Items,• 10. Which orders had the highest number of d...,order_item_id|order_id|product_id,"tableName: Order_Items,question: • 10. Which o..."


In [30]:
# Find all items in df  with plot value longer than 512 characters
def find_long_plot_items(df):
    long_plot_items = df[df["text"].str.len() > 512]
    return long_plot_items


find_long_plot_items(df).count()

tableName      0
question       0
tableSchema    0
text           0
dtype: int64

In [10]:
def get_cfn_outputs(stackname, cfn):
    outputs = {}
    for output in cfn.describe_stacks(StackName=stackname)["Stacks"][0]["Outputs"]:
        outputs[output["OutputKey"]] = output["OutputValue"]
    return outputs

In [11]:
import boto3, json


region_name = "us-west-2"

cfn = boto3.client("cloudformation", region_name)
kms = boto3.client("secretsmanager", region_name)

stackname = "opensearch-workshop"
cfn_outputs = get_cfn_outputs(stackname, cfn)

aos_credentials = json.loads(
    kms.get_secret_value(SecretId=cfn_outputs["OpenSearchSecret"])["SecretString"]
)

aos_host = cfn_outputs["OpenSearchDomainEndpoint"]
aos_host

'search-opensearch-workshop-hispcszxbt5gc2yzbjrilid77q.us-west-2.es.amazonaws.com'

In [12]:
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth

auth = (aos_credentials["username"], aos_credentials["password"])

aos_client = OpenSearch(
    hosts=[{"host": aos_host, "port": 443}],
    http_auth=auth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
)

In [13]:
import requests

search_model = {"query": {"match": {"name": "OpenSearch-Cohere"}}, "size": 10}

response = requests.get(
    "https://" + aos_host + "/_plugins/_ml/models/_search", auth=auth, json=search_model
)
model_info = json.loads(response.text)
model_id = model_info["hits"]["hits"][0]["_id"]
model_id

'R25VV5ABxRR9v61Qh50K'

In [33]:
pipeline = {
    "description": "An neural search pipeline for movie index - OpenSearch-cohere-060124084807",
    "processors": [
        {
            "text_embedding": {
                "model_id": model_id,
                "field_map": {
                    "text": "vector_field",
                },
            }
        }
    ],
}

pipeline_id = "text2sql_plot"
# aos_client.ingest.delete_pipeline(id=pipeline_id)
aos_client.ingest.put_pipeline(id=pipeline_id, body=pipeline)

{'acknowledged': True}

In [62]:
index_name = "rag_semantic_ver1"

In [63]:
aos_client.indices.delete(index=index_name)

{'acknowledged': True}

In [64]:
rag_semantic = {
    "settings": {
        "max_result_window": 15000,
        # "analysis": {"analyzer": {"analysis-nori": {"type": "nori", "stopwords": "_korean_"}}},
        "index.knn": True,
        "default_pipeline": pipeline_id,
        "index.knn.space_type": "l2",
    },
    "mappings": {
        "properties": {
            "tableName": {
                "type": "text",
                "fields": {
                          "english": {
                            "type": "text",
                            "analyzer": "english"},                
                            },
            },
            "question": {
                "type": "text",
                "fields": {
                          "english": {
                            "type": "text",
                            "analyzer": "english"},                
                            },                
            },
            "tableSchema": {
                "type": "text",
                "fields": {
                          "english": {
                            "type": "text",
                            "analyzer": "english"},                
                            },     #{"keyword": {"type": "keyword", "ignore_above": 256}},
            },
            "vector_field": {
                "type": "knn_vector",
                "dimension": 1024,
                "method": {"name": "hnsw", "space_type": "l2", "engine": "faiss"},
                "store": True,
            },

        }
    },
}



aos_client.indices.create(index=index_name, body=rag_semantic)

{'acknowledged': True,
 'shards_acknowledged': True,
 'index': 'rag_semantic_ver1'}

In [65]:
from tqdm import tqdm
from opensearchpy import helpers

json_data = df.to_json(orient="records", lines=True)
docs = json_data.split("\n")[:-1]  # To remove the last empty line


def _generate_data():
    for doc in docs:
        yield {"_index": index_name, "_source": doc}


succeeded = []
failed = []
for success, item in helpers.parallel_bulk(
    aos_client, actions=_generate_data(), chunk_size=10, thread_count=1, queue_size=1
):
    if success:
        succeeded.append(item)
    else:
        failed.append(item)

In [66]:
# Refresh the index to make the changes visible
aos_client.indices.refresh(index=index_name)

count = aos_client.count(index=index_name)
print(count)

{'count': 140, '_shards': {'total': 5, 'successful': 5, 'skipped': 0, 'failed': 0}}


In [67]:
def keyword_search(query_text):
    query = {
        "size": 10,
        "_source": {"excludes": ["vector_field"]},
        "query": {
            "multi_match": {
                "query": query_text,
                "fields": ["tableName", "question", "tableSchema"],
            }
        },
    }

    res = aos_client.search(index=index_name, body=query)

    query_result = []
    for hit in res["hits"]["hits"]:
        row = [
            hit["_score"],
            hit["_source"]["tableName"],
            hit["_source"]["question"],
            hit["_source"]["tableSchema"],            
        ]
        query_result.append(row)

    query_result_df = pd.DataFrame(
        data=query_result, columns=["_score", "tableName", "question", "tableSchema"]
    )
    display(query_result_df)

In [75]:
query_text = "aws 전화번호 알려줘"
keyword_search(query_text)

Unnamed: 0,_score,tableName,question,tableSchema


In [76]:
def semantic_search(query_text):
    query = {
        "size": 10,
        "_source": {"excludes": ["vector_field"]},
        "query": {
            "neural": {"vector_field": {"query_text": query_text, "model_id": model_id, "k": 10}},
        },
    }

    res = aos_client.search(index=index_name, body=query)

    query_result = []
    for hit in res["hits"]["hits"]:
        row = [
            hit["_score"],
            hit["_source"]["tableName"],
            hit["_source"]["question"],
            hit["_source"]["tableSchema"],            
        ]
        query_result.append(row)

    query_result_df = pd.DataFrame(
        data=query_result, columns=["_score", "tableName", "question", "tableSchema"]
    )
    display(query_result_df)

In [77]:
semantic_search(query_text)

Unnamed: 0,_score,tableName,question,tableSchema
0,0.528255,Suppliers,• 2. Which suppliers have a phone number listed?,supplier_id|supplier_name|supplier_phone
1,0.517891,Suppliers,• 4. Can you provide a list of suppliers sorte...,supplier_id|supplier_name|supplier_phone
2,0.509304,Addresses,• 1. What is the total number of unique addres...,address_id|address_details
3,0.506874,Addresses,• 4. Can you provide a list of addresses that ...,address_id|address_details
4,0.502813,Staff,• 4. Can you provide a breakdown of the staff ...,staff_id|staff_gender|staff_name
5,0.499103,Suppliers,• 8. Can you identify any suppliers with dupli...,supplier_id|supplier_name|supplier_phone
6,0.495068,Suppliers,• 10. Which suppliers have phone numbers with ...,supplier_id|supplier_name|supplier_phone
7,0.494885,Addresses,• 2. Which addresses have been recently added ...,address_id|address_details
8,0.490502,Suppliers,• 1. What are the names of all the suppliers i...,supplier_id|supplier_name|supplier_phone
9,0.489407,Suppliers,• 7. How many suppliers have a phone number th...,supplier_id|supplier_name|supplier_phone
