In [1]:
import sys
sys.path.append("../src")
from text2sql import hello
print(hello.message)

hello, world!


In [2]:
import json

from text2sql.data import (
    get_sqlite_database_file,
    query_sqlite_database,
    get_sqlite_schema,
    schema_to_basic_format,
    schema_to_sql_create,
    schema_to_datagrip_format
)

In [3]:
import glob
import os

In [4]:
def list_supported_databases(dataset_base_path: str, dataset: str) -> list[str]:
    """find all sqlite databases in the dataset directory and return their names"""
    # handle nested or flat structure
    flat = [os.path.basename(p) for p in glob.glob(os.path.join(dataset_base_path, dataset, "*.sqlite"))]
    nested = [os.path.basename(p) for p in glob.glob(os.path.join(dataset_base_path, dataset, "**/*.sqlite"))]
    found_files = sorted(list(set(flat + nested)))
    database_names = [x.rsplit(".", 1)[0] for x in found_files]
    return database_names

In [26]:
class SqliteDataset:
    def __init__(self, base_data_path: str, dataset_name: str):
        self.base_data_path = base_data_path
        self.dataset_name = dataset_name
        self.databases = list_supported_databases(base_data_path, dataset_name)

    def get_databases(self) -> list[str]:
        """return a list of the names of the sqlite databases in the dataset"""
        return self.databases

    def get_database_path(self, database_name: str) -> str:
        """return the path to the sqlite database file"""
        if database_name not in self.databases:
            raise ValueError(f"Database '{database_name}' not found in dataset '{self.dataset_name}'")
        return get_sqlite_database_file(self.base_data_path, self.dataset_name, database_name)
    
    def get_database_schema(self, database_name: str) -> dict:
        """return a dict of the database schema"""
        return get_sqlite_schema(self.base_data_path, self.dataset_name, database_name)
    
    def describe_database_schema(self, database_name: str, mode: str="basic") -> str:
        """return a string representation of the database schema"""
        supported_modes = ["basic", "basic_types", "basic_relations", "basic_types_relations", "sql", "datagrip"]
        if mode not in supported_modes:
            raise ValueError(f"Unknown schema mode '{mode}', supported modes are: {supported_modes}")
        schema = self.get_database_schema(database_name)
        if mode == "basic":
            return schema_to_basic_format(database_name, schema, include_types=False, include_relations=False)
        if mode == "basic_types":
            return schema_to_basic_format(database_name, schema, include_types=True, include_relations=False)
        if mode == "basic_relations":
            return schema_to_basic_format(database_name, schema, include_types=False, include_relations=True)
        if mode == "basic_types_relations":
            return schema_to_basic_format(database_name, schema, include_types=True, include_relations=True)
        elif mode == "sql":
            return schema_to_sql_create(database_name, schema)
        elif mode == "datagrip":
            return schema_to_datagrip_format(database_name, schema)
        else:
            raise ValueError(f"Unknown schema mode '{mode}', supported modes are: {supported_modes}")
        
    def query_database(self, database_name: str, query: str) -> list[dict]:
        """return the results of the query as a list of dictionaries"""
        database_path = self.get_database_path(database_name)
        return query_sqlite_database(self.base_data_path, self.dataset_name, database_name, query)

In [27]:
dataset_base_path = "/home/derek/PythonProjects/gena/data/text2sql_datasets/sqlite_datasets"
dataset = "bird"
database = "language_corpus"

In [28]:
bird_train_dataset = SqliteDataset(dataset_base_path, dataset)
print(bird_train_dataset.get_databases())

