### 이 노트북의 실습은 아래 원문 워크샵 시나리오를 바탕으로 합니다.

https://github.com/aws-samples/text-to-sql-bedrock-workshop

## Step 0: 라이브러리 설치 및 Athena 연결

In [None]:
!python -m ensurepip --upgrade
!pip install "sqlalchemy" --quiet
!pip install "boto3~=1.34"  --quiet
!pip install "jinja2" --quiet
!pip install "botocore" --quiet
!pip install "pandas" --quiet
!pip install "PyAthena" --quiet
!pip install "faiss-cpu" --quiet

In [None]:
import boto3
import sys

#sys.path.append('../')
from libs.din_sql import din_sql_lib as dsl

In [None]:
ATHENA_RESULTS_S3_LOCATION = "txt2sql-2-us-west-2-389394968670-workshop/athena_results/" # available in cloudformation outputs
ATHENA_CATALOG_NAME = "txt2sql-2-tpc_ds" # available in cloudformation outputs
DB_NAME = "tpcds1"

In [None]:
#model_id = 'anthropic.claude-3-haiku-20240307-v1:0'
model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'

din_sql = dsl.DIN_SQL(bedrock_model_id=model_id)

In [None]:
din_sql.athena_connect(catalog_name=ATHENA_CATALOG_NAME, 
               db_name=DB_NAME, 
               s3_prefix=ATHENA_RESULTS_S3_LOCATION)

## Step 1: Schema Linking 모듈

### 사용할 DB의 스키마 정보 확보

In [None]:
return_sql= din_sql.find_fields(db_name=DB_NAME)
print(return_sql)

### 스키마 선별을 위한 프롬프트 구성 & LLM 요청

In [None]:
question = "Which customer spent the most money in the web store?"

# 프롬프트 구성 : Instruction + Few shot samples + Question
schema_links_prompt = din_sql.schema_linking_prompt_maker(question, DB_NAME)

# Word-in-mouth 추가 : Instruction for CoT
word_in_mouth_schema_link = f'A. Let’s think step by step. In the question "{question}", we are asked:'

# LLM 질문 & 답변 생성
schema_links = din_sql.llm_generation(
                    schema_links_prompt,
                    stop_sequences=['</links>'],
                    word_in_mouth=word_in_mouth_schema_link
                    )

### 결과 확인

In [None]:
print(f"{word_in_mouth_schema_link}{schema_links}")

# 참고 : 위에서 stop_sequences로 </links>를 걸었기 때문에, 스키마 링크가 출력되는 직후 답변이 중단됨

In [None]:
links = schema_links.split('<links>')[1].replace('\n','')
links

#### 이제 자연어 질문 처리에 필요한 스키마 링크가 확보되었습니다.

## Step 2: Classification 모듈

### 쿼리 난이도 분류를 위한 프롬프트 구성 & LLM 요청

In [None]:
# 프롬프트 구성 : Instruction + Few shot samples + Link + Question
classification_prompt = din_sql.classification_prompt_maker(question, DB_NAME, links)

# Word-in-mouth 추가 : Instruction for CoT
word_in_mouth_classification = "A: Let’s think step by step."

# LLM 질문 & 답변 생성
classification = din_sql.llm_generation(
                    classification_prompt,
                    stop_sequences=['</label>'],
                    word_in_mouth=word_in_mouth_classification
                    )

### 결과 확인

In [None]:
print(f"{word_in_mouth_classification}{classification}")

In [None]:
predicted_class = classification.split('<label>')[1].replace('\n','')
predicted_class

## Step 3: SQL Generation 모듈

### 쿼리 생성을 위한 프롬프트 구성 & LLM 요청

In [None]:
# 프롬프트 구성 : 쿼리의 예상 난이도(EASY / NON-NESTED / NESTED)에 따라 다른 프롬프트를 구성
sql_tag_start = '```sql'
sql_generation_prompt = din_sql.medium_prompt_maker(
                        test_sample_text=question, 
                        database=DB_NAME, 
                        schema_links=links,
                        sql_tag_start=sql_tag_start,
                        sql_tag_end='```')

# Word-in-mouth 추가 : split을 위한 태그 삽입 
word_in_mouth_medium_prompt = f"SQL: {sql_tag_start}"
#word_in_mouth_medium_prompt = "A: Let’s think step by step. For creating the SQL for the given question, we need to join tables. First, create an intermediate representation, then use it to construct the SQL query.\n Intermediate_representation:"

