# Prompt Tuning - NL2SQL using Santacoder

### Load dataset

In [1]:
from datasets import load_dataset

dataset = load_dataset("wikisql")

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset wikisql (/mnt/data/logesh/hf_cache/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)
100%|██████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 353.15it/s]


In [60]:
dataset['train'][0]

{'phase': 1,
 'question': 'Tell me what the notes are for South Australia ',
 'table': {'header': ['State/territory',
   'Text/background colour',
   'Format',
   'Current slogan',
   'Current series',
   'Notes'],
  'page_title': '',
  'page_id': '',
  'types': ['text', 'text', 'text', 'text', 'text', 'text'],
  'id': '1-1000181-1',
  'section_title': '',
  'caption': '',
  'rows': [['Australian Capital Territory',
    'blue/white',
    'Yaa·nna',
    'ACT · CELEBRATION OF A CENTURY 2013',
    'YIL·00A',
    'Slogan screenprinted on plate'],
   ['New South Wales',
    'black/yellow',
    'aa·nn·aa',
    'NEW SOUTH WALES',
    'BX·99·HI',
    'No slogan on current series'],
   ['New South Wales',
    'black/white',
    'aaa·nna',
    'NSW',
    'CPX·12A',
    'Optional white slimline series'],
   ['Northern Territory',
    'ochre/white',
    'Ca·nn·aa',
    'NT · OUTBACK AUSTRALIA',
    'CB·06·ZZ',
    'New series began in June 2011'],
   ['Queensland',
    'maroon/white',
    'nnn·aaa

## Initialize model

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
import torch
from datasets import load_dataset
import os
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "cuda"
model_name_or_path = "bigcode/santacoder"
tokenizer_name_or_path = "bigcode/santacoder"
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=8,
    prompt_tuning_init_text="Convert question in natural language to SQL",
    tokenizer_name_or_path=model_name_or_path,
)


text_column = "question"
label_column = "human_readable"
max_length = 64
lr = 3e-2
num_epochs = 3
batch_size = 8

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

## Converting wikisql to instruction dataset

```Format: Question : <Question> \n Table Columns : <List_of_columns> \n SQL : <SQL_TO_BE_GENERATED>```

In [7]:
def fetch_table_context(table):
    header_type = [f"{header}:{typ}" for header,typ in zip(table['header'],table['types'])]
    return ",".join(header_type)


def preprocess_function(examples):
    batch_size = len(examples[text_column])
    inputs = [f"{text_column} : {x} \n Table Columns : {fetch_table_context(t)} \n SQL : " for x,t in zip(examples[text_column],examples['table'])] 
    print(inputs[0])
    targets = [str(x[label_column])+'\n' for x in examples["sql"]]
    model_inputs = tokenizer(inputs)
    labels = tokenizer(targets)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i] + [tokenizer.pad_token_id]
        # print(i, sample_input_ids, label_input_ids)
        model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
        labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
        model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])
    # print(model_inputs)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i]
        model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (
            max_length - len(sample_input_ids)
        ) + sample_input_ids
        model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[
            "attention_mask"
        ][i]
        labels["input_ids"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids
        model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length])
        model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length])
        labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [8]:
train_processed_datasets = dataset['train'].map(
    preprocess_function,
    batched=True,
    num_proc=1,
    load_from_cache_file=False,
    remove_columns=dataset["train"].column_names,
    desc="Running tokenizer on dataset",
)

test_processed_datasets = dataset['test'].map(
    preprocess_function,
    batched=True,
    num_proc=1,
    load_from_cache_file=False,
    remove_columns=dataset["test"].column_names,
    desc="Running tokenizer on dataset",
)

Running tokenizer on dataset:   2%|▍                        | 1000/56355 [00:00<00:15, 3642.70 examples/s]

question : Tell me what the notes are for South Australia  
 Table Columns : State/territory:text,Text/background colour:text,Format:text,Current slogan:text,Current series:text,Notes:text 
 SQL : 
question : What is the number for years 1985-88 
 Table Columns : Player:text,No.:real,Nationality:text,Position:text,Years for Jazz:text,School/Club Team:text 
 SQL : 


Running tokenizer on dataset:   5%|█▎                       | 3000/56355 [00:00<00:14, 3586.90 examples/s]

