In [1]:
%load_ext autoreload

In [6]:
import time
import json
import tempfile
import shutil
from pathlib import Path

from helpers.get_tables import get_tables
from resdsql.preprocessing import main as preprocessing
from resdsql.schema_item_classifier import classify_schema_items
from resdsql.text2sql_data_generator import generate_dataset
from resdsql.text2sql import generate_sql

In [9]:
def predict_sql(question, sem_names, db_path, num_beams=8, num_return_sequences=8, seed=42):
    with tempfile.TemporaryDirectory() as temp_dir:
        
        # paths
        temp_dir = Path(temp_dir)
        tables_path = temp_dir / 'tables.json'
        samples_path = temp_dir / 'dev.json'
        preprocessed_samples_path = temp_dir / 'preprocessed_test.json'
        database_path = temp_dir / 'database'
        database_base_path = database_path / 'base'
        with_probs_path = temp_dir / 'with_probs.json'
        dataset_path = temp_dir / 'dataset.json'
        output_path = temp_dir / 'output.sql'

        # create tables.json
        tables = get_tables(db_path, sem_names=sem_names)
        with open(tables_path, 'w') as f:
            json.dump(tables, f, indent=4, ensure_ascii=False)

        # create dev.json
        with open(samples_path, 'w') as f:
            samples = [
                {
                    'db_id': 'base',
                    'question': question,
                }
            ]
            json.dump(samples, f, indent=4, ensure_ascii=False)
            
        # create database dir
        database_base_path.mkdir(parents=True)
        shutil.copy(str(db_path), str(database_base_path / 'base.sqlite'))

        # preprocessing
        preprocessing(
            mode='test',
            table_path=str(tables_path),
            input_dataset_path=str(samples_path),
            output_dataset_path=str(preprocessed_samples_path),
            db_path=str(database_path),
            target_type='sql',
        )
        
        classify_schema_items(
            batch_size=1,
            # device
            # seed
            # save_path,
            dev_filepath=str(preprocessed_samples_path),
            output_filepath=str(with_probs_path),
            use_contents=True,
            add_fk_info=True,
            mode='test',
        )
        
        generate_dataset(
            input_dataset_path=str(with_probs_path),
            output_dataset_path=str(dataset_path),
            topk_table_num=4,
            topk_column_num=5,
            mode='test',
            use_contents=True,
            add_fk_info=True,
            output_skeleton=True,
            target_type='sql'
        )
        
        generate_sql(
            batch_size=1,
            seed=seed,
            mode="eval",
            dev_filepath=str(dataset_path),
            db_path=str(database_path),
            num_beams=num_beams,
            num_return_sequences=num_return_sequences,
            target_type="sql",
            output=str(output_path),
        )
        
        with open(output_path) as f:
            lines = f.readlines()
            
        return lines[0].strip()

In [10]:
sem_names = {
    # 'student': (
    #     'student',
    #     {
    #         'dept_name': 'department name'
    #     }
    # )
}

db_path = '/app/college_2.sqlite'

question = 'Zwróć imiona wszystkich studentów'
# question = 'Znajdź budynki, które mają pokoje o pojemności większej niż 50.'
    
predict_sql(
    question,
    sem_names,
    db_path,
    num_beams=4,
    num_return_sequences=4,
    seed=15452234,
)

3000


1it [00:00, 10.53it/s]
100%|██████████| 1/1 [00:00<00:00, 10.40it/s]
100%|██████████| 1/1 [00:00<00:00,  1.25it/s]


Text-to-SQL inference spends 0.7985846996307373s.
/tmp/tmpei_1r9ph


KeyboardInterrupt: 