In [154]:
import torch
from transformers import T5TokenizerFast, T5ForConditionalGeneration

tokenizer = T5TokenizerFast.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

loading file spiece.model from cache at /home/ubuntu/.cache/huggingface/hub/models--t5-small/snapshots/3479082dc36f8a4730936ef1c9b88cd8b0835c53/spiece.model
loading file tokenizer.json from cache at /home/ubuntu/.cache/huggingface/hub/models--t5-small/snapshots/3479082dc36f8a4730936ef1c9b88cd8b0835c53/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at None
loading configuration file config.json from cache at /home/ubuntu/.cache/huggingface/hub/models--t5-small/snapshots/3479082dc36f8a4730936ef1c9b88cd8b0835c53/config.json
Model config T5Config {
  "_name_or_path": "t5-small",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,


In [155]:
import datasets

from transformers.models.bart.modeling_bart import shift_tokens_right
import random
import torch
random.seed(42)

def random_mask(query):
    words = query.split()
    if len(words) < 3:
        return query
    mask_index = random.randint(1, len(words) - 1)
    return ' '.join(words[:mask_index]) + '<extra_id_0>'

def convert_to_features(batch):
    contexts = ['# '.join(passages['passage_text']) for passages in batch['passages']]
    masked_queries = list(map(random_mask, batch['query']))
    inputs = [query + '; ' + context for context, query in zip(contexts, masked_queries)]

    input_encodings = tokenizer.batch_encode_plus(inputs, pad_to_max_length=True, max_length=512, truncation=True, return_tensors='pt')
    label_encodings = tokenizer.batch_encode_plus(batch['query'], pad_to_max_length=True, max_length=512, truncation=True, return_tensors='pt')
    labels = label_encodings['input_ids']
    # decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
    labels[labels[:,:] == model.config.pad_token_id] = -100
    
    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        # 'decoder_input_ids': decoder_input_ids,
        'labels': labels,
        'masked_queries': masked_queries,
    }

    return encodings

from tqdm import tqdm

dataset = datasets.load_dataset('ms_marco', 'v2.1', split='train[:5000]')
dataset = dataset.map(convert_to_features, batched=True, batch_size=8)
dataset = dataset.train_test_split(test_size=0.1)

Found cached dataset ms_marco (/home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)
Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84/cache-0558b24b1bbeb703.arrow


In [156]:
from transformers.trainer import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir='./models/bart-summarizer',
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    do_train=True,
    do_eval=True,
    warmup_steps=500,   
    weight_decay=0.01,
    logging_dir='./logs',
    learning_rate=3e-4,
    logging_steps=100,

)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [157]:
trainer.train()

Step,Training Loss
100,4.9589
200,2.1489
300,1.5874
400,1.5092
500,1.5672
600,1.633
700,1.4458
800,1.4135
900,1.3478
1000,1.453


TrainOutput(global_step=4500, training_loss=1.4219490424262153, metrics={'train_runtime': 297.8021, 'train_samples_per_second': 15.111, 'train_steps_per_second': 15.111, 'total_flos': 609038106624000.0, 'train_loss': 1.4219490424262153, 'epoch': 1.0})

In [158]:
trainer.evaluate()

{'eval_loss': 1.1567209959030151,
 'eval_runtime': 6.2418,
 'eval_samples_per_second': 80.105,
 'eval_steps_per_second': 80.105,
 'epoch': 1.0}

In [159]:
for i in range(0, 10):
    print('Actual:    ', dataset['test'][i]['query'])
    print('Query:     ', dataset['test'][i]['masked_queries'])
    to_encode = dataset['test'][i]['masked_queries'] + '; ' + '# '.join(dataset['test'][i]['passages']['passage_text'])
    encoded = tokenizer(
        to_encode,
        pad_to_max_length=True,
        max_length=1024,
        truncation=True,
        return_tensors='pt')
    output = model.generate(input_ids=encoded['input_ids'].to('cuda'), max_length=1024, num_beams=4, early_stopping=True)
    print('Predicted: ', tokenizer.decode(output[0], skip_special_tokens=True))
    print('---------------------')

Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}



Actual:     how to make the color rose gold
Query:      how to make<extra_id_0>
Predicted:  how to make rose gold
---------------------
Actual:     names moby dick is called
Query:      names moby<extra_id_0>
Predicted:  names moby dick
---------------------
Actual:     us bank payoff mailing address
Query:      us bank payoff<extra_id_0>


Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}



Predicted:  us bank payoff number
---------------------
Actual:     what is cwr means
Query:      what<extra_id_0>
Predicted:  what is ecn
---------------------
Actual:     how to open 1972 vw beetle hood
Query:      how<extra_id_0>
Predicted:  how to open the hood on your beetle
---------------------
Actual:     what does body habitus mean
Query:      what does body<extra_id_0>


Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}



Predicted:  what does body habitus mean
---------------------
Actual:     what causes learning disabilities?
Query:      what causes learning<extra_id_0>
Predicted:  what causes learning disabilities
---------------------
Actual:     is there a way to merge to cells with writing in them together
Query:      is there a way to merge to<extra_id_0>
Predicted:  is there a way to merge to excel
---------------------
Actual:     the pointer sisters neutron dance
Query:      the pointer sisters<extra_id_0>


Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}



Predicted:  the pointer sisters lyrics
---------------------
Actual:     does adenosine triphosphate give energy
Query:      does adenosine triphosphate<extra_id_0>
Predicted:  does adenosine triphosphate work
---------------------


In [162]:
trainer.save_model('/home/ubuntu/models/t5-small-autocomplete')

Saving model checkpoint to /home/ubuntu/models/t5-small-autocomplete
Configuration saved in /home/ubuntu/models/t5-small-autocomplete/config.json
Configuration saved in /home/ubuntu/models/t5-small-autocomplete/generation_config.json
Model weights saved in /home/ubuntu/models/t5-small-autocomplete/pytorch_model.bin