question : What is the barrel length for a cold model le6921sp? 
 Table Columns : Colt model no.:text,Name:text,Stock:text,Fire control:text,Rear sight:text,Forward assist:text,Barrel length:text,Barrel profile:text,Barrel twist:text,Hand guards:text,Bayonet Lug:text,Muzzle device:text 
 SQL : 


Running tokenizer on dataset:   7%|█▊                       | 4000/56355 [00:01<00:13, 3753.78 examples/s]

question : what amount of try bonus where the game was won by 11? 
 Table Columns : Club:text,Played:text,Won:text,Drawn:text,Lost:text,Points for:text,Points against:text,Tries for:text,Tries against:text,Try bonus:text,Losing bonus:text,Points:text 
 SQL : 
question : What is the highest value of PF when Ends Lost is 51? 
 Table Columns : Locale:text,Skip:text,W:real,L:real,PF:real,PA:real,Ends Won:real,Ends Lost:real,Blank Ends:real,Stolen Ends:real,Shot Pct.:text 
 SQL : 


Running tokenizer on dataset:  11%|██▋                      | 6000/56355 [00:01<00:13, 3657.15 examples/s]

question : How many times is a score for stolen ends recorded for France? 
 Table Columns : Country:text,Skip:text,W:real,L:real,PF:real,PA:real,Ends Won:real,Ends Lost:real,Blank Ends:real,Stolen Ends:real,Shot %:real 
 SQL : 


Running tokenizer on dataset:  12%|███                      | 7000/56355 [00:01<00:13, 3653.25 examples/s]

question : How many seasons did "strangled, not stirred" air? 
 Table Columns : No. in series:real,No. in season:real,Title:text,Directed by:text,Written by:text,Original air date:text,Production code:real 
 SQL : 


Running tokenizer on dataset:  14%|███▌                     | 8000/56355 [00:02<00:14, 3278.42 examples/s]

question : What's team #2 in the round where team $1 is Ilisiakos? 
 Table Columns : Team #1:text,Agg. score:text,Team #2:text,1st leg:text,2nd leg:text 
 SQL : 


Running tokenizer on dataset:  16%|███▉                     | 9000/56355 [00:02<00:13, 3483.29 examples/s]

question : What is the total of countys where Obama is popular by 35.44%? 
 Table Columns : County:text,Obama%:text,Obama#:real,McCain%:text,McCain#:real 
 SQL : 


Running tokenizer on dataset:  18%|████▎                   | 10000/56355 [00:02<00:13, 3562.94 examples/s]

question : Who is the head coach for the score of 4-3? 
 Table Columns : Tournament:real,Conference:text,Championship Game Opponent:text,Score:text,Location:text,Head Coach:text 
 SQL : 
question : How did the game number 50 end? 
 Table Columns : Game:real,Date:text,Team:text,Score:text,High points:text,High rebounds:text,High assists:text,Location Attendance:text,Record:text 
 SQL : 


Running tokenizer on dataset:  21%|█████                   | 12000/56355 [00:03<00:12, 3534.52 examples/s]

question : How many individuals watched the show that had a bbc ranking of 6? 
 Table Columns : Episode no.:real,Airdate:text,Viewers:text,BBC Three weekly ranking:text,Cable rank:text 
 SQL : 


Running tokenizer on dataset:  23%|█████▌                  | 13000/56355 [00:03<00:11, 3668.95 examples/s]

question : what are all the state/nation where the race number is 36 
 Table Columns : Position:real,Race number:text,Sail number:text,Yacht:text,State/country:text,Yacht type:text,LOA (Metres):text,Skipper:text,Elapsed time d:hh:mm:ss:text 
 SQL : 


Running tokenizer on dataset:  25%|█████▉                  | 14000/56355 [00:04<00:12, 3368.60 examples/s]

question : What is the Galician (reintegrationist) word of the Galician (Official) is adeus*? 
 Table Columns : English:text,Galician ( Official ):text,Galician ( Reintegrationist ):text,Portuguese:text,Spanish:text 
 SQL : 


Running tokenizer on dataset:  27%|██████▍                 | 15000/56355 [00:04<00:11, 3542.10 examples/s]

