Notebook needs to be ran using GPU, if unable to run locally, you can use Colab which means you will need to install the following packages. If not, you can comment them out.

In [None]:
!pip install -q -U accelerate=='0.25.0' peft=='0.7.1' bitsandbytes=='0.41.3.post2' trl=='0.7.4'
!pip install -q git+https://github.com/huggingface/transformers 

In [None]:
import torch
import sqlparse
import sqlite3

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    pipeline,
)

from typing import (
    Tuple, 
    Optional, 
    List
)

from peft import (
    LoraConfig, 
    PeftConfig, 
    PeftModel, 
)

# import other scripts..
from context_generation import generate_prompt
from create_database import setup_database

# Load Model

In [None]:
class TextToSQL:
    def __init__(
        self, 
        tokenizer_path: str = "codellama/CodeLlama-7b-Instruct-hf", 
        model_path: str = "/kaggle/input/fine-tuned-7b"
    ) -> None:
        self._tokenizer_path = tokenizer_path
        self._model_path = model_path
        self.setup_model()
    
    def setup_model(self) -> None:
        """
        Sets up the model and tokenizer with specific configurations.
        """
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="float16",
            bnb_4bit_use_double_quant=False,
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_path)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"
        
        # to ensure we set up model correctly
        model_id = self._tokenizer_path
        peft_model_id = self._model_path

        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, # loads base model first
            quantization_config=bnb_config,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            device_map="auto",
            use_cache=True,
        )
        
        # load fine-tuned adapters
        self.model.load_adapter(peft_model_id)
        
    def generate_query(
        self, 
        question: str, 
        context_file: Optional[str] = "context.sql"
    ) -> List[str]:
        """
        Generates a SQL query based on the provided question and context.
        """
        
        # must import generate_prompt from utils file
        prompt = generate_prompt(
            context_file, question
        )

        pipe = pipeline(
            task="text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_new_tokens=300,
            do_sample=False,
            num_beams=5,
        )

        generated_query = pipe(
            prompt,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        
        return generated_query 

Loading the model can take up to 1 min or so, this time round it was 32 seconds.

In [None]:
%%time

# load model
model = TextToSQL()

# Set up database

We set up the database which we are going to query from to test our model.

In [None]:
# set up database which is called example.db
setup_database()

In [None]:
# create connection to the db
conn = sqlite3.connect('/kaggle/working/example.db')
cursor = conn.cursor()

# test connection works
cursor.execute("SELECT * FROM student_details")
result = cursor.fetchall()
result

# Evaluation

We evaluate against 10 different questions created and used relating a simple education database.

In [None]:
import pandas as pd

test = pd.read_json('/kaggle/input/evaluation-data/testing_data.json')

In [None]:
import google.generativeai as genai

## load gemini pro..
api_key = '' # add key here..
genai.configure(api_key = api_key)
gem_model = genai.GenerativeModel('gemini-pro')

def generate_gemini(prompt, context):
    
    # instuct:
    instruct = "Generate an SQL query\n"
    
    # combine prompt and context
    prompt = instruct + prompt + "\n" + context
    
    response = gem_model.generate_content(
    prompt,
    generation_config=genai.types.GenerationConfig(temperature=0.9, top_k=40, top_p=0.95,
        max_output_tokens=300), safety_settings=[
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "block_none",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "block_none",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "block_none",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": "block_none",
    },
        ]
    )
    
    return response.text

In [None]:
# get context:
with open('/kaggle/input/context/context.sql', 'r') as f:
    context = f.read()

# Query 1

In [None]:
prompt = "What is the average age of students enrolled in the Computing unit?"
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].strip()
cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """ SELECT AVG(sd.age) AS average_age
FROM student_details sd
INNER JOIN unit_enrolment ue ON sd.student_id = ue.student_id
WHERE ue.unit_title = 'computing';
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """ SELECT AVG(student_details.age) AS average_age
FROM student_details
JOIN unit_enrolment ON student_details.student_id = unit_enrolment.student_id
WHERE unit_enrolment.unit_title = 'Computing';
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 2

In [None]:
prompt = "How many students are enrolled in the Psychology unit?"
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].strip()
cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT COUNT(*) AS student_count
FROM unit_enrolment
WHERE unit_title = 'Psychology';
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT COUNT(*) AS number_of_students
FROM unit_enrolment
WHERE unit_title = 'Psychology';
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 3

In [None]:
prompt = "Provide all the names of the top three students with the highest grades?"
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].strip()

cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT sd.given_name, sd.last_name, AVG(g.grade) AS average_grade
FROM student_details sd
JOIN grade g ON sd.student_id = g.student_id
GROUP BY sd.student_id, sd.given_name, sd.last_name
ORDER BY average_grade DESC
LIMIT 3;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT student_details.student_id, student_details.given_name, student_details.last_name, AVG(grade.grade) AS average_grade
FROM student_details
JOIN grade ON student_details.student_id = grade.student_id
GROUP BY student_details.student_id, student_details.given_name, student_details.last_name
ORDER BY average_grade DESC
LIMIT 3;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 4

In [None]:
prompt = "List names of male students who are younger than 20 years."
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].strip()
cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT given_name, last_name
FROM student_details
WHERE gender = 'male' AND age < 20;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT given_name, last_name
FROM student_details
WHERE gender = 'male' AND age < 20;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 5

In [None]:
prompt = "Find the average grade of students in the Engineering unit."
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].strip()
cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT AVG(g.grade) AS average_grade
FROM unit_enrolment ue
JOIN grade g ON ue.student_id = g.student_id
WHERE ue.unit_title = 'Engineering';
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT AVG(grade.grade) AS average_grade
FROM grade
JOIN unit_enrolment ON grade.student_id = unit_enrolment.student_id
WHERE unit_enrolment.unit_title = 'Engineering';
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 6

In [None]:
prompt = "Count the number of female students enrolled in the 'Biology' unit."
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].strip()
cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT COUNT(*) AS female_students_count
FROM student_details sd
JOIN unit_enrolment ue ON sd.student_id = ue.student_id
WHERE ue.unit_title = 'Biology' AND sd.gender = 'female';
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT COUNT(*) AS number_of_female_students
FROM student_details
JOIN unit_enrolment ON student_details.student_id = unit_enrolment.student_id
WHERE student_details.gender = 'female' AND unit_enrolment.unit_title = 'Biology';
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 7

In [None]:
prompt = "Who is the youngest student in the Mathematics unit?"
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].strip()
cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT sd.given_name, sd.last_name
FROM student_details sd
JOIN unit_enrolment ue ON sd.student_id = ue.student_id
WHERE ue.unit_title = 'Mathematics'
ORDER BY sd.age ASC
LIMIT 1;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT student_details.given_name, student_details.last_name, student_details.age
FROM student_details
JOIN unit_enrolment ON student_details.student_id = unit_enrolment.student_id
WHERE unit_enrolment.unit_title = 'Mathematics'
ORDER BY student_details.age ASC
LIMIT 1;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 8

In [None]:
prompt = "List all students with an average grade above 75."
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].split('#')[0].strip()
cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT sd.given_name, sd.last_name
FROM student_details sd
JOIN grade g ON sd.student_id = g.student_id
GROUP BY sd.student_id, sd.given_name, sd.last_name
HAVING AVG(g.grade) > 75;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT student_details.student_id, student_details.given_name, student_details.last_name, AVG(grade.grade) AS average_grade
FROM student_details
JOIN grade ON student_details.student_id = grade.student_id
GROUP BY student_details.student_id, student_details.given_name, student_details.last_name
HAVING AVG(grade.grade) > 75;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 9

In [None]:
prompt = "How many students are older than 22 years?"
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].split('#')[0].strip()
cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT COUNT(*) AS older_students_count
FROM student_details
WHERE age > 22;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT COUNT(*) AS number_of_students
FROM student_details
WHERE age > 22;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

# Query 10

In [None]:
prompt = "Which units have more than 5 students enrolled?"
print(prompt)

In [None]:
### Proposed Model

context_file = "/kaggle/input/context/context.sql"
query_results = model.generate_query(prompt, context_file)[0]
query_results = query_results['generated_text'].split('\n\n')[0].split('### Answer:')[-1].split('#')[0].strip()
cursor.execute(query_results)
result = cursor.fetchall() 
print(result)

In [None]:
### Gemini-Pro

query_results = generate_gemini(prompt, context)
query_results = query_results.strip('`')
query_lines = query_results.split('\n')[1:]
query_results = '\n'.join(query_lines)
query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-3.5

query_results = """SELECT unit_title, COUNT(*) AS enrollment_count
FROM unit_enrolment
GROUP BY unit_title
HAVING COUNT(*) > 5;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

In [None]:
### ChatGPT-4

query_results = """SELECT unit_title, COUNT(student_id) AS student_count
FROM unit_enrolment
GROUP BY unit_title
HAVING COUNT(student_id) > 5;
"""

query_results = sqlparse.format(
    query_results, strip_comments=True
)

cursor.execute(query_results)
result = cursor.fetchall()
print(result)

To view chats with OpenAI please see below links.

**GP-3.5 CHAT:** https://chat.openai.com/share/2785f3b8-a7eb-4bcc-9c92-cc4897c08c6a

**GPT-4 CHAT:** https://chat.openai.com/share/ddcfba7a-e8dd-459b-8f27-9f85f3e3be0c