In [1]:
import pandas as pd
import numpy as np
import json
import utils

In [2]:
db_schema = """CREATE TABLE user (
    id NUMBER ,
    username VARCHAR2(50) NOT NULL,
    email VARCHAR2(100),
    password_hash CHAR(64),
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP,
    last_login DATE,
    age NUMBER CHECK (age >= 0),
    is_active CHAR(1) DEFAULT 'Y',
    profile_views NUMBER(19),
    account_balance NUMBER(10,2),
    referral_code VARCHAR2(20),
    subscription_type VARCHAR2(30),
    subscription_expiry DATE,
    email_verified CHAR(1),
    phone_number VARCHAR2(20),
    country VARCHAR2(50),
    city VARCHAR2(50),
    zip_code VARCHAR2(10),
    address_line1 CLOB,
    address_line2 CLOB,
    bio CLOB,
    avatar_url CLOB,
    favorite_color VARCHAR2(30),
    preferred_language VARCHAR2(30),
    timezone VARCHAR2(50),
    login_attempts NUMBER DEFAULT 0,
    password_reset_token CHAR(64),
    password_reset_expiry TIMESTAMP,
    session_token CHAR(64),
    session_expiry TIMESTAMP,
    device_type VARCHAR2(30),
    os_version VARCHAR2(50),
    browser_info CLOB,
    two_factor_enabled CHAR(1),
    two_factor_secret CHAR(32),
    date_of_birth DATE,
    gender VARCHAR2(10),
    employment_status VARCHAR2(30),
    job_title VARCHAR2(100),
    company_name VARCHAR2(100),
    website_url CLOB,
    linkedin_url CLOB,
    github_url CLOB,
    twitter_handle VARCHAR2(50),
    instagram_handle VARCHAR2(50),
    youtube_channel CLOB,
    notification_preferences CLOB, 
    metadata CLOB
);
"""

In [3]:
_, col_descriptions = utils.generate_embeddings('user', db_schema)
col_names = [name.split(',')[0].split('.')[1] for name in col_descriptions]

In [4]:
col_df = pd.DataFrame({'name':col_names})

## Example 1

In [5]:
generated_query = """SELECT SUM(views) AS total_views FROM user WHERE account_balance > 1000"""

In [6]:
embedding_model_name = 'mixedbread-ai/mxbai-embed-large-v1'
qp = utils.queryPostprocessing(generated_query, {'table_name':'user', 'columns':col_df}, embedding_model_name)
processed_query = qp.formatQuerySQLglot()

In [7]:
processed_query

'SELECT SUM(profile_views) AS total_views FROM user WHERE account_balance > 1000'

## Example 2

In [10]:
generated_query = """SELECT user, contact FROM data WHERE balance > 1000"""

In [11]:
embedding_model_name = 'mixedbread-ai/mxbai-embed-large-v1'
qp = utils.queryPostprocessing(generated_query, {'table_name':'user', 'columns':col_df}, embedding_model_name)
processed_query = qp.formatQuerySQLglot()
processed_query

'SELECT username, phone_number FROM user WHERE account_balance > 1000'