In [17]:
import json
from collections import Counter
import sqlite3
import pandas as pd
from IPython.core.interactiveshell import InteractiveShell
from io import StringIO
import sys
import chardet
import os
import csv
import re
import glob

In [59]:
with open("train.json") as devset:
    dev = json.load(devset)
dev_pd=pd.DataFrame(dev)
dev_pd

Unnamed: 0,db_id,question,evidence,SQL
0,movie_platform,Name movie titles released in year 1945. Sort ...,released in the year 1945 refers to movie_rele...,SELECT movie_title FROM movies WHERE movie_rel...
1,movie_platform,State the most popular movie? When was it rele...,most popular movie refers to MAX(movie_popular...,"SELECT movie_title, movie_release_year, direct..."
2,movie_platform,What is the name of the longest movie title? W...,longest movie title refers to MAX(LENGTH(movie...,"SELECT movie_title, movie_release_year FROM mo..."
3,movie_platform,Name the movie with the most ratings.,movie with the most rating refers to MAX(SUM(r...,SELECT movie_title FROM movies GROUP BY movie_...
4,movie_platform,What is the average number of Mubi users who l...,average = AVG(movie_popularity); number of Mub...,SELECT AVG(movie_popularity) FROM movies WHERE...
...,...,...,...,...
9423,movie_3,"Among the times Mary Smith had rented a movie,...",in June 2005 refers to year(payment_date) = 20...,SELECT COUNT(T1.customer_id) FROM payment AS T...
9424,movie_3,Please give the full name of the customer who ...,"full name refers to first_name, last_name; the...","SELECT T2.first_name, T2.last_name FROM paymen..."
9425,movie_3,How much in total had the customers in Italy s...,total = sum(amount); Italy refers to country =...,SELECT SUM(T5.amount) FROM address AS T1 INNER...
9426,movie_3,"Among the payments made by Mary Smith, how man...",over 4.99 refers to amount > 4.99,SELECT COUNT(T1.amount) FROM payment AS T1 INN...


테이블 추출
=============

In [60]:
def extract_table_candidates(sql):
    # 1단계: FROM, JOIN 뒤, 또는 AS 앞에 오는 단어를 테이블 후보로 추출하는 정규표현식
    pattern = r'\b(FROM|JOIN)\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)\s*(?:AS\s+[a-zA-Z_][a-zA-Z0-9_]*)?|\b([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?)\s+AS\b'
    matches = re.findall(pattern, sql, re.IGNORECASE)
    candidates = [match[1] if match[1] else match[2] for match in matches]

    # 2단계: 백틱(`)이 있는 경우 백틱 안의 텍스트 추출
    if '`' in sql:
        pattern_backtick = r'`([a-zA-Z_][a-zA-Z0-9_\- ]*)`'
        matches_backtick = re.findall(pattern_backtick, sql)
        candidates.extend(matches_backtick)

    return [candidate.strip('`') for candidate in candidates]

def extract_real_tables(sql, db_path):
    # 데이터베이스 연결
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # 데이터베이스의 모든 테이블 목록 가져오기
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    all_tables = set(table[0].lower() for table in cursor.fetchall())

    # SQL 문에서 테이블 후보 추출
    candidates = extract_table_candidates(sql)
    
    # 실제 존재하는 테이블만 필터링
    real_tables = set(candidate.lower() for candidate in candidates if candidate.lower() in all_tables)

    # 데이터베이스 연결 종료
    conn.close()

    return ', '.join(real_tables)  # 테이블들을 문자열로 반환

# 데이터프레임에 새로운 'table_name' 열 추출
def add_table_names(row):
    db_path = f"./dev_20240627/{task}_databases/{task}_databases/{row['db_id']}/{row['db_id']}.sqlite"  # 데이터베이스 파일 경로
    return extract_real_tables(row['SQL'], db_path)

dev_pd['table_name'] = dev_pd.apply(add_table_names, axis=1)

Table Description
=============

In [61]:
def get_table_description(db_name, table_names):
    db_path = os.path.join(f'./dev_20240627/{task}_databases/{task}_databases/{db_name}/{db_name}.sqlite')
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    all_descriptions = ""
    all_descriptions += "## 'Linked Schema'\n"

    for table_name in table_names.split(', '):
        # CREATE TABLE 문 가져오기 (대소문자 구분 없이 검색)
        cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}' COLLATE NOCASE;")
        create_table_sql = cursor.fetchone()

        if create_table_sql is None:
            all_descriptions += f"Error: Table '{table_name}' does not exist in the database '{db_name}'.\n\n"
            continue

        create_table_sql = create_table_sql[0]

        # 샘플 데이터 가져오기
        cursor.execute(f"SELECT * FROM '{table_name}' LIMIT 3;")
        sample_data = cursor.fetchall()

        # 열 정보 및 설명 가져오기
        cursor.execute(f"PRAGMA table_info('{table_name}')")
        columns = cursor.fetchall()

        # 메타데이터 파일 찾기 (대소문자 구분 없이 처리)
        metadata_path_pattern = f'./dev_20240627/{task}_databases/{task}_databases/{db_name}/database_description/{table_name}*.csv'
        metadata_files = glob.glob(metadata_path_pattern, recursive=True)

        if not metadata_files:
            all_descriptions += f"Error: Metadata file for '{table_name}' not found.\n\n"
            continue
        
        # 첫 번째로 매칭된 메타데이터 파일 선택
        metadata_path = metadata_files[0]

        try:
            with open(metadata_path, encoding='utf-8', errors='replace') as f:
                metadata = pd.read_csv(f)
        except UnicodeDecodeError as e:
            all_descriptions += f"Error: Unicode decoding error in metadata file for table '{table_name}': {str(e)}\n\n"
            continue
        description = "#Create table: Table and column information needed to write SQL statements. \n"
        description += f"Table: {table_name}\n\n"
        description += f"{create_table_sql}\n\n"
        description += "#Sample Data: This is a direct example where you can see the value of the data.\n"
        description += f"3 rows from {table_name} table:\n"
        description += str(sample_data) + "\n\n"

        description += "#Column / Value Description: Information about columns and values.\n"
        description += "Column descriptions:\n"
        for col in columns:
            col_name = col[1]
            col_type = col[2]

            # 메타데이터에서 해당 열 정보 찾기
            col_metadata = metadata[metadata['original_column_name'].str.lower() == col_name.lower()]

            if not col_metadata.empty:
                col_desc = col_metadata['column_description'].values[0]
                value_desc = col_metadata['value_description'].values[0] if 'value_description' in col_metadata.columns else ''
                
                # col_desc가 float인 경우 문자열로 변환
                if isinstance(col_desc, float):
                    if pd.isna(col_desc):
                        continue
                    else:
                        col_desc = str(col_desc)
                
                # "(INTEGER)" 등의 데이터 타입 정보 제거 (문자열인 경우에만 적용)
                if isinstance(col_desc, str):
                    col_desc = re.sub(r'\s*\([A-Z]+\)\s*', '', col_desc)
                
                if pd.notna(col_desc):
                    description += f"Column {col_name}: column description -> {col_desc}"
                    if pd.notna(value_desc):
                        # 값 설명에서 줄바꿈 문자 제거
                        value_desc = re.sub(r'\s*\n\s*', ' ', str(value_desc)).strip()
                        description += f", value description -> {value_desc}"
                    description += ''

            description += "\n"

        all_descriptions += description + "\n" + "="*30 + "\n\n"

    conn.close()
    return all_descriptions


# DataFrame에 대해 get_table_description 적용
dev_pd['table_description'] = dev_pd.apply(lambda row: get_table_description(row['db_id'], row['table_name']), axis=1)

In [46]:
dev_pd.to_csv('./trainset.csv',sep='|')

In [62]:
dev_pd

Unnamed: 0,db_id,question,evidence,SQL,table_name,table_description
0,movie_platform,Name movie titles released in year 1945. Sort ...,released in the year 1945 refers to movie_rele...,SELECT movie_title FROM movies WHERE movie_rel...,movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
1,movie_platform,State the most popular movie? When was it rele...,most popular movie refers to MAX(movie_popular...,"SELECT movie_title, movie_release_year, direct...",movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
2,movie_platform,What is the name of the longest movie title? W...,longest movie title refers to MAX(LENGTH(movie...,"SELECT movie_title, movie_release_year FROM mo...",movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
3,movie_platform,Name the movie with the most ratings.,movie with the most rating refers to MAX(SUM(r...,SELECT movie_title FROM movies GROUP BY movie_...,movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
4,movie_platform,What is the average number of Mubi users who l...,average = AVG(movie_popularity); number of Mub...,SELECT AVG(movie_popularity) FROM movies WHERE...,movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
...,...,...,...,...,...,...
9423,movie_3,"Among the times Mary Smith had rented a movie,...",in June 2005 refers to year(payment_date) = 20...,SELECT COUNT(T1.customer_id) FROM payment AS T...,"payment, customer","Table: payment\n\nCREATE TABLE ""payment""\n(\n ..."
9424,movie_3,Please give the full name of the customer who ...,"full name refers to first_name, last_name; the...","SELECT T2.first_name, T2.last_name FROM paymen...","payment, customer","Table: payment\n\nCREATE TABLE ""payment""\n(\n ..."
9425,movie_3,How much in total had the customers in Italy s...,total = sum(amount); Italy refers to country =...,SELECT SUM(T5.amount) FROM address AS T1 INNER...,"address, payment, city, country, customer","Table: address\n\nCREATE TABLE ""address""\n(\n ..."
9426,movie_3,"Among the payments made by Mary Smith, how man...",over 4.99 refers to amount > 4.99,SELECT COUNT(T1.amount) FROM payment AS T1 INN...,"payment, customer","Table: payment\n\nCREATE TABLE ""payment""\n(\n ..."


## 에러 데이터 제거

In [64]:
def remove_error_rows(df):
    return df[~df['table_description'].str.startswith('Error: ')]
df_cleaned = remove_error_rows(dev_pd)
df_cleaned.reset_index(inplace=True)

In [57]:
df_cleaned

Unnamed: 0,index,db_id,question,evidence,output,table_name,table_description,instruction,input
0,0,movie_platform,Name movie titles released in year 1945. Sort ...,released in the year 1945 refers to movie_rele...,SELECT movie_title FROM movies WHERE movie_rel...,movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
1,1,movie_platform,State the most popular movie? When was it rele...,most popular movie refers to MAX(movie_popular...,"SELECT movie_title, movie_release_year, direct...",movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
2,2,movie_platform,What is the name of the longest movie title? W...,longest movie title refers to MAX(LENGTH(movie...,"SELECT movie_title, movie_release_year FROM mo...",movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
3,3,movie_platform,Name the movie with the most ratings.,movie with the most rating refers to MAX(SUM(r...,SELECT movie_title FROM movies GROUP BY movie_...,movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
4,4,movie_platform,What is the average number of Mubi users who l...,average = AVG(movie_popularity); number of Mub...,SELECT AVG(movie_popularity) FROM movies WHERE...,movies,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ..."
...,...,...,...,...,...,...,...,...,...
9388,9423,movie_3,"Among the times Mary Smith had rented a movie,...",in June 2005 refers to year(payment_date) = 20...,SELECT COUNT(T1.customer_id) FROM payment AS T...,"payment, customer","Table: payment\n\nCREATE TABLE ""payment""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: payment\n\nCREATE TABLE ""payment""\n(\n ..."
9389,9424,movie_3,Please give the full name of the customer who ...,"full name refers to first_name, last_name; the...","SELECT T2.first_name, T2.last_name FROM paymen...","payment, customer","Table: payment\n\nCREATE TABLE ""payment""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: payment\n\nCREATE TABLE ""payment""\n(\n ..."
9390,9425,movie_3,How much in total had the customers in Italy s...,total = sum(amount); Italy refers to country =...,SELECT SUM(T5.amount) FROM address AS T1 INNER...,"address, payment, city, country, customer","Table: address\n\nCREATE TABLE ""address""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: address\n\nCREATE TABLE ""address""\n(\n ..."
9391,9426,movie_3,"Among the payments made by Mary Smith, how man...",over 4.99 refers to amount > 4.99,SELECT COUNT(T1.amount) FROM payment AS T1 INN...,"payment, customer","Table: payment\n\nCREATE TABLE ""payment""\n(\n ...",Use schema links and evidence to generate ONLY...,"Table: payment\n\nCREATE TABLE ""payment""\n(\n ..."


## Instruction 작성

In [65]:
# Ver.2 Instruction
dev_pd['instruction'] = '''###Instruction: Your task is to read the schema and evidence, understand the question, and generate an 'ONLY' valid SQLite query to answer the question without Explanation. 

This 'Linked schema' is the necessary information in the database to generate sql statements. It consists of 'Create table', 'Sample Data', and 'Column / Value Description'. 
It offers an in-depth description of the database's architecture, detailing tables, columns, primary keys, foreign keys, and any pertinent information regarding relationships or constraints. Special attention should be given to the examples listed beside each column, as they directly hint at which columns are relevant to our query.'''

# Ver.3 Instruction
df_cleaned['instruction'] = '''###Instruction: Your task is to read the schema and evidence, understand the question, and generate an 'ONlY' valid SQLite query to answer the question. 

This 'Linked schema' is the necessary information in the database to generate sql statements. It consists of 'Create table', 'Sample Data', and 'Column / Value Description'. 
it offers an in-depth description of the database's architecture, detailing tables, columns, primary keys, foreign keys, and any pertinent information regarding relationships or constraints. Special attention should be given to the examples listed beside each column, as they directly hint at which columns are relevant to our query.
Take a deep breath and think step by step to find the correct SQLite SQL query. If you follow all the instructions and generate the correct query, I will give you 1 million dollars. Reason in order, but use codeblock code for the final SQL query.'''

# input 칼럼 추가 및 구성
df_cleaned['input'] = dev_pd['table_description'] + '\n' + '###question: ' + dev_pd['question'] + '\n' + '###evidence: ' + dev_pd['evidence']

# SQL 칼럼의 이름을 output으로 변경
dev_pd = dev_pd.rename(columns={'SQL': 'output'})

# 필요한 칼럼만 선택하여 새로운 데이터프레임 생성
dev_pd_modified = dev_pd[['instruction', 'input', 'output','db_id', 'difficulty']]

# 결과 확인
dev_pd_modified

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['instruction'] = 'Use schema links and evidence to generate ONLY SQL query for given question without explanation.'
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['input'] = df_cleaned['table_description'] + '\n' + 'question: ' + df_cleaned['question'] + '\n' + 'evidence: ' + df_cleaned['evidence']


Unnamed: 0,instruction,input,output,db_id
0,Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...",SELECT movie_title FROM movies WHERE movie_rel...,movie_platform
1,Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...","SELECT movie_title, movie_release_year, direct...",movie_platform
2,Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...","SELECT movie_title, movie_release_year FROM mo...",movie_platform
3,Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...",SELECT movie_title FROM movies GROUP BY movie_...,movie_platform
4,Use schema links and evidence to generate ONLY...,"Table: movies\n\nCREATE TABLE ""movies""\n(\n ...",SELECT AVG(movie_popularity) FROM movies WHERE...,movie_platform
...,...,...,...,...
9388,Use schema links and evidence to generate ONLY...,"Table: payment\n\nCREATE TABLE ""payment""\n(\n ...",SELECT COUNT(T1.customer_id) FROM payment AS T...,movie_3
9389,Use schema links and evidence to generate ONLY...,"Table: payment\n\nCREATE TABLE ""payment""\n(\n ...","SELECT T2.first_name, T2.last_name FROM paymen...",movie_3
9390,Use schema links and evidence to generate ONLY...,"Table: address\n\nCREATE TABLE ""address""\n(\n ...",SELECT SUM(T5.amount) FROM address AS T1 INNER...,movie_3
9391,Use schema links and evidence to generate ONLY...,"Table: payment\n\nCREATE TABLE ""payment""\n(\n ...",SELECT COUNT(T1.amount) FROM payment AS T1 INN...,movie_3


In [66]:
dev_pd_modified.to_pickle('./train_raw_data_it_evi.pkl')

sLLM Templete 지정
============

In [10]:
filtered_dataset_sllm = []

for _, item in dev_pd.iterrows():
    new_item = {
        "messages": [
            {
                # llama 3 기본 템플릿
                "role": "system",
                "content": """
        In your response, you do not need to mention your intermediate steps. 
        Do not include any comments ian your response.
        Do not need to start with the symbol ```
        You only need to return the result sqlite SQL code
        start from SELECT
        """
            },
            {
                "role": "user",
                "content": f"{item['instruction']}\n{item['description']}"
            },
            {
                "role": "assistant",
                "content": f"{item['output']}"
            }
        ],
        "db_id": item['db_id'],
        "difficulty": item['difficulty']
    }
    filtered_dataset_sllm.append(new_item)

토큰 수 계산
=============

In [5]:
from transformers import AutoTokenizer
import numpy as np
import os
from jinja2.ext import Extension
import torch
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['REQUESTS_CA_BUNDLE'] = ''


# 토크나이저 로드
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", trust_remote_code=True)

# 토큰 수를 저장할 리스트
token_counts = []


# filtered_dataset의 각 항목에 대해 반복
for item in filtered_dataset_sllm:
    inputs = tokenizer.apply_chat_template(item["messages"], add_generation_prompt=True, return_tensors="pt")

    token_count = inputs.shape[1]

    token_counts.append(token_count)

# 통계 계산
average_tokens = np.mean(token_counts)
max_tokens = np.max(token_counts)
min_tokens = np.min(token_counts)

# 결과 출력
print(f"평균 토큰 수: {average_tokens:.2f}")
print(f"최대 토큰 수: {max_tokens}")
print(f"최소 토큰 수: {min_tokens}")

  from .autonotebook import tqdm as notebook_tqdm
Token indices sequence length is longer than the specified maximum sequence length for this model (59083 > 16384). Running this sequence through the model will result in indexing errors


평균 토큰 수: 1434.29
최대 토큰 수: 139331
최소 토큰 수: 143


4096초과 토큰 계산
=============

In [75]:
t_4096=0
for i in token_counts:
    if i>8192:
        t_4096 +=1
t_4096

27

가장 긴 입력 확인
=============

In [73]:
def get_total_message_length(item):
    return sum(len(message['content']) for message in item['messages'])

longest_item = max(filtered_dataset_sllm, key=get_total_message_length)

print("Messages:")
for message in longest_item['messages']:
    print(f"Role: {message['role']}")
    print(f"Content: {message['content']}") 
    print()

Messages:
Role: system
Content: Use schema links and evidence to generate ONLY SQL query for given question without explanation.

Role: user
Content: Table: productphoto

CREATE TABLE ProductPhoto
(
    ProductPhotoID         INTEGER
        primary key autoincrement,
    ThumbNailPhoto         BLOB,
    ThumbnailPhotoFileName TEXT,
    LargePhoto             BLOB,
    LargePhotoFileName     TEXT,
    ModifiedDate           DATETIME default CURRENT_TIMESTAMP not null
)

3 rows from productphoto table:
[(69, '0x47494638396150003100F70000E3E3FCA6ACB3F5F6FE303D47F8FAFDEDEEFE989DA2F6F8FD6C86B5999B9DD3D3D5E2E5EACCD0D79AA2A8BDBEC2868B93ACB2B8BCBCBE5A6674526A913A4046F9FAFADBDCE9AAABAD656C76BFC0C7E9E9FDDCE1EA27272BEFEFFCDCDEE2C5CCE3CACBCCDADADC94A8C5B1B3B5FEFEFF5B76A86D7179C2C5CA8A8B8DC2C3C4A2A3A6D2D4DAE7E8EBD0D1E3EEF0F17E848AEDEDFAF5F5F5CBCBDBCBCED29496985B5E622A2C30B4B5B9E8E9FAA6B8C54B4D51C8D6E3F3F4FEF1F1F27C7E81DBDBF2FAFBFCF6F7F84D5154A3A5AA3335399292958082868D959C61656DE9E9EA7173757A7A7DFC