## Requirements

In [26]:
# Settings for autoreloading.

%load_ext autoreload
%reload_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
!pip install -q -U bitsandbytes
!pip install -q xformers
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets
!pip install -q einops
!pip install -q wandb
!pip install -q scipy

[0m

In [5]:
!apt-get install git-lfs

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.2).
0 upgraded, 0 newly installed, 0 to remove and 2 not upgraded.


In [6]:
!rm -r /workspace/falcon-7b-sql

rm: cannot remove '/workspace/falcon-7b-sql': No such file or directory


In [7]:
!git clone https://github.com/maidacundo/falcon-7b-sql.git

Cloning into 'falcon-7b-sql'...
remote: Enumerating objects: 57, done.[K
remote: Counting objects: 100% (39/39), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 57 (delta 10), reused 32 (delta 6), pack-reused 18[K
Receiving objects: 100% (57/57), 93.66 MiB | 13.93 MiB/s, done.
Resolving deltas: 100% (11/11), done.


In [8]:
%cd falcon-7b-sql/src

/workspace/falcon-7b-sql/src


### Login

In [9]:
import wandb
wandb.login(key='6b22cbf359c5924f4500afc1ae572d6827998186')

[34m[1mwandb[0m: Currently logged in as: [33mmaidacundo[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [10]:
from huggingface_hub import login

login(token='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',
      add_to_git_credential=True)

Token is valid (permission: write).
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.

git config --global credential.helper store

Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.[0m
Token has not been saved to git credential helper.
Your token has been saved to /root/.cache/huggingface/token
Login successful


## Model & Dataset

In [27]:
model_id = "tiiuae/falcon-7b"
dataset_id = 'spider'
spider_schema = '/workspace/falcon-7b-sql/data/tables.json'

In [28]:
from utils.dataset_utils import get_dataset
dataset = get_dataset(dataset_id, spider_schema, use_fields=True)

Found cached dataset spider (/root/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa/cache-294f7aa779b7d473.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa/cache-863368dfc5cb4519.arrow


In [29]:
import torch
from transformers import BitsAndBytesConfig
from peft import LoraConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

lora_config = LoraConfig(
    r=2, # 64
    lora_alpha=8, # 16
    target_modules=["query_key_value"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

In [30]:
from utils.training_utils import get_model_and_tokenizer
from utils.training_utils import SQL_SPECIAL_TOKENS

model, tokenizer = get_model_and_tokenizer(model_id, bnb_config, lora_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 589824 || all params: 3609348288 || trainable%: 0.016341565095310637


In [31]:
from model.training.data_collator import DialogueDataCollator
collate_fn = DialogueDataCollator(tokenizer, 
                                  use_system_prefix=True,
                                  max_length = 512,
                                  system_prefix='Convert text into SQL statements by providing a database schema and a query, and generate the corresponding SQL statement.'
                                 )

```
{'input_ids': tensor([[65024,    76,    11, 65026,    77,    11, 65025,    78,    11]]),
'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]),
'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]]),
'label_masks': tensor([[False, False, False, False, False, False,  True,  True, False]]),
'targets': tensor([[   76,    11, 65026,    77,    11, 65025,    78,    11, 65024]])}
```



## Training

In [32]:
batch_size = 16
gradient_accumulation_steps = 1
total_training_steps = len(dataset['train']) // (batch_size * gradient_accumulation_steps)

warmup_steps = (total_training_steps) * 0.1 # 10% of total steps for

In [33]:
import transformers

training_args = transformers.TrainingArguments(
        full_determinism=False,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=warmup_steps,
        num_train_epochs=3,
        learning_rate=1e-4,
        weight_decay=0,
        fp16=True,
        logging_steps=1,
        output_dir="../../falcon_qlora_sql_r2",
        optim="paged_adamw_8bit",
        seed=42,
        push_to_hub=True,
        report_to="wandb",
        save_strategy='steps',
        evaluation_strategy='steps',
        eval_steps=100,
        save_steps=100,
    )

In [34]:
import wandb
wandb.init(project='falcon_qlora_sql', entity='maidacundo', config=training_args)

0,1
eval/loss,█▂▂▂▂▂▂▂▂▂▁▁▁▁▁
eval/runtime,▁▅▅▅▆▆▅▆▅▆▇▇▇▇█
eval/samples_per_second,█▄▄▄▃▃▄▃▄▃▂▂▂▂▁
eval/steps_per_second,█▄▄▄▃▃▄▃▄▄▂▂▂▂▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,▄███▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
train/loss,▇▃▂▂▂▂▃▂▂▃▃▂▂▃▂▁▂▄▂▂▄▂▄▃▂▂▂▃▂▂▁▂▂█▂▂▂▂▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.17716
eval/runtime,83.884
eval/samples_per_second,12.327
eval/steps_per_second,0.775
train/epoch,3.0
train/global_step,1314.0
train/learning_rate,0.0
train/loss,0.1243
train/total_flos,1.623570697907159e+17
train/train_loss,0.3017


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669267116715975, max=1.0…

In [19]:
small_train = dataset['train'].select(range(64))
small_eval = dataset['validation'].select(range(64))

In [35]:
from model.training.sft_trainer import preprocess_logits_for_metrics, SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    sampler=None,
    train_collate_fn=collate_fn,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    data_collator=collate_fn,
    tokenizer=tokenizer,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
model.config.use_cache = False

/workspace/falcon-7b-sql/src/../../falcon_qlora_sql_r2 is already a clone of https://huggingface.co/maidacundo/falcon_qlora_sql_r2. Make sure you pull the latest changes with `repo.git_pull()`.


In [36]:
trainer.evaluate(dataset['validation'])

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 1.2636007070541382,
 'eval_runtime': 81.9444,
 'eval_samples_per_second': 12.618,
 'eval_steps_per_second': 0.793}

In [37]:
trainer.train()

Step,Training Loss,Validation Loss
100,0.2993,0.2863
200,0.8003,0.335807
300,0.1872,0.242406
400,0.1267,0.236244
500,0.2214,0.256431
600,0.2885,0.218684
700,0.1654,0.198765
800,0.1633,0.206228
900,0.0381,0.186784
1000,0.0633,0.176708


TrainOutput(global_step=1314, training_loss=0.24357877032151956, metrics={'train_runtime': 6948.0175, 'train_samples_per_second': 3.022, 'train_steps_per_second': 0.189, 'total_flos': 1.614661660624896e+17, 'train_loss': 0.24357877032151956, 'epoch': 3.0})

In [38]:
trainer.evaluate(dataset['validation'])

{'eval_loss': 0.17348535358905792,
 'eval_runtime': 83.4912,
 'eval_samples_per_second': 12.385,
 'eval_steps_per_second': 0.779,
 'epoch': 3.0}

In [39]:
trainer.push_to_hub()

Upload file adapter_model.bin:   1%|1         | 32.0k/2.27M [00:00<?, ?B/s]

To https://huggingface.co/maidacundo/falcon_qlora_sql_r2
   eb69a94..bec40ca  main -> main

To https://huggingface.co/maidacundo/falcon_qlora_sql_r2
   bec40ca..b81efcd  main -> main



'https://huggingface.co/maidacundo/falcon_qlora_sql_r2/commit/bec40caccbfd62694c4d30fbbc330198e05866f3'

## Inference

In [42]:
from utils.training_utils import get_tokenizer, get_model, add_embeddings_to_model, SQL_SPECIAL_TOKENS
from peft import PeftModel, prepare_model_for_kbit_training

def get_pretraineed_model_and_tokenizer(model_id: str, bnb_config, lora_id: str):
    tokenizer = get_tokenizer(model_id)
    model = get_model(model_id, bnb_config)
    add_embeddings_to_model(model, tokenizer, SQL_SPECIAL_TOKENS)
    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)
    model = PeftModel.from_pretrained(model, lora_id, torch_dtype=torch.float16)
    return model, tokenizer


In [43]:
import torch
from transformers import BitsAndBytesConfig
from peft import LoraConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

lora_id = 'maidacundo/falcon_qlora_sql'

model, tokenizer = get_pretraineed_model_and_tokenizer(model_id, bnb_config, lora_id)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)/adapter_config.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

Downloading adapter_model.bin:   0%|          | 0.00/75.5M [00:00<?, ?B/s]

In [70]:
from torch.utils.data import Dataset, DataLoader

class InferenceDataset(Dataset):

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.dataset[idx]

        return 'Convert text into SQL statements by providing a database schema and a query, and generate the corresponding SQL statement.' + sample['input_text'].split('<|sql|>')[0] + '<|sql|>'

inference_ds = InferenceDataset(dataset['validation'].select(range(50)))

inference_dataloader = DataLoader(inference_ds, batch_size=1, shuffle=False)

In [71]:
import transformers
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausa

In [72]:
from tqdm import tqdm
results = []
for batch in tqdm(inference_dataloader):
  out = pipeline(batch,
                do_sample=False,
                max_length=512,
                temperature=0.2,
                top_k=3,
                top_p=0.9,
                repetition_penalty=1.2,
                eos_token_id=tokenizer(';')['input_ids'][0],
                pad_token_id=tokenizer.eos_token_id
                )
  for res in out:
    prediction = res[0]['generated_text'].split('<|sql|>')[-1]
    print(prediction)
    results.append(prediction)

  2%|▏         | 1/50 [00:00<00:43,  1.12it/s]

select count(*) from singer;


  4%|▍         | 2/50 [00:07<03:34,  4.47s/it]

select sum(t1.age) from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id;


  6%|▌         | 3/50 [00:11<03:05,  3.95s/it]

select song_name,  country,  age from singer order by age desc;


  8%|▊         | 4/50 [00:14<02:50,  3.70s/it]

select song_name,  country,  age from singer order by age desc;


 10%|█         | 5/50 [01:05<15:41, 20.92s/it]

select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age) 


 12%|█▏        | 6/50 [01:58<23:07, 31.52s/it]

select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),  min(age),  max(age) from singer where country  =  "france";select avg(age),


 14%|█▍        | 7/50 [02:09<17:49, 24.87s/it]

select t1.song_name,  t1.song_release_year from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id order by t2.age limit 1;


 16%|█▌        | 8/50 [02:12<12:33, 17.93s/it]

select song_name,  year from singer order by age limit 1;


 18%|█▊        | 9/50 [02:14<08:52, 12.98s/it]

select country from singer where age  >  20;


 20%|██        | 10/50 [02:16<06:24,  9.61s/it]

select country from singer where age  >  20;


 22%|██▏       | 11/50 [02:18<04:46,  7.35s/it]

select country,  count(*) from singer group by country;


 24%|██▍       | 12/50 [02:20<03:38,  5.76s/it]

select country,  count(*) from singer group by country;


 26%|██▌       | 13/50 [02:27<03:42,  6.02s/it]

select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer order by age desc limit 1;


 28%|██▊       | 14/50 [03:19<11:56, 19.91s/it]

select song_name from singer where age  >  (select avg(age) from singer);select song_name from singer where is_male  =  'f' and age  >  (select avg(age) from singer);select song_name from singer where is_male  =  'f' and age  >  (select avg(age) from singer);select song_name from singer where is_male  =  'f' and age  >  (select avg(age) from singer);select song_name from singer where is_male  =  'f' and age  >  (select avg(age) from singer);select song_name from singer where is_male  =  'f' and age  >  (select avg(age) from singer);select song_name from singer where is_male  =  'f' and age  >  (select avg(age) from singer);select song_name from singer where is_male  =  'f' and age  >  (select avg(age) from singer);select song_name from singer where is_male  =  'f' and


 30%|███       | 15/50 [03:29<09:53, 16.95s/it]

select t1.location,  t2.name from stadium as t1 join stadium as t2 on t1.stadium_id  =  t2.stadium_id where t1.capacity between 5000 and 10000;


 32%|███▏      | 16/50 [03:33<07:18, 12.89s/it]

select location,  name from stadium where capacity between 5000 and 10000;


 34%|███▍      | 17/50 [03:35<05:23,  9.80s/it]

select max(capacity),  avg(capacity) from stadium;


 36%|███▌      | 18/50 [03:38<04:04,  7.63s/it]

select avg(capacity),  max(capacity) from stadium;


 38%|███▊      | 19/50 [03:49<04:35,  8.88s/it]

select t1.name,  t1.average from stadium as t1 join stadium as t2 on t1.stadium_id  =  t2.stadium_id group by t1.stadium_id order by avg(t2.attendance) desc limit 1;


 40%|████      | 20/50 [04:01<04:52,  9.74s/it]

select t1.name,  t1.average from stadium as t1 join stadium as t2 on t1.stadium_id  =  t2.stadium_id group by t1.stadium_id order by avg(t2.attendance) desc limit 1;


 42%|████▏     | 21/50 [04:05<03:51,  7.99s/it]

select count(*) from concert where year  =  2014 or year  =  2015;


 44%|████▍     | 22/50 [04:09<03:08,  6.74s/it]

select count(*) from concert where year  =  2014 or year  =  2015;


 46%|████▌     | 23/50 [04:17<03:12,  7.14s/it]

select t1.name,  count(*) from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id group by t1.name;


 48%|████▊     | 24/50 [04:26<03:16,  7.55s/it]

select count(*),  t1.name from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id group by t1.stadium_id;


 50%|█████     | 25/50 [04:38<03:42,  8.91s/it]

select t1.name,  t2.capacity from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  >  2014 group by t1.name order by count(*) desc limit 1;


 52%|█████▏    | 26/50 [04:47<03:37,  9.08s/it]

select t1.name,  t1.capacity from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  >  2013;


 54%|█████▍    | 27/50 [04:50<02:44,  7.14s/it]

select year from concert group by year order by count(*) desc limit 1;


 56%|█████▌    | 28/50 [04:52<02:03,  5.63s/it]

select year from concert order by count(*) desc limit 1;


 58%|█████▊    | 29/50 [04:59<02:05,  5.97s/it]

select name from stadium except select t1.name from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id;


 60%|██████    | 30/50 [05:50<06:31, 19.57s/it]

select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium where stadium_id not in ( select stadium_id from co

 62%|██████▏   | 31/50 [05:54<04:44, 14.98s/it]

select country from singer where age  >  40 intersect select country from singer where age  <  30;


 64%|██████▍   | 32/50 [06:44<07:39, 25.53s/it]

select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where


 66%|██████▌   | 33/50 [07:34<09:19, 32.92s/it]

select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium where stadium_id not in ( select stadium_id from concert where


 68%|██████▊   | 34/50 [07:44<06:54, 25.93s/it]

select t1.concert_name,  t2.theme from concert as t1 join singer_in_concert as t2 on t1.concert_id  =  t2.concert_id group by t1.concert_id;


 70%|███████   | 35/50 [07:54<05:16, 21.10s/it]

select t1.concert_name,  t2.theme,  count(*) from concert as t1 join singer as t2 on t1.concert_id  =  t2.concert_id group by t1.concert_id;


 72%|███████▏  | 36/50 [08:02<04:00, 17.15s/it]

select t1.name,  count(*) from singer as t1 join concert as t2 on t1.singer_id  =  t2.singer_id group by t1.name;


 74%|███████▍  | 37/50 [08:11<03:10, 14.65s/it]

select t1.name,  count(*) from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id group by t1.name;


 76%|███████▌  | 38/50 [08:18<02:31, 12.59s/it]

select t1.name from singer as t1 join concert as t2 on t1.singer_id  =  t2.concert_id where t2.year  =  2014;


 78%|███████▊  | 39/50 [08:27<02:03, 11.25s/it]

select t1.name from singer as t1 join concert as t2 on t1.singer_id  =  t2.singer_id where t2.year  =  2014;


 80%|████████  | 40/50 [09:16<03:47, 22.76s/it]

select t1.name,  t1.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t3.name,  t3.country from singer as t3 join singer_in_concert as t4 on t3.singer_id  =  t4.singer_id where t4.song_name like '%hey%';select t5.name,  t5.country from singer as t5 join singer_in_concert as t6 on t5.singer_id  =  t6.singer_id where t6.song_name like '%hey%';select t7.name,  t7.country from singer as t7 join singer_in_concert as t8 on t7.singer_id  =  t8.singer_id where t8.song_name like '%hey%';select t9.name,  t9.country from singer as t


 82%|████████▏ | 41/50 [10:05<04:35, 30.61s/it]

select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.


 84%|████████▍ | 42/50 [10:24<03:37, 27.22s/it]

select t1.name,  t1.location from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  =  2014 intersect select t1.name,  t1.location from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  =  2015;


 86%|████████▌ | 43/50 [10:44<02:53, 24.85s/it]

select t1.name,  t1.location from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  =  2014 intersect select t1.name,  t1.location from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  =  2015;


 88%|████████▊ | 44/50 [10:48<01:52, 18.72s/it]

select count(*) from stadium where name  =  ( select stadium_name from stadium order by capacity desc limit 1;


 90%|█████████ | 45/50 [11:39<02:21, 28.25s/it]

select count(*) from concert where stadium_id  =  ( select stadium_id from stadium order by capacity desc limit 1);select count(*) from concert where stadium_id  =  ( select stadium_id from stadium order by capacity desc limit 1);select count(*) from concert where stadium_id  =  ( select stadium_id from stadium order by capacity desc limit 1);select count(*) from concert where stadium_id  =  ( select stadium_id from stadium order by capacity desc limit 1);select count(*) from concert where stadium_id  =  ( select stadium_id from stadium order by capacity desc limit 1);select count(*) from concert where stadium_id  =  ( select stadium_id from stadium order by capacity desc limit 1);select count(*) from concert where stadium_id  =  ( select stadium_id from stadium order by capacity desc limit 1);select count(*) from concert where stadium_id  =  ( select stadium_id from stadium order by capacity desc limit 1);select count(*) from concert where stadium_id  =  ( select stadium_id from stadi

 92%|█████████▏| 46/50 [11:40<01:21, 20.32s/it]

select count(*) from pets where weight  >  10;


 94%|█████████▍| 47/50 [11:42<00:44, 14.77s/it]

select count(*) from pets where weight  >  10;


 96%|█████████▌| 48/50 [11:44<00:21, 10.93s/it]

select weight from pets order by pet_age desc limit 1;


 98%|█████████▊| 49/50 [12:49<00:27, 27.09s/it]

select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(weight) from pets where pet_age  =  'youngest';select min(wei

100%|██████████| 50/50 [12:52<00:00, 15.45s/it]

select max(weight),  pet_type from pets group by pet_type;





In [73]:
with open('falcon_r16.txt', 'w') as f:
    for line in results:
        f.write(f"{line}\n")

In [74]:
lora_id = 'maidacundo/falcon_qlora_sql_r2'

model, tokenizer = get_pretraineed_model_and_tokenizer(model_id, bnb_config, lora_id)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)/adapter_config.json:   0%|          | 0.00/407 [00:00<?, ?B/s]

Downloading adapter_model.bin:   0%|          | 0.00/2.38M [00:00<?, ?B/s]

In [75]:
import transformers
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausa

In [76]:
from tqdm import tqdm
results = []
for batch in tqdm(inference_dataloader):
  out = pipeline(batch,
                do_sample=False,
                max_length=512,
                temperature=0.2,
                top_k=3,
                top_p=0.9,
                repetition_penalty=1.2,
                eos_token_id=tokenizer(';')['input_ids'][0],
                pad_token_id=tokenizer.eos_token_id
                )
  for res in out:
    prediction = res[0]['generated_text'].split('<|sql|>')[-1]
    print(prediction)
    results.append(prediction)

  2%|▏         | 1/50 [00:00<00:42,  1.14it/s]

select count(*) from singer;


  4%|▍         | 2/50 [00:01<00:41,  1.14it/s]

select count(*) from singer;


  6%|▌         | 3/50 [00:04<01:24,  1.80s/it]

select name,  country,  age from singer order by age desc;


  8%|▊         | 4/50 [00:07<01:42,  2.23s/it]

select name,  country,  age from singer order by age desc;


 10%|█         | 5/50 [00:23<05:21,  7.14s/it]

select avg(age),  min(age),  max(age) from singer where country  =  "france";create table singer_in_concert as select t1.singer_id,  t2.concert_id from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id;


 12%|█▏        | 6/50 [00:39<07:27, 10.17s/it]

select avg(age),  min(age),  max(age) from singer where country  =  "france";create table singer_in_concert as select t1.singer_id,  t2.concert_id from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id;


 14%|█▍        | 7/50 [00:49<07:21, 10.27s/it]

select t1.name,  t2.song_release_year from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id order by t1.age limit 1;


 16%|█▌        | 8/50 [01:05<08:15, 11.80s/it]

select song_name,  song_release_year from singer where age  =  ( select min(age) from singer );select t1.song_name,  t2.song_release_year from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id;


 18%|█▊        | 9/50 [01:07<06:06,  8.93s/it]

select country from singer where age  >  20 group by country;


 20%|██        | 10/50 [01:09<04:32,  6.81s/it]

select country from singer where age  >  20;


 22%|██▏       | 11/50 [01:11<03:30,  5.40s/it]

select country,  count(*) from singer group by country;


 24%|██▍       | 12/50 [01:14<02:47,  4.40s/it]

select country,  count(*) from singer group by country;


 26%|██▌       | 13/50 [01:18<02:49,  4.57s/it]

select song_name from singer where age  >  (select avg(age) from singer);select song_name from singer;


 28%|██▊       | 14/50 [02:11<11:21, 18.93s/it]

select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age) from singer );select song_name from singer where age  >  ( select avg(age)


 30%|███       | 15/50 [02:14<08:19, 14.28s/it]

select location,  name from stadium where capacity between 5000 and 10000;


 32%|███▏      | 16/50 [02:18<06:14, 11.03s/it]

select location,  name from stadium where capacity between 5000 and 10000;


 34%|███▍      | 17/50 [02:20<04:40,  8.50s/it]

select max(capacity),  avg(capacity) from stadium;


 36%|███▌      | 18/50 [02:23<03:35,  6.73s/it]

select avg(capacity),  max(capacity) from stadium;


 38%|███▊      | 19/50 [02:35<04:23,  8.51s/it]

select t1.name,  t2.capacity from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id group by t1.stadium_id order by avg(t2.attendance) desc limit 1;


 40%|████      | 20/50 [02:48<04:52,  9.75s/it]

select t1.name,  t2.capacity from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id group by t1.stadium_id order by avg(t2.attendance) desc limit 1;


 42%|████▏     | 21/50 [02:52<03:52,  8.02s/it]

select count(*) from concert where year  =  2014 or year  =  2015;


 44%|████▍     | 22/50 [02:56<03:09,  6.78s/it]

select count(*) from concert where year  =  2014 or year  =  2015;


 46%|████▌     | 23/50 [03:05<03:20,  7.43s/it]

select t1.name,  count(*) from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id group by t1.name;


 48%|████▊     | 24/50 [03:07<02:29,  5.74s/it]

select count(*) from stadium group by stadium_id;


 50%|█████     | 25/50 [03:58<08:03, 19.34s/it]

select t1.name,  t2.capacity from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  >=  2014 group by t1.name having count(*)  =  ( select max(count(*) ) from singer_in_concert as t3 join concert as t4 on t3.concert_id  =  t4.concert_id where t4.year  =  2014);select t1.name,  t2.capacity from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  >=  2014 group by t1.name having count(*)  =  ( select max(count(*) ) from singer_in_concert as t3 join concert as t4 on t3.concert_id  =  t4.concert_id where t4.year  =  2014);select t1.name,  t2.capacity from stadium as t


 52%|█████▏    | 26/50 [04:27<08:58, 22.44s/it]

select t1.name,  t1.capacity from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.concert_id  >  ( select max(t3.concert_id) from concert as t3 join singer_in_concert as t4 on t3.concert_id  =  t4.concert_id where t4.concert_id  >  2013);select t1.name,  t1.capacity from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id;


 54%|█████▍    | 27/50 [04:30<06:19, 16.49s/it]

select year from concert group by year order by count(*) desc limit 1;


 56%|█████▌    | 28/50 [04:32<04:28, 12.18s/it]

select year from concert order by count(*) desc limit 1;


 58%|█████▊    | 29/50 [04:37<03:26,  9.84s/it]

select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium;


 60%|██████    | 30/50 [04:41<02:44,  8.25s/it]

select name from stadium where stadium_id not in ( select stadium_id from concert );select name from stadium;


 62%|██████▏   | 31/50 [04:45<02:14,  7.06s/it]

select country from singer where age  >  40 intersect select country from singer where age  <  30;


 64%|██████▍   | 32/50 [04:52<02:02,  6.78s/it]

select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium;


 66%|██████▌   | 33/50 [04:58<01:51,  6.58s/it]

select name from stadium where stadium_id not in ( select stadium_id from concert where year  =  2014 );select name from stadium;


 68%|██████▊   | 34/50 [05:08<02:04,  7.77s/it]

select t1.concert_name,  t1.theme,  count(*) from concert as t1 join singer_in_concert as t2 on t1.concert_id  =  t2.concert_id group by t1.concert_id;


 70%|███████   | 35/50 [05:19<02:08,  8.60s/it]

select t1.concert_name,  t2.theme,  count(*) from concert as t1 join singer_in_concert as t2 on t1.concert_id  =  t2.concert_id group by t1.concert_name;


 72%|███████▏  | 36/50 [05:27<02:00,  8.62s/it]

select t1.name,  count(*) from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id group by t1.name;


 74%|███████▍  | 37/50 [05:37<01:54,  8.84s/it]

select t1.name,  count(*) from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id group by t1.singer_id;


 76%|███████▌  | 38/50 [05:44<01:41,  8.46s/it]

select t1.name from singer as t1 join concert as t2 on t1.concert_id  =  t2.concert_id where t2.year  =  2014;


 78%|███████▊  | 39/50 [05:54<01:35,  8.69s/it]

select t1.name from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.concert_id  =  2014;


 80%|████████  | 40/50 [06:43<03:29, 20.93s/it]

select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1


 82%|████████▏ | 41/50 [07:32<04:23, 29.28s/it]

select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from singer as t1 join singer_in_concert as t2 on t1.singer_id  =  t2.singer_id where t2.song_name like '%hey%';select t1.name,  t2.country from


 84%|████████▍ | 42/50 [07:53<03:33, 26.73s/it]

select t1.name,  t2.location from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  =  2014 intersect select t1.name,  t2.location from stadium as t1 join singer_in_concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  =  2015;


 86%|████████▌ | 43/50 [08:12<02:51, 24.48s/it]

select t1.name,  t2.location from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  =  2014 intersect select t1.name,  t2.location from stadium as t1 join concert as t2 on t1.stadium_id  =  t2.stadium_id where t2.year  =  2015;


 88%|████████▊ | 44/50 [08:17<01:52, 18.78s/it]

select count(*) from concert where stadium_id  =  ( select max(stadium_id) from stadium);select count(*) from concert;


 90%|█████████ | 45/50 [09:08<02:21, 28.24s/it]

select count(*) from concert where stadium_id  =  ( select max(stadium_id) from stadium );select count(*) from concert as t1 join singer_in_concert as t2 on t1.concert_id  =  t2.concert_id join stadium as t3 on t2.stadium_id  =  t3.stadium_id where t3.highest  =  ( select max(highest) from stadium );select count(*) from concert as t1 join singer_in_concert as t2 on t1.concert_id  =  t2.concert_id join stadium as t3 on t2.stadium_id  =  t3.stadium_id where t3.lowest  =  ( select min(lowest) from stadium );select count(*) from concert as t1 join singer_in_concert as t2 on t1.concert_id  =  t2.concert_id join stadium as t3 on t2.stadium_id  =  t3.stadium_id where t3.average  =  ( select avg(average) from stadium );select


 92%|█████████▏| 46/50 [09:09<01:21, 20.33s/it]

select count(*) from pets where weight  >  10;


 94%|█████████▍| 47/50 [09:11<00:44, 14.79s/it]

select count(*) from pets where weight  >  10;


 96%|█████████▌| 48/50 [09:13<00:21, 10.95s/it]

select weight from pets order by pet_age desc limit 1;


 98%|█████████▊| 49/50 [10:18<00:27, 27.11s/it]

select weight from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) from pets where pet_age  =  'youngest';select sum(weight) 

100%|██████████| 50/50 [10:21<00:00, 12.42s/it]

select max(weight),  pettype from pets group by pettype;





In [77]:
with open('falcon_r2.txt', 'w') as f:
    for line in results:
        f.write(f"{line}\n")