In [1]:
#!pip install datasets transformers wandb sentencepiece tqdm

In [2]:
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from datasets import load_dataset
from functools import partial

import random
import logging
from tqdm.auto import tqdm
from torch.utils.data import DataLoader


import wandb
import transformers

In [3]:
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-small')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-small', force_download=True)

Downloading: 100%|██████████| 1.53k/1.53k [00:00<00:00, 1.56MB/s]
Downloading: 100%|██████████| 231M/231M [02:46<00:00, 1.45MB/s] 


In [4]:
dataset = load_dataset('spider')

Reusing dataset spider (C:\Users\wasii\.cache\huggingface\datasets\spider\spider\1.0.0\79778ebea87c59b19411f1eb3eda317e9dd5f7788a556d837ef25c3ae6e5e8b7)
100%|██████████| 2/2 [00:00<00:00, 68.58it/s]


In [5]:
def preprocess_function(examples, tokenizer, max_seq_length):
    

    inputs = examples['question']
    targets = examples['query']
    
    model_inputs = tokenizer(inputs, max_length=max_seq_length, padding="max_length", truncation=True)
    target_ids = tokenizer(targets, max_length=max_seq_length, padding="max_length", truncation=True)
    target_ids = target_ids.input_ids
    
    #decoder_input_ids = []

    # for target in target_ids:
    #     decoder_input_ids.append([tokenizer.bos_token_id] + target)
    #     labels.append(target + [tokenizer.eos_token_id])

    # model_inputs["decoder_input_ids"] = decoder_input_ids

    labels_with_ignore_index = []
    
    for labels_example in target_ids:
        labels_example = [label if label != 0 else -100 for label in labels_example]
        labels_with_ignore_index.append(labels_example)
    
    model_inputs["labels"] = labels_with_ignore_index

    return model_inputs


In [23]:
max_seq_length=128
overwrite_cache=False
preprocessing_num_workers = 8
batch_size=4
num_train_epochs=5
device='cuda'
learning_rate=1e-5
weight_decay=0.01
lr_scheduler_type = 'linear'
num_warmup_steps = 0
max_train_steps = 20000
logging_steps=25
eval_every_step=25

In [7]:
column_names = dataset["train"].column_names

preprocess_function_wrapped = partial(
    preprocess_function,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
)


processed_datasets = dataset.map(
    preprocess_function_wrapped,
    batched=True,
    num_proc=preprocessing_num_workers,
    remove_columns=column_names,
    load_from_cache_file=not overwrite_cache,
    desc="Running tokenizer on dataset",
)

Loading cached processed dataset at C:\Users\wasii\.cache\huggingface\datasets\spider\spider\1.0.0\79778ebea87c59b19411f1eb3eda317e9dd5f7788a556d837ef25c3ae6e5e8b7\cache-d027d5e0f3c407d4.arrow
Loading cached processed dataset at C:\Users\wasii\.cache\huggingface\datasets\spider\spider\1.0.0\79778ebea87c59b19411f1eb3eda317e9dd5f7788a556d837ef25c3ae6e5e8b7\cache-befebaf86ec901e0.arrow
Loading cached processed dataset at C:\Users\wasii\.cache\huggingface\datasets\spider\spider\1.0.0\79778ebea87c59b19411f1eb3eda317e9dd5f7788a556d837ef25c3ae6e5e8b7\cache-db8063f73b503fd8.arrow
Loading cached processed dataset at C:\Users\wasii\.cache\huggingface\datasets\spider\spider\1.0.0\79778ebea87c59b19411f1eb3eda317e9dd5f7788a556d837ef25c3ae6e5e8b7\cache-7da22a68bff089f3.arrow
Loading cached processed dataset at C:\Users\wasii\.cache\huggingface\datasets\spider\spider\1.0.0\79778ebea87c59b19411f1eb3eda317e9dd5f7788a556d837ef25c3ae6e5e8b7\cache-bc65a8799c9ebaa8.arrow
Loading cached processed dataset at

In [8]:
processed_datasets.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"] if "validation" in processed_datasets else processed_datasets["test"]

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 2):
    print(f"Sample {index} of the training set: {train_dataset[index]}.")
    print(f"Decoded input_ids: {tokenizer.decode(train_dataset[index]['input_ids'])}")
    print(f"Decoded labels: {tokenizer.decode([label for label in train_dataset[index]['labels'] if label != -100])}")
    print("\n")


Sample 268 of the training set: {'input_ids': tensor([    1,   682,   326,   508,   434, 26225,  1031,   716,   741,   486,
         1240, 27141,    18,     2,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,    

In [9]:
train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size
)

eval_dataloader = DataLoader(
    eval_dataset, shuffle=False, batch_size=batch_size
)

In [10]:
type(train_dataset['attention_mask'][0])

torch.Tensor

In [11]:
next(iter(train_dataloader))['attention_mask']

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])

In [12]:
import torch.nn as nn

class CodeT5_NLSQL(nn.Module):

  def __init__(self, model):
    super().__init__()

    self.model = model
    #self.input_size = model.config.to_dict()['hidden_size']
    #self.num_classes = num_classes

    #self.input_layer = nn.Linear(self.input_size)

  def forward(self, input_ids, attention_mask, labels=None):
    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    return outputs

In [13]:
import torch

In [25]:
nlsql_model = CodeT5_NLSQL(model)
nlsql_model.to(device)