question : What is the production code for the episode that had 23.9 million u.s. viewers? 
 Table Columns : No. in series:real,No. in season:real,Title:text,Directed by:text,Written by:text,Original air date:text,Production code:real,U.S. viewers (millions):text 
 SQL : 


Running tokenizer on dataset:  28%|██████▊                 | 16000/56355 [00:04<00:11, 3466.06 examples/s]

question : What is the year listed when tied is listed as 11? 
 Table Columns : Year:text,Position:real,Games played:real,Won:real,Tied:real,Lost:real,Goals Scored:real,Goals Against:real,Points:real,Postseason place:text 
 SQL : 


Running tokenizer on dataset:  30%|███████▏                | 17000/56355 [00:04<00:11, 3338.88 examples/s]

question : How many weeks have an attendance of 64,116? 
 Table Columns : Week:real,Date:text,Opponent:text,Result:text,Venue:text,Attendance:text 
 SQL : 


Running tokenizer on dataset:  32%|███████▋                | 18000/56355 [00:05<00:10, 3574.13 examples/s]

question : In which venue did 0 pens and 1 try occur? 
 Table Columns : Player:text,Tries:text,Conv:text,Pens:text,Drop:text,Venue:text,Date:text 
 SQL : 
question : On what date is Hawthorn the home team? 
 Table Columns : Home team:text,Home team score:text,Away team:text,Away team score:text,Venue:text,Crowd:real,Date:text 
 SQL : 


Running tokenizer on dataset:  35%|████████▌               | 20000/56355 [00:05<00:10, 3613.02 examples/s]

question : Which Result has a Score of 4-1, and a Competition of world cup qualifying? 
 Table Columns : Date:text,Result:text,Score:text,Brazil scorers:text,Competition:text 
 SQL : 


Running tokenizer on dataset:  37%|████████▉               | 21000/56355 [00:05<00:09, 3856.41 examples/s]

question : What's the lowest Floors with Feet that's larger htan 262, has a Name of Standard Bank Building, and Metres that's larger htan 138.8? 
 Table Columns : Name:text,City:text,Years as tallest:text,Metres:real,Feet:real,Floors:real 
 SQL : 


Running tokenizer on dataset:  39%|█████████▎              | 22000/56355 [00:06<00:09, 3504.32 examples/s]

question : Which Outcome has a Score of 6–4, 2–6, 6–3? 
 Table Columns : Outcome:text,Date:text,Tournament:text,Surface:text,Opponent:text,Score:text 
 SQL : 


Running tokenizer on dataset:  41%|█████████▊              | 23000/56355 [00:06<00:08, 3737.28 examples/s]

question : Which Operator has a Width of 2.65 m, and a Type designation of m5000? 
 Table Columns : City:text,Operator:text,Type designation:text,Number of vehicles:real,Width:text 
 SQL : 


Running tokenizer on dataset:  43%|██████████▏             | 24000/56355 [00:06<00:08, 3818.09 examples/s]

question : What is the score of the game that 33,531 people went too? 
 Table Columns : Date:text,Opponent:text,Score:text,Loss:text,Attendance:text,Record:text 
 SQL : 


Running tokenizer on dataset:  44%|██████████▋             | 25000/56355 [00:07<00:08, 3504.95 examples/s]

question : For which song was the score 6.5 + 6.0 + 6.0 + 5.5 = 24.0? 
 Table Columns : Index:text,Name:text,Song:text,Group Song:text,Score:text 
 SQL : 


Running tokenizer on dataset:  46%|███████████             | 26000/56355 [00:07<00:08, 3752.73 examples/s]

question : Which nation's total is less than 19 when there's less than 1 bronze? 
 Table Columns : Rank:text,Nation:text,Gold:real,Silver:real,Bronze:real,Total:real 
 SQL : 


Running tokenizer on dataset:  48%|███████████▍            | 27000/56355 [00:07<00:07, 3787.68 examples/s]

question : What is the first locomotive that has a SLM number lower than 924? 
 Table Columns : Built:real,Number:real,Type:text,SLM Number:real,Wheel arrangement:text,Location:text,Notes:text 
 SQL : 


Running tokenizer on dataset:  50%|███████████▉            | 28000/56355 [00:07<00:08, 3535.12 examples/s]

