In [62]:
from transformers import Trainer, TrainingArguments, DataCollatorForTokenClassification, \
                          DistilBertForTokenClassification, DistilBertTokenizerFast, pipeline
from functools import reduce
from datasets import Dataset
import torch

In [41]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

In [42]:
snips_file = open('data/snips.train.txt', 'rb')
snips_rows = snips_file.readlines()

utterances = []
tokenized_utterances = []
labels_for_tokens = []
sequence_labels = []

utterance, tokenized_utterance, labels_for_utterances = '', [], []
for snip_row in snips_rows:
    if len(snip_row) == 2:
        continue
    if ' ' not in snip_row.decode():
        sequence_labels.append(snip_row.decode().strip())
        utterances.append(utterance.strip())
        tokenized_utterances.append(tokenized_utterance)
        labels_for_tokens.append(labels_for_utterances)
        utterance, tokenized_utterance, labels_for_utterances = '', [], []
        continue
    token, token_label = snip_row.decode().split(' ')
    token_label = token_label.strip()
    utterance += f'{token} '
    tokenized_utterance.append(token)
    labels_for_utterances.append(token_label)

unique_token_labels = list(set(reduce(lambda x, y: x + y, labels_for_tokens)))
labels_for_tokens = [[unique_token_labels.index(_) for _ in l] for l in labels_for_tokens]

snips_dataset = Dataset.from_dict(
    dict(
        utterance=utterances,
        label=sequence_labels,
        tokens=tokenized_utterances,
        token_labels=labels_for_tokens
    )
)
snips_dataset = snips_dataset.train_test_split(test_size=0.2)

In [43]:
tokenized_inputs = tokenizer(snips_dataset['train'][0]['tokens'], truncation=True, is_split_into_words=True)
tokenized_inputs

{'input_ids': [101, 2424, 1996, 13675, 21104, 1997, 2158, 1024, 2242, 10433, 2112, 1016, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [44]:
tokenizer.convert_ids_to_tokens([101, 2064, 2017, 5587, 2198, 8040, 11253, 12891, 14751, 8694, 1996, 7367, 2571, 25509, 2953, 2000, 1996, 2377, 9863, 102])

['[CLS]',
 'can',
 'you',
 'add',
 'john',
 'sc',
 '##of',
 '##ield',
 'newest',
 'tune',
 'the',
 'se',
 '##le',
 '##kt',
 '##or',
 'to',
 'the',
 'play',
 '##list',
 '[SEP]']

In [45]:
tokenized_inputs.word_ids(batch_index=0)

[None, 0, 1, 2, 2, 3, 4, 4, 5, 6, 7, 8, None]

In [46]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples['token_labels']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [47]:
snips_dataset['train'][0]

{'utterance': 'find the crucible of man: something wicked part 2',
 'label': 'SearchCreativeWork',
 'tokens': ['find',
  'the',
  'crucible',
  'of',
  'man:',
  'something',
  'wicked',
  'part',
  '2'],
 'token_labels': [3, 69, 28, 28, 28, 28, 28, 28, 28]}

In [48]:
tok_clf_tokenized_snips = snips_dataset.map(tokenize_and_align_labels, batched=True)
tok_clf_tokenized_snips

Map:   0%|          | 0/10467 [00:00<?, ? examples/s]

Map:   0%|          | 0/2617 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['utterance', 'label', 'tokens', 'token_labels', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 10467
    })
    test: Dataset({
        features: ['utterance', 'label', 'tokens', 'token_labels', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2617
    })
})

In [49]:
tok_clf_tokenized_snips['train'][0]

{'utterance': 'find the crucible of man: something wicked part 2',
 'label': 'SearchCreativeWork',
 'tokens': ['find',
  'the',
  'crucible',
  'of',
  'man:',
  'something',
  'wicked',
  'part',
  '2'],
 'token_labels': [3, 69, 28, 28, 28, 28, 28, 28, 28],
 'input_ids': [101,
  2424,
  1996,
  13675,
  21104,
  1997,
  2158,
  1024,
  2242,
  10433,
  2112,
  1016,
  102],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': [-100, 3, 69, 28, -100, 28, 28, -100, 28, 28, 28, 28, -100]}

In [50]:
tokenizer.decode([11253])

'##of'

In [51]:
tok_clf_tokenized_snips['train'] = tok_clf_tokenized_snips['train'].remove_columns(
    ['utterance', 'label', 'tokens', 'token_labels']
)
tok_clf_tokenized_snips['test'] = tok_clf_tokenized_snips['test'].remove_columns(
    ['utterance', 'label', 'tokens', 'token_labels']
)
tok_clf_tokenized_snips

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 10467
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2617
    })
})