# LLM 질문 & 답변 생성
sql_qry = din_sql.llm_generation(
                        prompt=sql_generation_prompt,
                        stop_sequences=['</example>'],
                        word_in_mouth=word_in_mouth_medium_prompt
                    )

In [None]:
print(f"{word_in_mouth_medium_prompt}{sql_qry}")

In [None]:
SQL = sql_qry.split('```')[0].strip()
#SQL = sql_qry.split('```sql')[1].split('```')[0].strip()
print(f"{SQL}")

### 생성된 SQL 쿼리 테스트

In [None]:
import pandas as pd
result_set = din_sql.query(SQL)
pd.DataFrame(result_set)

### 결과 검증

In [None]:
validation_query = """
    SELECT "c"."c_customer_sk"
    , "c"."c_first_name"
    , "c"."c_last_name"
    , SUM("ws"."ws_net_paid") as total_sales
    FROM "customer" "c" 
    JOIN "web_sales" "ws" 
        ON "ws"."ws_bill_customer_sk" = "c"."c_customer_sk"   
    GROUP BY "c"."c_customer_sk"
    , "c"."c_first_name"
    , "c"."c_last_name"
    ORDER BY total_sales desc
    limit 10
"""
validation_set = din_sql.query(validation_query)
pd.DataFrame(validation_set)

## Step 4: Self Correction 모듈

### 쿼리 정합성 검증

In [None]:
revised_sql = din_sql.debugger_generation(
            prompt=din_sql.debugger(question, DB_NAME, SQL, sql_dialect='presto')
            ).replace("\n", " ")
print(f"{revised_sql}")

In [None]:
SQL = revised_sql.split('```sql')[1].split('```')[0].strip()
print(f"{SQL}")

In [None]:
result_set = din_sql.query(SQL)
pd.DataFrame(result_set)

## End-to-End 수행

In [None]:
#model_id = 'anthropic.claude-3-haiku-20240307-v1:0'
model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'

din_sql = dsl.DIN_SQL(bedrock_model_id=model_id)
din_sql.athena_connect(catalog_name=ATHENA_CATALOG_NAME, 
               db_name=DB_NAME, 
               s3_prefix=ATHENA_RESULTS_S3_LOCATION)

In [None]:
question = "카탈로그가 가장 많이 판매된 해가 언제인가요?"

# Schema Linking 모듈
print("Schema Linking")
schema_links_prompt = din_sql.schema_linking_prompt_maker(question, DB_NAME)
word_in_mouth_schema_link = f'A. Let’s think step by step. In the question "{question}", we are asked:'
%time schema_links = din_sql.llm_generation(schema_links_prompt, stop_sequences=['</links>'], word_in_mouth=word_in_mouth_schema_link)
links = schema_links.split('<links>')[1].replace('\n','')

# Classification 모듈
print("\nClassification and Decomposition")
classification_prompt = din_sql.classification_prompt_maker(question, DB_NAME, links)
word_in_mouth_classification = "A: Let’s think step by step."
%time classification = din_sql.llm_generation(classification_prompt, stop_sequences=['</label>'], word_in_mouth=word_in_mouth_classification)
predicted_class = classification.split('<label>')[1].replace('\n','')

# SQL Generation 모듈
print("\nSQL Generation")
sql_tag_start='```sql'
word_in_mouth_medium_prompt = f"SQL: {sql_tag_start}"
sql_generation_prompt = din_sql.medium_prompt_maker(test_sample_text=question, database=DB_NAME, schema_links=links, sql_tag_start=sql_tag_start, sql_tag_end='```')
word_in_mouth_medium_prompt = f"SQL: {sql_tag_start}"
%time sql_qry = din_sql.llm_generation(prompt=sql_generation_prompt, stop_sequences=['</example>'], word_in_mouth=word_in_mouth_medium_prompt)
SQL = sql_qry.split('```')[0].strip()

# Self Correction 모듈
print("\nSelf Correction") 
debug_prompt = din_sql.debugger(question, DB_NAME, SQL, sql_dialect='presto')
%time revised_sql = din_sql.debugger_generation(prompt=debug_prompt).replace("\n", " ")

# 쿼리 실행
print("\nQuery Execution") 
#SQL = revised_sql.split('```sql')[1].split('```')[0].strip()
SQL = revised_sql.split('```')[0].strip()
%time result_set = din_sql.query(SQL)
pd.DataFrame(result_set)