question : What is the largest amount of top division titles featuring the tammeka club? 
 Table Columns : Club:text,Position in 2012:text,First season in top division:real,Number of seasons in Meistriliiga:real,First season of current spell in top division:real,Top division titles:real 
 SQL : 


Running tokenizer on dataset:  51%|████████████▎           | 29000/56355 [00:08<00:07, 3674.16 examples/s]

question : What was the top score for grier jones? 
 Table Columns : Place:text,Player:text,Country:text,Score:real,To par:text 
 SQL : 
question : What is the average pick for Princeton after round 3? 
 Table Columns : Round:real,Pick:real,Player:text,Nationality:text,College:text 
 SQL : 


Running tokenizer on dataset:  55%|█████████████▏          | 31000/56355 [00:08<00:06, 3641.06 examples/s]

question : What is richard virenque's lowest rank? 
 Table Columns : Rank:real,Name:text,Country:text,Wins:real,Years:text 
 SQL : 


Running tokenizer on dataset:  57%|█████████████▋          | 32000/56355 [00:08<00:06, 3736.20 examples/s]

question : What is the average number of matches of leonardo in seasons after 1? 
 Table Columns : Name:text,Seasons:real,Matches:real,Win %:text,Draw:real,Draw %:text,Lose:real,Lose %:text 
 SQL : 


Running tokenizer on dataset:  59%|██████████████          | 33000/56355 [00:09<00:06, 3479.27 examples/s]

question : What is the Area of the Parish with a Population of 2,113? 
 Table Columns : Official Name:text,Status:text,Area km 2:real,Population:real,Census Ranking:text 
 SQL : 


Running tokenizer on dataset:  60%|██████████████▍         | 34000/56355 [00:09<00:06, 3668.32 examples/s]

question : What is the highest number of rebounds of the game with a 6-14 record? 
 Table Columns : Game:real,Date:text,Opponent:text,Score:text,High points:text,High rebounds:text,High assists:text,Location/Attendance:text,Record:text 
 SQL : 
question : Who is the 2nd round opponent when Team 2 is Red Star (D1)? 
 Table Columns : Team 1:text,Score:text,Team 2:text,1st round:text,2nd round:text 
 SQL : 


Running tokenizer on dataset:  64%|███████████████▎        | 36000/56355 [00:10<00:05, 3494.07 examples/s]

question : What class had 1 made and fleet number of 406? 
 Table Columns : Class:text,Wheel arrangement:text,Fleet number(s):text,Manufacturer:text,Year made:text,Quantity made:text,Quantity preserved:text 
 SQL : 
question : what is the event for the year less than 1913 with the position of 2nd? 
 Table Columns : Year:real,Competition:text,Venue:text,Position:text,Event:text 
 SQL : 


Running tokenizer on dataset:  66%|███████████████▊        | 37000/56355 [00:10<00:05, 3824.35 examples/s]

question : What Constructor had 66 Laps? 
 Table Columns : Driver:text,Constructor:text,Laps:real,Time/Retired:text,Grid:real 
 SQL : 


Running tokenizer on dataset:  69%|████████████████▌       | 39000/56355 [00:10<00:04, 3611.99 examples/s]

question : Name the polyunsaturated fat with a saturated fat of 25g 
 Table Columns : Total fat:text,Saturated fat:text,Monounsaturated fat:text,Polyunsaturated fat:text,Smoke point:text 
 SQL : 


Running tokenizer on dataset:  71%|█████████████████       | 40000/56355 [00:11<00:04, 3793.78 examples/s]

question : What away team plays at Victoria Park? 
 Table Columns : Home team:text,Home team score:text,Away team:text,Away team score:text,Venue:text,Crowd:real,Date:text 
 SQL : 
question : What was Collingwood's score at the home match against Richmond? 
 Table Columns : Home team:text,Home team score:text,Away team:text,Away team score:text,Venue:text,Crowd:real,Date:text 
 SQL : 


Running tokenizer on dataset:  75%|█████████████████▉      | 42000/56355 [00:11<00:03, 3699.83 examples/s]

question : On waht date did Antoinette Jeanne Yvonne Boegner get married? 
 Table Columns : Name:text,Birth:text,Marriage:text,Became Duke:text,Ceased to be Duke:text,Death:text,Spouse:text 
 SQL : 