In [52]:
tok_data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
tok_data_collator

DataCollatorForTokenClassification(tokenizer=DistilBertTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}), padding=True, max_length=None, pad_to_multiple_of=None, label_pad_token_id=-100, return_tensors='pt')

In [53]:
tok_clf_model = DistilBertForTokenClassification.from_pretrained(
    'distilbert-base-uncased', num_labels=len(unique_token_labels)
)
tok_clf_model.config.id2label = {i: l for i, l in enumerate(unique_token_labels)}

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN t

In [54]:
tok_clf_model.config.id2label[0], tok_clf_model.config.id2label[1]

('I-track', 'B-city')

In [57]:
epochs = 5
train_args = TrainingArguments(
    output_dir='snip_tok_clf/results',
    num_train_epochs=epochs,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    load_best_model_at_end=True,

    logging_steps=10,
    log_level='info',
    evaluation_strategy='epoch',
    save_strategy='epoch',
)
trainer = Trainer(
    model=tok_clf_model,
    args=train_args,
    train_dataset=tok_clf_tokenized_snips['train'],
    eval_dataset=tok_clf_tokenized_snips['test'],
    data_collator=tok_data_collator,
)

In [58]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 2617
  Num examples = 2617
  Batch size = 32
You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 4.189548969268799,
 'eval_runtime': 1.9956,
 'eval_samples_per_second': 1311.373,
 'eval_steps_per_second': 41.09}

In [59]:
trainer.train()

***** Running training *****
  Num examples = 10467
  Num Epochs = 5
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 1640
  Number of trainable parameters = 66418248


Epoch,Training Loss,Validation Loss
1,0.2009,0.163156


TrainOutput(global_step=1640, training_loss=0.18485487074386783, metrics={'train_runtime': 102.0873, 'train_samples_per_second': 512.649, 'train_steps_per_second': 16.065, 'total_flos': 289773121921920.0, 'train_loss': 0.18485487074386783, 'epoch': 5.0})

In [60]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 2617
  Batch size = 32
  Num examples = 2617
  Batch size = 32


{'eval_loss': 0.0967433974146843,
 'eval_runtime': 1.0897,
 'eval_samples_per_second': 2401.583,
 'eval_steps_per_second': 75.25,
 'epoch': 5.0}

In [63]:
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else 'cpu')
pipe = pipeline('token-classification', tok_clf_model, tokenizer=tokenizer, device=device)
pipe('Add Two Coins by Dispatch to my road trip playlist')

[{'entity': 'B-entity_name',
  'score': 0.9867512,
  'index': 2,
  'word': 'two',
  'start': 4,
  'end': 7},
 {'entity': 'I-entity_name',
  'score': 0.9941801,
  'index': 3,
  'word': 'coins',
  'start': 8,
  'end': 13},
 {'entity': 'B-artist',
  'score': 0.9822044,
  'index': 5,
  'word': 'dispatch',
  'start': 17,
  'end': 25},
 {'entity': 'B-playlist_owner',
  'score': 0.99733764,
  'index': 7,
  'word': 'my',
  'start': 29,
  'end': 31},
 {'entity': 'B-playlist',
  'score': 0.9975948,
  'index': 8,
  'word': 'road',
  'start': 32,
  'end': 36},
 {'entity': 'I-playlist',
  'score': 0.9981085,
  'index': 9,
  'word': 'trip',
  'start': 37,
  'end': 41}]

In [64]:
pipe = pipeline('token-classification', tok_clf_model, tokenizer=tokenizer, device=device)
pipe('Rate The Principles of Data Science 5 out of 5')

[{'entity': 'B-object_name',
  'score': 0.99552095,
  'index': 2,
  'word': 'the',
  'start': 5,
  'end': 8},
 {'entity': 'I-object_name',
  'score': 0.99341583,
  'index': 3,
  'word': 'principles',
  'start': 9,
  'end': 19},
 {'entity': 'I-object_name',
  'score': 0.9973712,
  'index': 4,
  'word': 'of',
  'start': 20,
  'end': 22},
 {'entity': 'I-object_name',
  'score': 0.99743384,
  'index': 5,
  'word': 'data',
  'start': 23,
  'end': 27},
 {'entity': 'I-object_name',
  'score': 0.99751866,
  'index': 6,
  'word': 'science',
  'start': 28,
  'end': 35},
 {'entity': 'B-rating_value',
  'score': 0.998423,
  'index': 7,
  'word': '5',
  'start': 36,
  'end': 37},
 {'entity': 'B-rating_value',
  'score': 0.6980755,
  'index': 10,
  'word': '5',
  'start': 45,
  'end': 46}]