## Pre-requisites

In [21]:
# !pip install ollama
# !pip install sqlparse
# !pip install langchain
# !pip install langchain-core
# !pip install langchain_community

## Set up ollama on your system (Step-1)
#### Linux system
* Command --> curl -fsSL https://ollama.com/install.sh | sh
* Website --> https://ollama.com/download

## Setup LLM (Step-2)
* Command --> ollama pull granite-code:8b-instruct-q4_0 
* Website --> https://ollama.com/library/granite-code:8b-instruct-q4_0
* We are using 8b 4bit quantized model . Model is present  in 3b , 8b, 20b,34b variants . Base and quantized versions are present. You can use depending on the system configuration.

### Import Libraries

In [33]:
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
import time

## Access LLM running on your system . 
* base_url will be http://localhost:11434 in your case . I am running it on server.

In [34]:
llm = Ollama(base_url="http://172.27.222.2:11434",
    model="granite-code:8b-instruct-q4_0"
) 

## Setup  mysql and access the database

In [35]:
mysql_uri = 'mysql+mysqlconnector://root@localhost:3306/test'
db = SQLDatabase.from_uri(mysql_uri)

In [36]:
db.get_table_names()   # These are the tables present in database named 'test'

['EmployeeDetails', 'Employees']

In [57]:
chain = create_sql_query_chain(llm, db)   #This chain generates SQL queries for the given database.

#### Generate a sql query from natural language .

In [75]:
response = chain.invoke({'question':'Retrieve the youngest and oldest employees details along with their ages.'})
print(response)

SELECT BirthDate, CONCAT(FirstName, ' ', LastName) AS FullName, DATEDIFF(CURDATE(), BirthDate) AS Age
FROM EmployeeDetails
ORDER BY Age ASC, FullName ASC
LIMIT 2;


In [76]:
response

"SELECT BirthDate, CONCAT(FirstName, ' ', LastName) AS FullName, DATEDIFF(CURDATE(), BirthDate) AS Age\nFROM EmployeeDetails\nORDER BY Age ASC, FullName ASC\nLIMIT 2;"

### Evaluation 
* Our database contain two tables named Employees and EmployeeDetails . 
* For evaluation purpose , we have generated 20 questions using ChatGPT , 10-simple , 5-medium , 5-complex along with queries.
* We will  match ground truth query with LLM generated query . We will execute both the queries (chatgpt generated and LLM generated manually in mysql and check the response generated .)

In [79]:
sql_questions = [
    "Retrieve all columns for employees with an EmployeeID less than 10.",
    "Retrieve the first name and salary of employees in the IT department.",
    "Count the number of employees in each department.",
    "Retrieve employees whose salary is greater than $70,000.",
    "Retrieve employees sorted by last name in descending order.",
    "Calculate the average salary of employees.",
    "Retrieve employees birth after January 1st, 1982.",
    "Retrieve employees with the first name starting with 'J'.",
    "Retrieve employees sorted by department and position.",
    "Retrieve employees with no assigned department.",
    "Retrieve employees along with their detailed information (from EmployeeDetails).",
    "Calculate the total number of employees and average salary per department.",
    "Retrieve the employee details who live in cities that have more than one employee, along with the count of employees in those cities.",
    "Retrieve employees with salaries between $60,000 and $80,000, sorted by salary.",
    "Retrieve the top 3 highest paid employees.",
    "Retrieve the department with the highest average salary.",
    "Retrieve employees who are earning more than the average salary of their department.",
    "Retrieve the total number of employees in each city and their average salary.",
    "Retrieve the youngest and oldest employees' details along with their ages.",
    "Retrieve employees who have the same last name as any other employee."
]

In [80]:
len(sql_questions)  # Total number of sql questions

20

#### Lets create a function that  takes sql_question one by one and return sql  query 

In [81]:
model_response=  []
def sql_query(questions_list):
    for i in questions_list:
        response = chain.invoke({'question':i})
        model_response.append(response)

sql_query(sql_questions)

#### Create a Dataframe