Running tokenizer on dataset:  76%|██████████████████▎     | 43000/56355 [00:12<00:03, 3818.12 examples/s]

question : When the Away team score equaled 15.20 (110) what was the Date of the game? 
 Table Columns : Home team:text,Home team score:text,Away team:text,Away team score:text,Venue:text,Crowd:real,Date:text 
 SQL : 


Running tokenizer on dataset:  78%|██████████████████▋     | 44000/56355 [00:12<00:03, 3523.91 examples/s]

question : What is Party, when Results is "Re-Elected", when First Elected is greater than 1990, and when District is "Minnesota 4"? 
 Table Columns : District:text,Incumbent:text,Party:text,First elected:real,Results:text 
 SQL : 


Running tokenizer on dataset:  80%|███████████████████▏    | 45000/56355 [00:12<00:03, 3660.51 examples/s]

question : Who is the winner in des moines, iowa where p.h. finkbank was the runner-up? 
 Table Columns : Year:text,Winner:text,Runner-up:text,Venue:text,Location:text 
 SQL : 


Running tokenizer on dataset:  82%|███████████████████▌    | 46000/56355 [00:12<00:02, 3850.90 examples/s]

question : Which Score has a To par of –3, and a Player of santiago luna? 
 Table Columns : Place:text,Player:text,Country:text,Score:real,To par:text 
 SQL : 


Running tokenizer on dataset:  83%|████████████████████    | 47000/56355 [00:13<00:02, 3521.50 examples/s]

question : What 8:00 am has a 3:00 pm of space goofs (mon) spider-man (tue-fri)? 
 Table Columns : 7:00 am:text,7:30 am:text,8:00 am:text,9:00 am:text,11:00 am:text,noon:text,12:30 pm:text,1:00 pm:text,1:30 pm:text,2:00 pm:text,3:00 pm:text,4:30 pm:text,5:00 pm:text,6:30 pm:text 
 SQL : 


Running tokenizer on dataset:  85%|████████████████████▍   | 48000/56355 [00:13<00:02, 3649.09 examples/s]

question : What is the height for the 2008 club Arona? 
 Table Columns : Name:text,Height:text,Weight:text,Spike:text,2008 club:text 
 SQL : 


Running tokenizer on dataset:  87%|████████████████████▊   | 49000/56355 [00:13<00:01, 3795.12 examples/s]

question : What is the team when the college is virginia tech? 
 Table Columns : Pick:real,Team:text,Player:text,Position:text,College:text 
 SQL : 


Running tokenizer on dataset:  89%|█████████████████████▎  | 50000/56355 [00:14<00:01, 3351.60 examples/s]

question : What is the latest year the world championships were held in Thun? 
 Table Columns : Year:real,Place:text,Gold:text,Silver:text,Bronze:text 
 SQL : 


Running tokenizer on dataset:  90%|█████████████████████▋  | 51000/56355 [00:14<00:01, 3421.74 examples/s]

question : How many picks on average did Jay Bruchak have before round 6? 
 Table Columns : Round:real,Pick:real,Player:text,Nationality:text,College:text 
 SQL : 
question : Result of 1st, and a Venue of melbourne , australia, and a Extra of 100 m happened in which year? 
 Table Columns : Year:real,Tournament:text,Venue:text,Result:text,Extra:text 
 SQL : 


Running tokenizer on dataset:  94%|██████████████████████▌ | 53000/56355 [00:14<00:00, 3525.96 examples/s]

question : What is the highest Isolation (km) when the elevation was smaller than 1320, and a Municipality of hinnøya? 
 Table Columns : Peak:text,Elevation (m):real,Prominence (m):real,Isolation (km):real,Municipality:text,County:text 
 SQL : 


Running tokenizer on dataset:  96%|██████████████████████▉ | 54000/56355 [00:15<00:00, 3697.01 examples/s]

question : What's the Total for a Mexico City game with a Gold of less than 4 and a Bronze of less than 2? 
 Table Columns : Year:text,Edition:text,Host city:text,Gold:real,Silver:real,Bronze:real,Total:real 
 SQL : 


