In [1]:
from datasets import load_dataset, load_metric, concatenate_datasets
import torch
#from transformers import AutoTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments

#from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [35]:
device = torch.device('mps')

# Feature Extraction

In [3]:
train_data = load_dataset('wikisql', split='train')
val_data = load_dataset('wikisql', split='validation')
test_data = load_dataset('wikisql', split='test')

In [4]:
train_data[0]

{'phase': 1,
 'question': 'Tell me what the notes are for South Australia ',
 'table': {'header': ['State/territory',
   'Text/background colour',
   'Format',
   'Current slogan',
   'Current series',
   'Notes'],
  'page_title': '',
  'page_id': '',
  'types': ['text', 'text', 'text', 'text', 'text', 'text'],
  'id': '1-1000181-1',
  'section_title': '',
  'caption': '',
  'rows': [['Australian Capital Territory',
    'blue/white',
    'Yaa·nna',
    'ACT · CELEBRATION OF A CENTURY 2013',
    'YIL·00A',
    'Slogan screenprinted on plate'],
   ['New South Wales',
    'black/yellow',
    'aa·nn·aa',
    'NEW SOUTH WALES',
    'BX·99·HI',
    'No slogan on current series'],
   ['New South Wales',
    'black/white',
    'aaa·nna',
    'NSW',
    'CPX·12A',
    'Optional white slimline series'],
   ['Northern Territory',
    'ochre/white',
    'Ca·nn·aa',
    'NT · OUTBACK AUSTRALIA',
    'CB·06·ZZ',
    'New series began in June 2011'],
   ['Queensland',
    'maroon/white',
    'nnn·aaa

In [5]:
START_TOK = '[SOS] '
def format_dataset(example):
    return {'input': START_TOK+example['question'], 'target': example['sql']['human_readable']}


In [6]:
train_data = train_data.map(format_dataset, remove_columns=train_data.column_names)
val_data = val_data.map(format_dataset, remove_columns=val_data.column_names)
test_data = test_data.map(format_dataset, remove_columns=test_data.column_names)

train_data[0]


{'input': '[SOS] Tell me what the notes are for South Australia ',
 'target': 'SELECT Notes FROM table WHERE Current slogan = SOUTH AUSTRALIA'}

# Tokenization

In [7]:
CHECKPOINT = 'google-t5/t5-small'
tokenizer = T5Tokenizer.from_pretrained(CHECKPOINT)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Finding appropriate Max_Length

In [8]:
# map article and summary len to dict as well as if sample is longer than 512 tokens
def map_to_length(x):
    x["input_len"] = len(tokenizer(x["input"]).input_ids)
    x["input_longer_128"] = int(x["input_len"] > 128)
    x["input_longer_64"] = int(x["input_len"] > 64)
    x["input_longer_32"] = int(x["input_len"] > 32)

    x["out_len"] = len(tokenizer(x["target"]).input_ids)
    x["out_longer_128"] = int(x["out_len"] > 128)
    x["out_longer_64"] = int(x["out_len"] > 64)
    x["out_longer_32"] = int(x["out_len"] > 32)
    return x


In [9]:
train_stats = train_data.map(map_to_length, num_proc=4)

Map (num_proc=4):   0%|          | 0/56355 [00:00<?, ? examples/s]

In [10]:
val_stats = val_data.map(map_to_length, num_proc=4)

Map (num_proc=4):   0%|          | 0/8421 [00:00<?, ? examples/s]

In [11]:
test_stats = test_data.map(map_to_length, num_proc=4)

Map (num_proc=4):   0%|          | 0/15878 [00:00<?, ? examples/s]

In [12]:
all_merged = concatenate_datasets([train_stats,
                                   val_stats,
                                  test_stats])

##### Some Analysis on lengths

In [13]:
def compute_and_print_stats(x, sample_size):
    if len(x["input_len"]) == sample_size:
        print(
            "Input Max: {}, Input Mean: {:.5f}, Input>32:{},  Input>128:{:.5f}, Input>64:{:.5f} \nOutput Max: {}, Output Mean:{:.5f}, Output>32:{}, Output>128:{:.5f}, Output>64:{:.5f}".format(
                max(x["input_len"]),
                sum(x["input_len"]) / sample_size,
                sum(x["input_longer_32"]) / sample_size,
                sum(x["input_longer_128"]) / sample_size,
                sum(x["input_longer_64"]) / sample_size,
                max(x["out_len"]),
                sum(x["out_len"]) / sample_size,
                sum(x["out_longer_32"]) / sample_size,
                sum(x["out_longer_128"]) / sample_size,
                sum(x["out_longer_64"]) / sample_size,
            )
        )

In [14]:
# All Data
output = all_merged.map(
  lambda x: compute_and_print_stats(x, all_merged.shape[0]), 
  batched=True,
  batch_size=-1,
)

Map:   0%|          | 0/80654 [00:00<?, ? examples/s]

Input Max: 94, Input Mean: 21.73463, Input>32:0.07684677759317579,  Input>128:0.00000, Input>64:0.00046 
Output Max: 176, Output Mean:21.57647, Output>32:0.05963746373397476, Output>128:0.00002, Output>64:0.00035


In [15]:
# Train Data
output = train_stats.map(
  lambda x: compute_and_print_stats(x, train_stats.shape[0]), 
  batched=True,
  batch_size=-1,
)

Map:   0%|          | 0/56355 [00:00<?, ? examples/s]

Input Max: 94, Input Mean: 21.71997, Input>32:0.07614231212847129,  Input>128:0.00000, Input>64:0.00043 
Output Max: 176, Output Mean:21.57257, Output>32:0.05971076213290746, Output>128:0.00004, Output>64:0.00032


In [16]:
# Val Data
output = val_stats.map(
  lambda x: compute_and_print_stats(x, val_stats.shape[0]), 
  batched=True,
  batch_size=-1,
)

Map:   0%|          | 0/8421 [00:00<?, ? examples/s]

Input Max: 83, Input Mean: 21.78126, Input>32:0.07552547203420021,  Input>128:0.00000, Input>64:0.00071 
Output Max: 79, Output Mean:21.45707, Output>32:0.05640660254126588, Output>128:0.00000, Output>64:0.00059


### Tokenizing and Padding

In [17]:
BUFFER = 2 # start end tokens
MAX_LENGTH = 64 + BUFFER

In [18]:
def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(example_batch['input'], padding='max_length', max_length=MAX_LENGTH, truncation=True)
    target_encodings = tokenizer.batch_encode_plus(example_batch['target'], padding='max_length', max_length=MAX_LENGTH, truncation=True)
    
    encodings = {
        'input_ids': input_encodings['input_ids'], 
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids'],
        'decoder_attention_mask': target_encodings['attention_mask']
    }


    return encodings

In [19]:
finaltrain_data = train_data.map(convert_to_features, batched=True, remove_columns=train_data.column_names, num_proc=4)
finalval_data = val_data.map(convert_to_features, batched=True, remove_columns=val_data.column_names, num_proc=4)
finaltest_data = test_data.map(convert_to_features, batched=True, remove_columns=test_data.column_names, num_proc=4)

Map (num_proc=4):   0%|          | 0/56355 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/8421 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/15878 [00:00<?, ? examples/s]

In [20]:
columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']

In [21]:
finaltrain_data.set_format(type='torch', columns=columns, device=device)
finalval_data.set_format(type='torch', columns=columns, device=device)
finaltest_data.set_format(type='torch', columns=columns, device=device)

In [22]:
finaltrain_data[1]['input_ids'], finaltrain_data[0]['input_ids'].shape

(tensor([ 784,  134, 3638,  908,  363,   19,    8,  750,  939,  213,    8,  126,
          939, 1553,   16, 1515, 2722,   58,    1,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0]),
 torch.Size([66]))

In [23]:
tokenizer.decode(finaltrain_data[0]['input_ids'])

'[SOS] Tell me what the notes are for South Australia</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

# Training

In [24]:
batch_size = 8
model_name = "t5-base-medium-title-generation"
model_dir = f"./t5-small-finetuned-wikisql/"

In [None]:
args = Seq2SeqTrainingArguments(
    model_dir,
    evaluation_strategy="steps",
    eval_steps=100,
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="steps",
    save_steps=200,
    learning_rate=4e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=False,
    load_best_model_at_end=True,
    metric_for_best_model="rouge1"
)

In [39]:
args = Seq2SeqTrainingArguments(model_dir)

In [36]:
model = T5ForConditionalGeneration.from_pretrained(CHECKPOINT, device_map=device)

In [41]:
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
#    compute_metrics=compute_metrics,
    train_dataset=finaltrain_data,
    eval_dataset=finalval_data,
)

In [42]:
trainer.train()

Step,Training Loss
500,0.7802
1000,0.3053
1500,0.2623
2000,0.2474
2500,0.2316
3000,0.2182
3500,0.2106
4000,0.213
4500,0.1995
5000,0.1955


TrainOutput(global_step=21135, training_loss=0.19169367024710704, metrics={'train_runtime': 2722.0773, 'train_samples_per_second': 62.109, 'train_steps_per_second': 7.764, 'total_flos': 2949576308490240.0, 'train_loss': 0.19169367024710704, 'epoch': 3.0})

In [43]:
trainer.save_model()

In [44]:
'''model(input_ids = finaltest_data['input_ids'].to(device),
     attention_mask = finaltest_data['attention_mask'].to(device),
     decoder_input_ids = finaltest_data['labels'].to(device),
     decoder_attention_mask = finaltest_data['decoder_attention_mask'].to(device))'''

"model(input_ids = finaltest_data['input_ids'].to(device),\n     attention_mask = finaltest_data['attention_mask'].to(device),\n     decoder_input_ids = finaltest_data['labels'].to(device),\n     decoder_attention_mask = finaltest_data['decoder_attention_mask'].to(device))"

In [45]:
tokenizer.save_pretrained('./t5-small-finetuned-wikisql')

('./t5-small-finetuned-wikisql/tokenizer_config.json',
 './t5-small-finetuned-wikisql/special_tokens_map.json',
 './t5-small-finetuned-wikisql/spiece.model',
 './t5-small-finetuned-wikisql/added_tokens.json')

## Evaluation

In [49]:
finaltest_data.set_format("torch", device=device)

In [100]:
%%time
predictions = trainer.predict(finaltest_data.select(range(50)), metric_key_prefix='bleu')

CPU times: user 441 ms, sys: 3.48 s, total: 3.92 s
Wall time: 16.2 s


In [101]:
predictions.metrics

{'bleu_loss': 0.4112585484981537,
 'bleu_runtime': 16.1639,
 'bleu_samples_per_second': 3.093,
 'bleu_steps_per_second': 0.433}

In [120]:
def translate_to_sql(local_model, text):
    inputs = tokenizer(text, padding='longest', max_length=MAX_LENGTH, truncation=True, return_tensors='pt')
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    output = local_model.generate(input_ids, attention_mask=attention_mask, max_length=64)


    return tokenizer.decode(output[0], skip_special_tokens=True)

def generate_sql_on_test(data, local_model):
    length = data.shape[0]
    query = data['input']
    expected = data['target']
        
    for i in range(length):
        print(f"QUERY - {query[i]}")
        translated = translate_to_sql(local_model, query[i])
        print(f"Prediction - {translated}")
        print(f"Expected = {expected[i]}")
        print("="*50)

In [121]:
generate_sql_on_test(test_data.select(range(5)), model.to("cpu"))

QUERY - [SOS] What is terrence ross' nationality
Prediction - SELECT Nationality FROM table WHERE Nation = terrence ross
Expected = SELECT Nationality FROM table WHERE Player = Terrence Ross
QUERY - [SOS] What clu was in toronto 1995-96
Prediction - SELECT clu FROM table WHERE Location = toronto 1995-96
Expected = SELECT School/Club Team FROM table WHERE Years in Toronto = 1995-96
QUERY - [SOS] which club was in toronto 2003-06
Prediction - SELECT Club FROM table WHERE Location = toronto 2003-06
Expected = SELECT School/Club Team FROM table WHERE Years in Toronto = 2003-06
QUERY - [SOS] how many schools or teams had jalen rose
Prediction - SELECT COUNT Schools/Teams FROM table WHERE Player = Jalen Rose
Expected = SELECT COUNT School/Club Team FROM table WHERE Player = Jalen Rose
QUERY - [SOS] Where was Assen held?
Prediction - SELECT Venue FROM table WHERE Player = assen
Expected = SELECT Round FROM table WHERE Circuit = Assen
