In [11]:
!pip install -q -U transformers datasets accelerate peft bitsandbytes

In [2]:
import bitsandbytes as bnb
print(bnb.__file__)

/home/ubuntu/.local/lib/python3.10/site-packages/bitsandbytes/__init__.py


In [12]:
import torch
import os

from datasets import load_dataset, concatenate_datasets
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [21]:
def load_and_prepare_datasets(tokenizer, max_length=512, dataset_sample=None):
    """
    Load Spider and WikiSQL datasets, tokenize them, and optionally sample.
    """
    spider_dataset = load_dataset("spider", split="train")
    wikisql_dataset = load_dataset("wikisql", split="train")
    
    if dataset_sample:
        spider_dataset = spider_dataset.select(range(min(dataset_sample, len(spider_dataset))))
        wikisql_dataset = wikisql_dataset.select(range(min(dataset_sample, len(wikisql_dataset))))
    
    combined_dataset = concatenate_datasets([spider_dataset, wikisql_dataset])
    
    def preprocess_function(examples):
        questions = examples.get("question", [])
        if not questions:
            questions = [""] * len(examples[list(examples.keys())[0]])
        
        targets = []
        if "query" in examples:
            targets = examples["query"]
        elif "sql" in examples:
            sql_field = examples["sql"]
            if isinstance(sql_field, list) and len(sql_field) > 0 and isinstance(sql_field[0], dict):
                targets = [str(d.get("human_readable", "")) for d in sql_field]
            else:
                targets = [""] * len(questions)
        else:
            targets = [""] * len(questions)
        
        # Ensure all inputs are strings
        inputs = [str(q) if q is not None else "" for q in questions]
        targets = [str(t) if t is not None else "" for t in targets]
        
        model_inputs = tokenizer(
            inputs,
            max_length=max_length,
            truncation=True,
            padding="max_length",
        )
        
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                targets,
                max_length=max_length,
                truncation=True,
                padding="max_length",
            )
        
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    tokenized_dataset = combined_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=combined_dataset.column_names,
    )
    
    return tokenized_dataset

In [22]:
def create_qlora_model(
    base_model_name="salesforce/CodeGen2-7B",
    hf_token=None,
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    device_map="auto",
    torch_dtype=torch.float16,
):
    """
    Load the 'salesforce/CodeGen2-7B' model in 4-bit with BitsAndBytesConfig,
    then apply a LoRA adapter.
    """
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16
    )

    tokenizer = AutoTokenizer.from_pretrained(
        base_model_name,
        use_auth_token=hf_token,
        use_fast=True,
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        device_map=device_map,
        torch_dtype=torch_dtype,
        use_auth_token=hf_token,
        quantization_config=bnb_config,
    )

    target_modules = [
        "attn.q_proj",
        "attn.k_proj",
        "attn.v_proj",
        "attn.out_proj",
    ]

    peft_config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        target_modules=target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )

    peft_model = get_peft_model(base_model, peft_config)
    return tokenizer, peft_model


In [23]:
def fine_tune_qlora(
    tokenizer,
    model,
    tokenized_dataset,
    output_dir="./qlora-text-to-sql",
    num_epochs=3,
    batch_size=4,
):
    """
    Fine-tune QLoRA model on the combined Spider + WikiSQL dataset.
    Saves the final model locally to <output_dir>-final.
    """
    split_dataset = tokenized_dataset.train_test_split(test_size=0.1)
    train_dataset = split_dataset["train"]
    eval_dataset = split_dataset["test"]

    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        evaluation_strategy="epoch",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=1,
        num_train_epochs=num_epochs,
        logging_steps=100,
        save_steps=500,
        fp16=True,
        report_to="none",  
    )

    # Create the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )

    print("\nStarting fine-tuning... Please wait.\n")
    trainer.train()
    print("\nTraining is complete!\n")

    trainer.save_model(f"{output_dir}-final")
    tokenizer.save_pretrained(f"{output_dir}-final")

    print(f"\nModel and tokenizer saved at '{output_dir}-final'.")