Running tokenizer on dataset:  98%|███████████████████████▍| 55000/56355 [00:15<00:00, 3369.50 examples/s]

question : Name the average apps for smederevo 
 Table Columns : Season:text,Team:text,Country:text,Division:real,Apps:real,Goals:real 
 SQL : 


Running tokenizer on dataset:  99%|███████████████████████▊| 56000/56355 [00:15<00:00, 3578.88 examples/s]

question : What was Olin Dutra's score? 
 Table Columns : Place:text,Player:text,Country:text,Score:text,To par:text,Money ( $ ):text 
 SQL : 
question : Name the power for 1.8 duratorq 
 Table Columns : Model/Engine:text,Capacity:text,Cylinders/Valves:text,Power/rpm:text,Torque (Nm)/rpm:text 
 SQL : 


Running tokenizer on dataset:   6%|█▌                       | 1000/15878 [00:00<00:03, 3990.84 examples/s]

question : What is terrence ross' nationality 
 Table Columns : Player:text,No.:text,Nationality:text,Position:text,Years in Toronto:text,School/Club Team:text 
 SQL : 


Running tokenizer on dataset:  13%|███▏                     | 2000/15878 [00:00<00:04, 2920.18 examples/s]

question : Who is the director of the episode that corresponds to the total episodes number 14?  
 Table Columns : Total#:real,Series#:real,Title:text,Writer:text,Director:text,Original air date:text 
 SQL : 


Running tokenizer on dataset:  19%|████▋                    | 3000/15878 [00:00<00:03, 3250.14 examples/s]

question : What is the black caribbean population when the other black population is 2243? 
 Table Columns : Rank:real,London Borough:text,Black African Population:real,Black Caribbean Population:real,Other Black Population:real,Total Black Population:real 
 SQL : 


Running tokenizer on dataset:  25%|██████▎                  | 4000/15878 [00:01<00:03, 3060.75 examples/s]

question : How many votes for brown in the place that had 84.1% for coakley? 
 Table Columns : Municipality:text,Coakley votes:real,Coakley %:text,Brown votes:real,Brown %:text,Kennedy votes:real,Kennedy %:text,Total vote:real,Turnout %:text 
 SQL : 


Running tokenizer on dataset:  31%|███████▊                 | 5000/15878 [00:01<00:03, 3203.34 examples/s]

question : What is the production code for episode 96 in the series? 
 Table Columns : No. in series:text,No. in season:text,Title:text,Directed by:text,Written by:text,Original air date:text,Production code:real 
 SQL : 


Running tokenizer on dataset:  38%|█████████▍               | 6000/15878 [00:01<00:03, 3069.63 examples/s]

question : Which Slalom has a Giant Slalom of 8? 
 Table Columns : Season:real,Overall:real,Slalom:text,Giant Slalom:text,Super G:text,Downhill:text,Combined:text 
 SQL : 


Running tokenizer on dataset:  44%|███████████              | 7000/15878 [00:02<00:02, 3370.80 examples/s]

question : Rank of 5, and a Silver larger than 0 had what sum of total? 
 Table Columns : Rank:text,Nation:text,Gold:real,Silver:real,Bronze:real,Total:real 
 SQL : 
question : Who is the professional partner of celebrity małgorzata foremniak with a season less than 7, an average greater than 34.66, and a rank less than 10? 
 Table Columns : Rank:real,Celebrity:text,Professional Partner:text,Season:real,Average:real 
 SQL : 


Running tokenizer on dataset:  57%|██████████████▏          | 9000/15878 [00:02<00:01, 3573.32 examples/s]

question : Who was the opponent when he went more than 1 round with a record of 12-7? 
 Table Columns : Res.:text,Record:text,Opponent:text,Method:text,Event:text,Round:real,Time:text,Location:text 
 SQL : 


Running tokenizer on dataset:  63%|███████████████         | 10000/15878 [00:02<00:01, 3774.72 examples/s]

question : What is the area in km2 for the community that had a population in 2011 larger than 35,916 and a density of less than 47.1 inhabitants/km 2? 
 Table Columns : Name:text,Seat:text,Population (2011):real,Area (km 2 ):real,Density (inhabitants/km 2 ):real 
 SQL : 


Running tokenizer on dataset:  69%|████████████████▋       | 11000/15878 [00:03<00:01, 3469.89 examples/s]