['address', 'airline', 'app_store', 'authors', 'beer_factory', 'bike_share_1', 'book_publishing_company', 'books', 'car_retails', 'cars', 'chicago_crime', 'citeseer', 'codebase_comments', 'coinmarketcap', 'college_completion', 'computer_student', 'cookbook', 'craftbeer', 'cs_semester', 'disney', 'donor', 'european_football_1', 'food_inspection', 'food_inspection_2', 'genes', 'hockey', 'human_resources', 'ice_hockey_draft', 'image_and_language', 'language_corpus', 'law_episode', 'legislator', 'mental_health_survey', 'menu', 'mondial_geo', 'movie', 'movie_3', 'movie_platform', 'movielens', 'movies_4', 'music_platform_2', 'music_tracker', 'olympics', 'professional_basketball', 'public_review_platform', 'regional_sales', 'restaurant', 'retail_complains', 'retail_world', 'retails', 'sales', 'sales_in_weather', 'shakespeare', 'shipping', 'shooting', 'simpson_episodes', 'soccer_2016', 'social_media', 'software_company', 'student_loan', 'superstore', 'synthea', 'talkingdata', 'trains', 'univer

In [30]:
print(bird_train_dataset.describe_database_schema("beer_factory", mode="sql"))

beer_factory CREATE messages:

CREATE TABLE customers (
    CustomerID INTEGER
    First TEXT
    Last TEXT
    StreetAddress TEXT
    City TEXT
    State TEXT
    ZipCode INTEGER
    Email TEXT
    PhoneNumber TEXT
    FirstPurchaseDate DATE
    SubscribedToEmailList TEXT
    Gender TEXT
,
    PRIMARY KEY (CustomerID)
);

CREATE TABLE geolocation (
    LocationID INTEGER
    Latitude REAL
    Longitude REAL
,
    PRIMARY KEY (LocationID)
    FOREIGN KEY (LocationID) REFERENCES location (LocationID)
);

CREATE TABLE location (
    LocationID INTEGER
    LocationName TEXT
    StreetAddress TEXT
    City TEXT
    State TEXT
    ZipCode INTEGER
,
    PRIMARY KEY (LocationID)
    FOREIGN KEY (LocationID) REFERENCES geolocation (LocationID)
);

CREATE TABLE rootbeerbrand (
    BrandID INTEGER
    BrandName TEXT
    FirstBrewedYear INTEGER
    BreweryName TEXT
    City TEXT
    State TEXT
    Country TEXT
    Description TEXT
    CaneSugar TEXT
    CornSyrup TEXT
    Honey TEXT
    Artificia

In [31]:
print(bird_train_dataset.describe_database_schema("airline", mode="basic_types"))

airline tables:
Air Carriers ( Code (INTEGER) , Description (TEXT) )
Airports ( Code (TEXT) , Description (TEXT) )
Airlines ( FL_DATE (TEXT) , OP_CARRIER_AIRLINE_ID (INTEGER) , TAIL_NUM (TEXT) , OP_CARRIER_FL_NUM (INTEGER) , ORIGIN_AIRPORT_ID (INTEGER) , ORIGIN_AIRPORT_SEQ_ID (INTEGER) , ORIGIN_CITY_MARKET_ID (INTEGER) , ORIGIN (TEXT) , DEST_AIRPORT_ID (INTEGER) , DEST_AIRPORT_SEQ_ID (INTEGER) , DEST_CITY_MARKET_ID (INTEGER) , DEST (TEXT) , CRS_DEP_TIME (INTEGER) , DEP_TIME (INTEGER) , DEP_DELAY (INTEGER) , DEP_DELAY_NEW (INTEGER) , ARR_TIME (INTEGER) , ARR_DELAY (INTEGER) , ARR_DELAY_NEW (INTEGER) , CANCELLED (INTEGER) , CANCELLATION_CODE (TEXT) , CRS_ELAPSED_TIME (INTEGER) , ACTUAL_ELAPSED_TIME (INTEGER) , CARRIER_DELAY (INTEGER) , WEATHER_DELAY (INTEGER) , NAS_DELAY (INTEGER) , SECURITY_DELAY (INTEGER) , LATE_AIRCRAFT_DELAY (INTEGER) )
