In [44]:
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from datasets import load_dataset
from functools import partial

import utils
import random
import logging

from torch.utils.data import DataLoader


In [2]:
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')

Downloading: 100%|██████████| 687k/687k [00:00<00:00, 1.01MB/s]
Downloading: 100%|██████████| 287k/287k [00:00<00:00, 416kB/s] 
Downloading: 100%|██████████| 2.00/2.00 [00:00<00:00, 460B/s]
Downloading: 100%|██████████| 12.2k/12.2k [00:00<00:00, 2.92MB/s]
Downloading: 100%|██████████| 1.44k/1.44k [00:00<00:00, 318kB/s]
Downloading: 100%|██████████| 1.53k/1.53k [00:00<00:00, 230kB/s]
Downloading: 100%|██████████| 850M/850M [12:40<00:00, 1.17MB/s]  


In [5]:
dataset = load_dataset('spider')

Reusing dataset spider (C:\Users\wasii\.cache\huggingface\datasets\spider\spider\1.0.0\79778ebea87c59b19411f1eb3eda317e9dd5f7788a556d837ef25c3ae6e5e8b7)
100%|██████████| 2/2 [00:00<00:00, 39.97it/s]


In [35]:
dataset['train']['query'][:-50]

['SELECT count(*) FROM head WHERE age  >  56',
 'SELECT name ,  born_state ,  age FROM head ORDER BY age',
 'SELECT creation ,  name ,  budget_in_billions FROM department',
 'SELECT max(budget_in_billions) ,  min(budget_in_billions) FROM department',
 'SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15',
 "SELECT name FROM head WHERE born_state != 'California'",
 "SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'",
 'SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3',
 'SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1',
 "SELECT T1.name ,  T1.num_employees FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id WHERE T2.temporary_acting  =  'Yes'",
 'SELECT count(DISTINCT temporary_acting) FROM management',
 'SELECT count(*) FROM department WH

In [36]:
def preprocess_function(examples, tokenizer, max_seq_length):
    

    inputs = examples['question']
    targets = examples['query']
    
    model_inputs = tokenizer(inputs, max_length=max_seq_length, padding="max_length", truncation=True)
    target_ids = tokenizer(targets, max_length=max_seq_length, padding="max_length", truncation=True).input_ids
    
    #decoder_input_ids = []

    # for target in target_ids:
    #     decoder_input_ids.append([tokenizer.bos_token_id] + target)
    #     labels.append(target + [tokenizer.eos_token_id])

    # model_inputs["decoder_input_ids"] = decoder_input_ids

    labels_with_ignore_index = []
    
    for labels_example in target_ids:
        labels_example = [label if label != 0 else -100 for label in labels_example]
        labels_with_ignore_index.append(labels_example)
    
    model_inputs["labels"] = labels_with_ignore_index

    return model_inputs


In [42]:
max_seq_length=128
overwrite_cache=True
preprocessing_num_workers = 8
batch_size=32

In [37]:
column_names = dataset["train"].column_names

preprocess_function_wrapped = partial(
    preprocess_function,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
)


processed_datasets = dataset.map(
    preprocess_function_wrapped,
    batched=True,
    num_proc=preprocessing_num_workers,
    remove_columns=column_names,
    load_from_cache_file=not overwrite_cache,
    desc="Running tokenizer on dataset",
)

In [41]:
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"] if "validation" in processed_datasets else processed_datasets["test"]

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 2):
    print(f"Sample {index} of the training set: {train_dataset[index]}.")
    print(f"Decoded input_ids: {tokenizer.decode(train_dataset[index]['input_ids'])}")
    print(f"Decoded labels: {tokenizer.decode([label for label in train_dataset[index]['labels'] if label != -100])}")
    print("\n")


Sample 1216 of the training set: {'input_ids': [1, 990, 326, 1122, 1257, 471, 1142, 1257, 434, 777, 13051, 87, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [45]:
train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size
)

eval_dataloader = DataLoader(
    eval_dataset, shuffle=False, batch_size=batch_size
)

In [52]:
next(iter(train_dataloader))['attention_mask']

[tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 