question : What is the lowest Game, when Opponent is "Boston Bruins"? 
 Table Columns : Game:real,March:real,Opponent:text,Score:text,Record:text 
 SQL : 


Running tokenizer on dataset:  76%|██████████████████▏     | 12000/15878 [00:03<00:01, 3598.53 examples/s]

question : Tell me the winning driver for jim clark as pole position and fastest lap 
 Table Columns : Race:text,Circuit:text,Date:text,Pole position:text,Fastest lap:text,Winning driver:text,Constructor:text,Tyre:text,Report:text 
 SQL : 


Running tokenizer on dataset:  82%|███████████████████▋    | 13000/15878 [00:03<00:00, 3712.01 examples/s]

question : Which site has the CERCLIS ID fld004092532? 
 Table Columns : CERCLIS ID:text,Name:text,County:text,Partially deleted:text,Deleted:text 
 SQL : 


Running tokenizer on dataset:  88%|█████████████████████▏  | 14000/15878 [00:04<00:00, 3304.89 examples/s]

question : What country is Steve Jones from? 
 Table Columns : Place:text,Player:text,Country:text,Score:text,To par:text,Money ( $ ):real 
 SQL : 


Running tokenizer on dataset:  94%|██████████████████████▋ | 15000/15878 [00:04<00:00, 3370.49 examples/s]

question : What is the lowest number of f/laps with more than 4 podiums and more than 14 races? 
 Table Columns : Season:real,Series:text,Team:text,Races:real,Wins:real,Poles:real,F/Laps:real,Podiums:real,Points:real,Position:text 
 SQL : 


                                                                                                          

question : What Golden point(s) scorer has the Away Brisbane Broncos and Home South Sydney Rabbitohs? 
 Table Columns : Home:text,Score:text,Away:text,Venue:text,Golden point(s) scorer:text 
 SQL : 




In [9]:
train_dataloader = DataLoader(
    train_processed_datasets, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(test_processed_datasets, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)

## Train

In [10]:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,trust_remote_code=True)
model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


trainable params: 16384 || all params: 1124902912 || trainable%: 0.0014564812505348018
None


In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [12]:
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

100%|█████████████████████████████████████████████████████████████████| 7045/7045 [39:53<00:00,  2.94it/s]
100%|█████████████████████████████████████████████████████████████████| 1985/1985 [05:24<00:00,  6.11it/s]


epoch=0: train_ppl=tensor(1.3023, device='cuda:0') train_epoch_loss=tensor(0.2641, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')


100%|█████████████████████████████████████████████████████████████████| 7045/7045 [40:05<00:00,  2.93it/s]
100%|████████████████████████████████████████████████████████████| 1985/1985 [05:27<00:00,  6.07it/s]


epoch=1: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')


100%|████████████████████████████████████████████████████████████| 7045/7045 [40:11<00:00,  2.92it/s]
100%|████████████████████████████████████████████████████████████| 1985/1985 [05:26<00:00,  6.08it/s]

epoch=2: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')





In [13]:
model.save_pretrained("./prompt_model")


### Inference Test

In [57]:

dataset['test'][40]['question'], fetch_table_context(dataset['test'][40]['table'])

('Which Frequency is used for WEGP calls?',
 'Calls:text,Frequency:text,Branding:text,Format:text,Market/Rank:text,Timeslot:text,Group owner:text')

In [58]:
def infer(model,input_text):
    inputs = tokenizer(input_text,return_tensors="pt")
    with torch.no_grad():
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model.generate(
            input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=40, eos_token_id=stop_words_ids[0])
    
    return tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)

def parse(text):
    return text.split("SQL :")[1]

In [59]:
input_text = """question : What is the average score of students in maths 
 Table Columns : id:INT,name:text,subject:Text,score:INT
 SQL : """
 
predictions = infer(model,input_text)
print(parse(predictions[0]))

Setting `pad_token_id` to `eos_token_id`:185 for open-end generation.


 SELECT AVG score FROM table WHERE subject = maths



question : What is the average score of students in maths 
 Table Columns : id:INT,name:text,subject:Text,score:INT
 SQL : SELECT AVG score FROM table WHERE subject = maths
t.println(result
