In [1]:
%pip install accelerate peft bitsandbytes transformers trl jsonlines pandas numpy python-dotenv numba



In [2]:
import os
import torch
import pandas as pd
import numpy as np
from accelerate import Accelerator
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
    Trainer
)
from peft import LoraConfig, PeftModel, PeftConfig, get_peft_model, prepare_model_for_kbit_training
import time
from dotenv import load_dotenv
import random
random.seed(42)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
lora_model = "Llama-2-7b-chat-hf-text-to-sql"
load_dotenv()

False

In [3]:
from google.colab import userdata
my_token = userdata.get('token')

In [4]:
dataset = load_from_disk("text-to-sql-dataset.hf")
dataset

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'response', 'source', 'text'],
        num_rows: 157324
    })
    test: Dataset({
        features: ['instruction', 'input', 'response', 'source', 'text'],
        num_rows: 52442
    })
    valid: Dataset({
        features: ['instruction', 'input', 'response', 'source', 'text'],
        num_rows: 52442
    })
})

In [5]:
dataset['test'][0]

{'instruction': "What's the average crowd size when the Home team is melbourne?",
 'input': 'CREATE TABLE table_78356 (\n    "Home team" text,\n    "Home team score" text,\n    "Away team" text,\n    "Away team score" text,\n    "Venue" text,\n    "Crowd" real,\n    "Date" text\n)',
 'response': 'SELECT AVG("Crowd") FROM table_78356 WHERE "Home team" = \'melbourne\'',
 'source': 'wikisql',
 'text': 'Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What\'s the average crowd size when the Home team is melbourne? ### Input: CREATE TABLE table_78356 (\n    "Home team" text,\n    "Home team score" text,\n    "Away team" text,\n    "Away team score" text,\n    "Venue" text,\n    "Crowd" real,\n    "Date" text\n) ### Response: SELECT AVG("Crowd") FROM table_78356 WHERE "Home team" = \'melbourne\''}

In [6]:
# Should include quantization here
compute_dtype = getattr(torch, "float16")

quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_quant_type="nf4",
    bnb_8bit_compute_dtype=compute_dtype,
    bnb_8bit_use_double_quant=False,
)

lora_config = PeftConfig.from_pretrained(lora_model, load_from_disk = True)
model = AutoModelForCausalLM.from_pretrained(lora_config.base_model_name_or_path, token = my_token, quantization_config=quant_config)
tokenizer = AutoTokenizer.from_pretrained(lora_config.base_model_name_or_path, token = my_token)
model = get_peft_model(model, lora_config)
model = PeftModel.from_pretrained(model, lora_model, load_from_disk = True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [7]:
prompt = """Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What\'s the average crowd size when the Home team is melbourne? ### Input: CREATE TABLE table_78356 (\n    "Home team" text,\n    "Home team score" text,\n    "Away team" text,\n    "Away team score" text,\n    "Venue" text,\n    "Crowd" real,\n    "Date" text\n) ### Response:"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
print(inputs['input_ids'].shape)

output_sequences = model.generate(**inputs, max_new_tokens=200, top_k = 5)

print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))

torch.Size([1, 132])
['Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What\'s the average crowd size when the Home team is melbourne? ### Input: CREATE TABLE table_78356 (\n    "Home team" text,\n    "Home team score" text,\n    "Away team" text,\n    "Away team score" text,\n    "Venue" text,\n    "Crowd" real,\n    "Date" text\n) ### Response: SELECT Crowd FROM table_78356 WHERE Home team = \'Melbourne\';\n\n\n']