CodeT5_NLSQL(
  (model): T5ForConditionalGeneration(
    (shared): Embedding(32100, 512)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32100, 512)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=512, out_features=512, bias=False)
                (k): Linear(in_features=512, out_features=512, bias=False)
                (v): Linear(in_features=512, out_features=512, bias=False)
                (o): Linear(in_features=512, out_features=512, bias=False)
                (relative_attention_bias): Embedding(32, 8)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseReluDense(
                (wi): Linear(in_features=512, out_features=2048, bias=False)
                (wo): Linear(in_featur

In [15]:
import tqdm

In [16]:
from tqdm.auto import tqdm

gold_file = open("gold.txt", "w")

for batch in tqdm(eval_dataloader, desc="Evaluation"):
    gold_queries = []

    for row in batch["labels"].tolist():
        gold_queries.append("".join(tokenizer.decode([value for value in row if value != -100], skip_special_tokens=True)))
    
    gold_file.write("\n".join(gold_queries))
  
gold_file.close()

Evaluation: 100%|██████████| 65/65 [00:50<00:00,  1.30it/s]


In [17]:
import numpy as np 
from evaluation import evaluate, build_foreign_key_map_from_json

def evaluate_model(model, dataloader, tokenizer, max_seq_length, device):
  model.eval()

  all_preds = []
  all_labels = []

  avg_batch_acc = 0

  with torch.no_grad():
    for batch in tqdm(dataloader, desc="Evaluation"):
      input_ids = batch["input_ids"].to(device)
      labels = batch["labels"].to(device)

      attention_mask = batch["attention_mask"].to(device)
      #token_type_ids = batch["token_type_ids"].to(device)
      
      generated_tokens = model.model.generate(
            input_ids,
            max_length=max_seq_length,
            # beam_size=beam_size,
        )


      #logits = model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)

      #preds = torch.argmax(logits, dim=-1)
      #metric.add_batch(predictions=preds, references=labels)
      decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
      
      for row in decoded_preds:
        all_preds.append(row)

      labels = labels.tolist()

      new_labels = []
      for label_row in labels:
        new_labels.append([value for value in label_row if value != -100])
        
      decoded_labels = tokenizer.batch_decode(new_labels, skip_special_tokens=True)



  pred_file = open("pred.txt", "w")

  pred_queries = []

  for row in all_preds:
      pred_queries.append("".join(tokenizer.decode([value for value in row if value != -100], skip_special_tokens=True)))
  
  pred_file.write("\n".join(pred_queries))

  pred_file.close()
  scores = evaluate('gold.txt', 'pred.txt', 'database', 'match', build_foreign_key_map_from_json('tables.json'))
  
  evaluation_results = {
      "eval/exact_match": scores['all']['exact']
  }





  model.train()
  return evaluation_results, input_ids, decoded_preds, decoded_labels


In [18]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay,
)


lr_scheduler = transformers.get_scheduler(
    name=lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=max_train_steps,
)

In [19]:
# import gc

# gc.collect()

# torch.cuda.empty_cache()

In [26]:
run = wandb.init(project=f"CODET5_SQLNL")

global_step = 0

progress_bar = tqdm(range(len(train_dataloader) * num_train_epochs))


# iterate over epochs
for epoch in range(num_train_epochs):
    nlsql_model.train()  # make sure that model is in training mode, e.g. dropout is enabled

    # iterate over batches
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)
        # decoder_input_ids = batch["decoder_input_ids"].to(device)
        # key_padding_mask = batch["encoder_padding_mask"].to(device)
        labels = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = nlsql_model(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask
        )

        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        global_step += 1

        wandb.log(
            {
                "train_loss": loss,
                "learning_rate": optimizer.param_groups[0]["lr"],
                "epoch": epoch,
            },
            step=global_step,
        )


        if global_step % eval_every_step == 0:
            eval_results, last_input_ids, last_decoded_preds, last_decoded_labels = evaluate_model(
                model=nlsql_model,
                dataloader=eval_dataloader,
                tokenizer=tokenizer,
                device=device,
                max_seq_length=max_seq_length,

            )    
            wandb.log(
             { "eval/exact_match": eval_results['eval/exact_match']}
            )
            
            print("Generation example:")
            random_index = random.randint(0, len(last_input_ids) - 1)
            print(f"Input sentence: {tokenizer.decode(last_input_ids[random_index], skip_special_tokens=True)}")
            print(f"Generated sentence: {last_decoded_preds[random_index]}")
            print(f"Reference sentence: {last_decoded_labels[random_index]}")


        if global_step % logging_steps == 0:
            # An extra training metric that might be useful for understanding
            # how well the model is doing on the training set.
            # Please pay attention to it during training.
            # If the metric is significantly below 80%, there is a chance of a bug somewhere.
            predictions = logits.argmax(-1)

            label_nonpad_mask = labels != tokenizer.pad_token_id
            num_words_in_batch = label_nonpad_mask.sum().item()

            accuracy = (predictions == labels).masked_select(label_nonpad_mask).sum().item() / num_words_in_batch

            wandb.log(
                {"train_batch_word_accuracy": accuracy},
                step=global_step,
            )
            

run.finish()  # stop wandb run


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.12.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  0%|          | 0/2190 [00:37<?, ?it/s]


KeyboardInterrupt: 