In [82]:
import pandas as pd
ground_truth = [
    "SELECT * FROM Employees WHERE EmployeeID < 10;",
    "SELECT FirstName, Salary FROM Employees WHERE Department = 'IT';",
    "SELECT Department, COUNT(*) AS NumEmployees FROM Employees GROUP BY Department;",
    "SELECT * FROM Employees WHERE Salary > 70000;",
    "SELECT * FROM Employees ORDER BY LastName DESC;",
    "SELECT AVG(Salary) AS AvgSalary FROM Employees;",
    "SELECT * FROM Employees WHERE HireDate > '1982-01-01';",
    "SELECT * FROM Employees WHERE FirstName LIKE 'J%';",
    "SELECT * FROM Employees ORDER BY Department, Position;",
    "SELECT * FROM Employees WHERE Department IS NULL;",
    "SELECT E.*, ED.BirthDate, ED.Address, ED.City, ED.State, ED.ZipCode, ED.Phone FROM Employees E JOIN EmployeeDetails ED ON E.EmployeeID = ED.EmployeeID;",
    "SELECT Department, COUNT(*) AS NumEmployees, AVG(Salary) AS AvgSalary FROM Employees GROUP BY Department;",
    "SELECT ed.City, COUNT(ed.EmployeeID) AS EmployeeCount, GROUP_CONCAT(e.FirstName, ' ', e.LastName) AS EmployeeNames FROM EmployeeDetails ed JOIN Employees e ON ed.EmployeeID = e.EmployeeID GROUP BY ed.City HAVING COUNT(ed.EmployeeID) > 1 ORDER BY EmployeeCount DESC;",
    "SELECT * FROM Employees WHERE Salary BETWEEN 60000 AND 80000 ORDER BY Salary;",
    "SELECT * FROM Employees ORDER BY Salary DESC LIMIT 3;",
    "SELECT Department, AVG(Salary) AS AvgSalary FROM Employees GROUP BY Department ORDER BY AvgSalary DESC LIMIT 1;",
    "SELECT E.* FROM Employees E JOIN (SELECT Department, AVG(Salary) AS AvgSalary FROM Employees GROUP BY Department) AS AvgSalaries ON E.Department = AvgSalaries.Department WHERE E.Salary > AvgSalaries.AvgSalary;",
    "SELECT ED.City, COUNT(*) AS NumEmployees, AVG(E.Salary) AS AvgSalary FROM Employees E JOIN EmployeeDetails ED ON E.EmployeeID = ED.EmployeeID GROUP BY ED.City;",
    "SELECT e.FirstName, e.LastName, ed.BirthDate, TIMESTAMPDIFF(YEAR, ed.BirthDate, CURDATE()) AS Age, CASE WHEN ed.BirthDate = (SELECT MIN(BirthDate) FROM EmployeeDetails) THEN 'Oldest' WHEN ed.BirthDate = (SELECT MAX(BirthDate) FROM EmployeeDetails) THEN 'Youngest' END AS AgeCategory FROM EmployeeDetails ed JOIN Employees e ON ed.EmployeeID = e.EmployeeID WHERE ed.BirthDate = (SELECT MIN(BirthDate) FROM EmployeeDetails) OR ed.BirthDate = (SELECT MAX(BirthDate) FROM EmployeeDetails);",
    "SELECT E1.EmployeeID, E1.FirstName, E1.LastName FROM Employees E1 JOIN Employees E2 ON E1.LastName = E2.LastName AND E1.EmployeeID != E2.EmployeeID;"
]

# Create a DataFrame
df = pd.DataFrame({
    'Question': sql_questions,
    'Ground Truth Query':ground_truth ,
    'Granite Response (LLM)': model_response ,
    'Query Level':['simple','simple','simple','simple','simple','simple','simple','simple','simple','simple','medium','medium','medium','medium','medium','complex','complex','complex','complex','complex'],
    'Match':['correct','correct' ,'correct','correct','correct','correct' ,'correct','correct','correct','correct','correct','correct','wrong','correct','correct','correct','correct','wrong','wrong','correct']
})

In [83]:
df

Unnamed: 0,Question,Ground Truth Query,Granite Response (LLM),Query Level,Match
0,Retrieve all columns for employees with an Emp...,SELECT * FROM Employees WHERE EmployeeID < 10;,SELECT * FROM Employees WHERE EmployeeID < 10;,simple,correct
1,Retrieve the first name and salary of employee...,"SELECT FirstName, Salary FROM Employees WHERE ...","SELECT `FirstName`, `Salary` FROM `Employees` ...",simple,correct
2,Count the number of employees in each department.,"SELECT Department, COUNT(*) AS NumEmployees FR...","SELECT Department, COUNT(*) FROM Employees GRO...",simple,correct
3,Retrieve employees whose salary is greater tha...,SELECT * FROM Employees WHERE Salary > 70000;,SELECT * FROM Employees WHERE Salary > 70000;,simple,correct
4,Retrieve employees sorted by last name in desc...,SELECT * FROM Employees ORDER BY LastName DESC;,SELECT * FROM Employees ORDER BY LastName DESC;,simple,correct
5,Calculate the average salary of employees.,SELECT AVG(Salary) AS AvgSalary FROM Employees;,SELECT AVG(`Salary`) FROM `Employees`,simple,correct
6,"Retrieve employees birth after January 1st, 1982.",SELECT * FROM Employees WHERE HireDate > '1982...,SELECT * FROM EmployeeDetails WHERE BirthDate ...,simple,correct
7,Retrieve employees with the first name startin...,SELECT * FROM Employees WHERE FirstName LIKE '...,SELECT * FROM Employees WHERE FirstName LIKE '...,simple,correct
8,Retrieve employees sorted by department and po...,"SELECT * FROM Employees ORDER BY Department, P...","SELECT * FROM Employees ORDER BY Department, P...",simple,correct
9,Retrieve employees with no assigned department.,SELECT * FROM Employees WHERE Department IS NULL;,"SELECT `Employees`.`EmployeeID`, `Employees`.`...",simple,correct


In [85]:
df.to_csv("granite_eval.csv")