Load env file

In [1]:
from dotenv import load_dotenv
import os

load_dotenv()

True

Create the DB object

In [2]:
from langchain_community.utilities.sql_database import SQLDatabase

database_uri = "mysql+pymysql://root:"+ os.getenv('DB_PASS')+ "@/hr"
db = SQLDatabase.from_uri(database_uri)

In [3]:
print(db.dialect)

mysql


In [4]:
print(db.get_usable_table_names())

['countries', 'departments', 'dependents', 'employees', 'jobs', 'locations', 'regions']


In [5]:
db.run("SELECT * FROM countries LIMIT 5;")

"[('AR', 'Argentina', 2), ('AU', 'Australia', 3), ('BE', 'Belgium', 1), ('BR', 'Brazil', 2), ('CA', 'Canada', 2)]"

Create the LLM Object

In [6]:
from torch import cuda, bfloat16
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))

True
1
0
NVIDIA GeForce RTX 4060 Laptop GPU


In [8]:
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
print(device)

cuda:0


In [9]:
# Initialize HF items, need auth token for these
hf_auth  =  os.getenv('HASH_KEY')
model_id = 'defog/llama-3-sqlcoder-8b'

model_config = transformers.AutoConfig.from_pretrained(
    model_id,
    use_auth_token=hf_auth
)

# Set quantization configuration to load a large model with less GPU memory
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_enable_fp32_cpu_offload=True  # Enable FP32 CPU offloading
)

# Load the model with quantization configuration
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    device_map='auto',
    use_auth_token=hf_auth
)
model.eval()
print(f"Model loaded on {device}")

Loading checkpoint shards: 100%|██████████| 4/4 [00:20<00:00,  5.12s/it]


Model loaded on cuda:0


In [10]:
#Initialize the tokenizer

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_id,
    use_auth_token=hf_auth
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [11]:
#Create the transformers pipleline

generate_text = transformers.pipeline(
    model=model, tokenizer=tokenizer,
    return_full_text=True,
    task='text-generation',
    temperature       = 0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    max_new_tokens    = 512, # max number of tokens to generate in the output
    repetition_penalty= 1.1,  # without this output begins repeating
)

In [12]:
res = generate_text("Get all data from a table") 
print(res[0]["generated_text"])

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Get all data from a table
SELECT * FROM table_name;


In [13]:
#Create the HuggingFacePipeline

from langchain.llms import HuggingFacePipeline
llm = HuggingFacePipeline(pipeline=generate_text)

  warn_deprecated(


In [14]:
answer = llm(prompt="Get all data from a table")
print(answer)

  warn_deprecated(
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Get all data from a table
SELECT * FROM table_name;


Connect DB with SQL

In [15]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [16]:
print(response)

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [17]:
import re
def extract_query_from_response(text):
    # Use regex to find all occurrences of SQL queries
    matches = re.findall(r'SQLQuery:\s*(.*)', text)
    
    if len(matches) >= 2:
        sql_query = matches[1]
        return sql_query
    else:
        print("Less than two SQL Queries found")


In [18]:
sql_query = extract_query_from_response(response)
sql_query

'SELECT COUNT(*) AS total_employees FROM employees e;'

In [19]:
print(db.run(sql_query))

[(40,)]


In [20]:
response = chain.invoke({"question": "How many employees are there who started work after year 1995"})
print(response)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [21]:
sql_query = extract_query_from_response(response)
print(db.run(sql_query))

[(25,)]


In [22]:
response = chain.invoke({"question": "How many employees are Accountant"})
print(response)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [23]:
sql_query = extract_query_from_response(response)
print(db.run(sql_query))

[(5,)]
