# Natural SQL 7B

In [None]:
pip install transformers==4.35.2 accelerate sqlparse

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("chatdb/natural-sql-7b")
model = AutoModelForCausalLM.from_pretrained(
    "chatdb/natural-sql-7b",
    device_map="auto",
    torch_dtype=torch.float16,
)

In [11]:
questions = ['Show me the day with the most users joining', 'Show me the project that has a task with the most comments', 'What is the ratio of users with gmail addresses vs without?']

for question in questions:
    prompt = f"""
    ### Task 

    Generate a SQL query to answer the following question: `{question}` 
    
    ### PostgreSQL Database Schema 
    The query will run on a database with the following schema: 
    ```
    CREATE TABLE users (
        user_id SERIAL PRIMARY KEY,
        username VARCHAR(50) NOT NULL,
        email VARCHAR(100) NOT NULL,
        password_hash TEXT NOT NULL,
        created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
    );
    
    CREATE TABLE projects (
        project_id SERIAL PRIMARY KEY,
        project_name VARCHAR(100) NOT NULL,
        description TEXT,
        start_date DATE,
        end_date DATE,
        owner_id INTEGER REFERENCES users(user_id)
    );
    
    CREATE TABLE tasks (
        task_id SERIAL PRIMARY KEY,
        task_name VARCHAR(100) NOT NULL,
        description TEXT,
        due_date DATE,
        status VARCHAR(50),
        project_id INTEGER REFERENCES projects(project_id)
    );
    
    CREATE TABLE taskassignments (
        assignment_id SERIAL PRIMARY KEY,
        task_id INTEGER REFERENCES tasks(task_id),
        user_id INTEGER REFERENCES users(user_id),
        assigned_date DATE NOT NULL DEFAULT CURRENT_TIMESTAMP
    );
    
    CREATE TABLE comments (
        comment_id SERIAL PRIMARY KEY,
        content TEXT NOT NULL,
        created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
        task_id INTEGER REFERENCES tasks(task_id),
        user_id INTEGER REFERENCES users(user_id)
    );
    ```
    
    ### Answer 
    Here is the SQL query that answers the question: `{question}` 
    ```sql
    """

    print ("Question: " + question)
    print ("SQL: ")
    
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=100001,
        pad_token_id=100001,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    
    )
    
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    print(outputs[0].split("```sql")[-1])


Question: Show me the day with the most users joining
SQL: 

     SELECT created_at::date AS join_date, COUNT(*) AS user_count
     FROM users
     GROUP BY join_date
     ORDER BY user_count DESC
     LIMIT 1;
Question: Show me the project that has a task with the most comments
SQL: 

     SELECT p.project_id, p.project_name, COUNT(c.comment_id) AS comment_count
     FROM projects p
     JOIN tasks t ON p.project_id = t.project_id
     JOIN comments c ON t.task_id = c.task_id
     GROUP BY p.project_id
     ORDER BY comment_count DESC
     LIMIT 1;
Question: What is the ratio of users with gmail addresses vs without?
SQL: 

     SELECT
        SUM(CASE WHEN email LIKE '%@gmail.com%' THEN 1 ELSE 0 END) AS gmail_users,
        SUM(CASE WHEN email NOT LIKE '%@gmail.com%' THEN 1 ELSE 0 END) AS non_gmail_users,
        (SUM(CASE WHEN email LIKE '%@gmail.com%' THEN 1 ELSE 0 END)::FLOAT / NULLIF(SUM(CASE WHEN email NOT LIKE '%@gmail.com%' THEN 1 ELSE 0 END), 0)) AS gmail_ratio
    FROM users