In [24]:
base_model_name = "salesforce/CodeGen2-7B"
hf_token = None

tokenizer, peft_model = create_qlora_model(
    base_model_name=base_model_name,
    hf_token=hf_token,
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    device_map="auto",
    torch_dtype=torch.float16,
)

tokenized_dataset = load_and_prepare_datasets(tokenizer, max_length=512)

fine_tune_qlora(
    tokenizer,
    peft_model,
    tokenized_dataset,
    output_dir="./qlora-text-to-sql",
    num_epochs=3,
    batch_size=4,
)


Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.84s/it]
Some weights of the model checkpoint at salesforce/CodeGen2-7B were not used when initializing CodeGenForCausalLM: ['transformer.h.0.attn.causal_mask', 'transformer.h.1.attn.causal_mask', 'transformer.h.10.attn.causal_mask', 'transformer.h.11.attn.causal_mask', 'transformer.h.12.attn.causal_mask', 'transformer.h.13.attn.causal_mask', 'transformer.h.14.attn.causal_mask', 'transformer.h.15.attn.causal_mask', 'transformer.h.16.attn.causal_mask', 'transformer.h.17.attn.causal_mask', 'transformer.h.18.attn.causal_mask', 'transformer.h.19.attn.causal_mask', 'transformer.h.2.attn.causal_mask', 'transformer.h.20.attn.causal_mask', 'transformer.h.21.attn.causal_mask', 'transformer.h.22.attn.causal_mask', 'transformer.h.23.attn.causal_mask', 'transformer.h.24.attn.causal_mask', 'transformer.h.25.attn.causal_mask', 'transformer.h.26.attn.causal_mask', 'transformer.h.27.attn.causal_mask', 'transformer.h.28.attn.causal_mask', 


Starting fine-tuning... Please wait.



Epoch,Training Loss,Validation Loss
1,0.0319,0.040905
2,0.0432,0.03853
3,0.0285,0.037256





Training is complete! Now saving the model locally...


Model and tokenizer saved at './qlora-text-to-sql-final'.
If you're on a local machine, you'll find them in that folder.




In [26]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


model_path = "./qlora-text-to-sql-final"


tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    use_fast=True,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="cuda",
)
model.eval()


Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.66s/it]
Some weights of the model checkpoint at salesforce/CodeGen2-7B were not used when initializing CodeGenForCausalLM: ['transformer.h.0.attn.causal_mask', 'transformer.h.1.attn.causal_mask', 'transformer.h.10.attn.causal_mask', 'transformer.h.11.attn.causal_mask', 'transformer.h.12.attn.causal_mask', 'transformer.h.13.attn.causal_mask', 'transformer.h.14.attn.causal_mask', 'transformer.h.15.attn.causal_mask', 'transformer.h.16.attn.causal_mask', 'transformer.h.17.attn.causal_mask', 'transformer.h.18.attn.causal_mask', 'transformer.h.19.attn.causal_mask', 'transformer.h.2.attn.causal_mask', 'transformer.h.20.attn.causal_mask', 'transformer.h.21.attn.causal_mask', 'transformer.h.22.attn.causal_mask', 'transformer.h.23.attn.causal_mask', 'transformer.h.24.attn.causal_mask', 'transformer.h.25.attn.causal_mask', 'transformer.h.26.attn.causal_mask', 'transformer.h.27.attn.causal_mask', 'transformer.h.28.attn.causal_mask', 

CodeGenForCausalLM(
  (transformer): CodeGenModel(
    (wte): Embedding(51200, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-31): 32 x CodeGenBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): CodeGenAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (qkv_proj): Linear(in_features=4096, out_features=12288, bias=False)
          (out_proj): lora.Linear(
            (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=4096, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDi

In [27]:
def generate_sql(question, max_new_tokens=128):
    """
    Given a question in English, return the model's SQL generation as a string.
    """
    # Construct the prompt ( add special instructions or schema info).
    # For example:
    prompt = f"Question: {question}\nGenerate the SQL query:\n"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,  
            top_p=0.9,         
            temperature=0.8,     
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text


In [28]:
sample_question = "List the name and age of all students older than 20."
generated_sql = generate_sql(sample_question)
print("Generated SQL:\n", generated_sql)


Generated SQL:
 Question: List the name and age of all students older than 20.
Generate the SQL query:
SELECT student_name , age FROM student WHERE age >= 20



In [31]:
base_model_name = "salesforce/CodeGen2-7B"
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
adapter_path = "./qlora-text-to-sql-final"  # This is the folder with LoRA weights

lora_model = PeftModel.from_pretrained(base_model, adapter_path)
lora_model = lora_model.merge_and_unload()

lora_model.save_pretrained("./merged-full-model")


Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  3.63it/s]
Some weights of the model checkpoint at salesforce/CodeGen2-7B were not used when initializing CodeGenForCausalLM: ['transformer.h.0.attn.causal_mask', 'transformer.h.1.attn.causal_mask', 'transformer.h.10.attn.causal_mask', 'transformer.h.11.attn.causal_mask', 'transformer.h.12.attn.causal_mask', 'transformer.h.13.attn.causal_mask', 'transformer.h.14.attn.causal_mask', 'transformer.h.15.attn.causal_mask', 'transformer.h.16.attn.causal_mask', 'transformer.h.17.attn.causal_mask', 'transformer.h.18.attn.causal_mask', 'transformer.h.19.attn.causal_mask', 'transformer.h.2.attn.causal_mask', 'transformer.h.20.attn.causal_mask', 'transformer.h.21.attn.causal_mask', 'transformer.h.22.attn.causal_mask', 'transformer.h.23.attn.causal_mask', 'transformer.h.24.attn.causal_mask', 'transformer.h.25.attn.causal_mask', 'transformer.h.26.attn.causal_mask', 'transformer.h.27.attn.causal_mask', 'transformer.h.28.attn.causal_mask', 

In [36]:
test_questions = [
    "Find the total number of students enrolled in the 'Computer Science' department.",
    "Retrieve the names of employees who have a salary greater than 100,000 and work in the 'Engineering' department.",
    "List the top 5 customers who have made the highest total purchases in the last year.",
    "Find the total sales revenue for each product category in the last 6 months.",
    "Retrieve the student names and their average grades for courses where the grade is above 85.",
    "Find all employees who have worked for the company for more than 5 years and have at least one promotion.",
    "List all orders that were placed in the last month and are still pending delivery.",
    "Get the names of customers who have never placed an order.",
    "Find the number of distinct suppliers for each product in the inventory.",
    "Retrieve all flight details for flights scheduled between 6 AM and 12 PM tomorrow."
]

for idx, question in enumerate(test_questions, 1):
    generated_sql = generate_sql(question)
    print(f"\nTest {idx}: {question}")
    print("Generated SQL:\n", generated_sql)
    print("="*80)



Test 1: Find the total number of students enrolled in the 'Computer Science' department.
Generated SQL:
 Question: Find the total number of students enrolled in the 'Computer Science' department.
Generate the SQL query:
SELECT COUNT(id) AS total FROM Students WHERE department = 'Computer Science'


Test 2: Retrieve the names of employees who have a salary greater than 100,000 and work in the 'Engineering' department.
Generated SQL:
 Question: Retrieve the names of employees who have a salary greater than 100,000 and work in the 'Engineering' department.
Generate the SQL query:
SELECT Employee_Name FROM Employees WHERE  (Employees.Salary>= 100000) AND (Employees.Department LIKE '% Engineering' )


Test 3: List the top 5 customers who have made the highest total purchases in the last year.
Generated SQL:
 Question: List the top 5 customers who have made the highest total purchases in the last year.
Generate the SQL query:
SELECT Customers.id, Customers.name, Customers.id, Customers.  ye