In [1]:
from datasets import load_dataset, load_metric, concatenate_datasets
import torch
from transformers import AutoTokenizer
from transformers import T5ForConditionalGeneration
from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments

In [2]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

# Feature Extraction

In [3]:
train_data = load_dataset('wikisql', split='train')
val_data = load_dataset('wikisql', split='validation')
test_data = load_dataset('wikisql', split='test')

In [4]:
train_data[0]

{'phase': 1,
 'question': 'Tell me what the notes are for South Australia ',
 'table': {'header': ['State/territory',
   'Text/background colour',
   'Format',
   'Current slogan',
   'Current series',
   'Notes'],
  'page_title': '',
  'page_id': '',
  'types': ['text', 'text', 'text', 'text', 'text', 'text'],
  'id': '1-1000181-1',
  'section_title': '',
  'caption': '',
  'rows': [['Australian Capital Territory',
    'blue/white',
    'Yaa·nna',
    'ACT · CELEBRATION OF A CENTURY 2013',
    'YIL·00A',
    'Slogan screenprinted on plate'],
   ['New South Wales',
    'black/yellow',
    'aa·nn·aa',
    'NEW SOUTH WALES',
    'BX·99·HI',
    'No slogan on current series'],
   ['New South Wales',
    'black/white',
    'aaa·nna',
    'NSW',
    'CPX·12A',
    'Optional white slimline series'],
   ['Northern Territory',
    'ochre/white',
    'Ca·nn·aa',
    'NT · OUTBACK AUSTRALIA',
    'CB·06·ZZ',
    'New series began in June 2011'],
   ['Queensland',
    'maroon/white',
    'nnn·aaa

In [5]:
def format_dataset(example):
    return {'input': example['question'], 'target': example['sql']['human_readable']}


In [6]:
train_data = train_data.map(format_dataset, remove_columns=train_data.column_names)
val_data = val_data.map(format_dataset, remove_columns=val_data.column_names)
test_data = test_data.map(format_dataset, remove_columns=test_data.column_names)

train_data[0]


{'input': 'Tell me what the notes are for South Australia ',
 'target': 'SELECT Notes FROM table WHERE Current slogan = SOUTH AUSTRALIA'}

# Tokenization

In [7]:
tokenizer = AutoTokenizer.from_pretrained('t5-small')

### Finding appropriate Max_Length

In [8]:
# map article and summary len to dict as well as if sample is longer than 512 tokens
def map_to_length(x):
    x["input_len"] = len(tokenizer(x["input"]).input_ids)
    x["input_longer_128"] = int(x["input_len"] > 128)
    x["input_longer_64"] = int(x["input_len"] > 64)
    x["input_longer_32"] = int(x["input_len"] > 32)

    x["out_len"] = len(tokenizer(x["target"]).input_ids)
    x["out_longer_128"] = int(x["out_len"] > 128)
    x["out_longer_64"] = int(x["out_len"] > 64)
    x["out_longer_32"] = int(x["out_len"] > 32)
    return x


In [9]:
train_stats = train_data.map(map_to_length, num_proc=4)

In [10]:
val_stats = val_data.map(map_to_length, num_proc=4)

In [11]:
test_stats = test_data.map(map_to_length, num_proc=4)

In [12]:
all_merged = concatenate_datasets([train_stats,
                                   val_stats,
                                  test_stats])

##### Some Analysis on lengths

In [13]:
def compute_and_print_stats(x, sample_size):
    if len(x["input_len"]) == sample_size:
        print(
            "Input Max: {}, Input Mean: {:.5f}, Input>32:{},  Input>128:{:.5f}, Input>64:{:.5f} \nOutput Max: {}, Output Mean:{:.5f}, Output>32:{}, Output>128:{:.5f}, Output>64:{:.5f}".format(
                max(x["input_len"]),
                sum(x["input_len"]) / sample_size,
                sum(x["input_longer_32"]) / sample_size,
                sum(x["input_longer_128"]) / sample_size,
                sum(x["input_longer_64"]) / sample_size,
                max(x["out_len"]),
                sum(x["out_len"]) / sample_size,
                sum(x["out_longer_32"]) / sample_size,
                sum(x["out_longer_128"]) / sample_size,
                sum(x["out_longer_64"]) / sample_size,
            )
        )

In [14]:
# All Data
output = all_merged.map(
  lambda x: compute_and_print_stats(x, all_merged.shape[0]), 
  batched=True,
  batch_size=-1,
)

Map:   0%|          | 0/80654 [00:00<?, ? examples/s]

Input Max: 90, Input Mean: 17.73463, Input>32:0.039539266496391993,  Input>128:0.00000, Input>64:0.00027 
Output Max: 176, Output Mean:21.57647, Output>32:0.05963746373397476, Output>128:0.00002, Output>64:0.00035


In [15]:
# Train Data
output = train_stats.map(
  lambda x: compute_and_print_stats(x, train_stats.shape[0]), 
  batched=True,
  batch_size=-1,
)

Map:   0%|          | 0/56355 [00:00<?, ? examples/s]

Input Max: 90, Input Mean: 17.71997, Input>32:0.03857687871528702,  Input>128:0.00000, Input>64:0.00027 
Output Max: 176, Output Mean:21.57257, Output>32:0.05971076213290746, Output>128:0.00004, Output>64:0.00032


In [16]:
# Val Data
output = val_stats.map(
  lambda x: compute_and_print_stats(x, val_stats.shape[0]), 
  batched=True,
  batch_size=-1,
)

Map:   0%|          | 0/8421 [00:00<?, ? examples/s]

Input Max: 79, Input Mean: 17.78126, Input>32:0.038950243439021495,  Input>128:0.00000, Input>64:0.00059 
Output Max: 79, Output Mean:21.45707, Output>32:0.05640660254126588, Output>128:0.00000, Output>64:0.00059


### Tokenizing and Padding

In [17]:
MAX_LENGTH = 64

In [18]:
def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(example_batch['input'], padding='max_length', max_length=MAX_LENGTH, truncation=True)
    target_encodings = tokenizer.batch_encode_plus(example_batch['target'], padding='max_length', max_length=MAX_LENGTH, truncation=True)
    
    encodings = {
        'input_ids': input_encodings['input_ids'], 
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids'],
        'decoder_attention_mask': target_encodings['attention_mask']
    }


    return encodings

In [19]:
finaltrain_data = train_data.map(convert_to_features, batched=True, remove_columns=train_data.column_names)
finalval_data = val_data.map(convert_to_features, batched=True, remove_columns=val_data.column_names)
finaltest_data = test_data.map(convert_to_features, batched=True, remove_columns=test_data.column_names)

Map:   0%|          | 0/8421 [00:00<?, ? examples/s]

In [20]:
columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']

In [21]:
finaltrain_data.set_format(type='torch', columns=columns)
finalval_data.set_format(type='torch', columns=columns)
finaltest_data.set_format(type='torch', columns=columns)

In [22]:
finaltrain_data[1]['input_ids'], finaltrain_data[0]['input_ids'].shape

(tensor([ 363,   19,    8,  750,  939,  213,    8,  126,  939, 1553,   16, 1515,
         2722,   58,    1,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0]),
 torch.Size([64]))

# Training

In [None]:
batch_size = 8
model_name = "t5-base-medium-title-generation"
model_dir = f"./t5-small-finetuned-wikisql/"

args = Seq2SeqTrainingArguments(
    model_dir,
    evaluation_strategy="steps",
    eval_steps=100,
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="steps",
    save_steps=200,
    learning_rate=4e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=False,
    load_best_model_at_end=True,
    metric_for_best_model="rouge1"
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer)

In [None]:
metric = load_metric("rouge")