***BERT for sequence classification***

**Fine-Tuning BERT, Classification**

Remember the goal is to add a feed-forward layer on top of the pooler layer to classify the sequence under some chosen labels.

This is only one way to structure classification using BERT but it is the most common.

In [1]:
# imports

import numpy as np
from transformers import (Trainer, TrainingArguments,
                          DistilBertForSequenceClassification,
                          DistilBertTokenizerFast,
                          DataCollatorWithPadding, pipeline
                          )
from datasets import load_metric, Dataset
from functools import reduce

  from .autonotebook import tqdm as notebook_tqdm


Data collators generate batches of data for our pipelines.

In [2]:
# Loading the dataset

snips_file = open(r'data\snips.train.txt', 'rb')
snips_rows = snips_file.readlines()

snips_rows[:20]

[b'listen O\r\n',
 b'to O\r\n',
 b'westbam B-artist\r\n',
 b'alumb O\r\n',
 b'allergic B-album\r\n',
 b'on O\r\n',
 b'google B-service\r\n',
 b'music I-service\r\n',
 b'PlayMusic\r\n',
 b'\r\n',
 b'add O\r\n',
 b'step B-entity_name\r\n',
 b'to I-entity_name\r\n',
 b'me I-entity_name\r\n',
 b'to O\r\n',
 b'the O\r\n',
 b'50 B-playlist\r\n',
 b'cl\xc3\xa1sicos I-playlist\r\n',
 b'playlist O\r\n',
 b'AddToPlaylist\r\n']

In general when we wanna label entities we use an O if the entity means nothing, B if it is the beginning of an entity and I if it is a continuation of another token.

In [3]:
# Parsing the txt file into a more manageable format

utterances = []
tokenized_utterances = []
labels_for_tokens = []
sequence_labels = []

utterance, tokenized_utterance, label_for_utterances = '', [], []

for snip_row in snips_rows:
    if len(snip_row) == 2:  # Skip rows with no data
        continue
    if ' ' not in snip_row.decode():  # We've hit a sequence label
        sequence_labels.append(snip_row.decode().strip())
        utterances.append(utterance.strip())
        tokenized_utterances.append(tokenized_utterance)
        labels_for_tokens.append(label_for_utterances)
        utterance = ''
        tokenized_utterance = []
        label_for_utterances = []
        continue
    token, token_label = snip_row.decode().split(' ')
    token_label = token_label.strip()
    utterance += f'{token} '
    tokenized_utterance.append(token)
    label_for_utterances.append(token_label)



In [4]:
len(labels_for_tokens), len(tokenized_utterances), len(utterances), len(sequence_labels)

(13084, 13084, 13084, 13084)

In [5]:
# Instantiating tokenizer

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')

In [6]:
print(tokenized_utterances[0])
print(labels_for_tokens[0])
print(utterances[0])
print(sequence_labels[0])

['listen', 'to', 'westbam', 'alumb', 'allergic', 'on', 'google', 'music']
['O', 'O', 'B-artist', 'O', 'B-album', 'O', 'B-service', 'I-service']
listen to westbam alumb allergic on google music
PlayMusic


In [7]:
unique_sequence_labels = list(set(sequence_labels))
unique_sequence_labels

['SearchScreeningEvent',
 'AddToPlaylist',
 'PlayMusic',
 'RateBook',
 'BookRestaurant',
 'SearchCreativeWork',
 'GetWeather']

These are the seven intents that we will classify sequences into.

In [8]:
sequence_labels = [
    unique_sequence_labels.index(l) for l in sequence_labels
]

print(f'There are {len(sequence_labels)} unique sequence labels')

There are 13084 unique sequence labels


In [9]:
unique_token_labels = list(set(reduce(lambda x, y: x + y, labels_for_tokens)))
labels_for_tokens = [
    [unique_token_labels.index(_) for _ in l]
    for l in labels_for_tokens
]

print(f'There are {len(unique_token_labels)} unique token labels')

There are 72 unique token labels


In [10]:
print(tokenized_utterances[0])
print(labels_for_tokens[0])
print([unique_token_labels[l] for l in labels_for_tokens[0]])
print(utterances[0])
print(sequence_labels[0])
print(unique_sequence_labels[sequence_labels[0]])

['listen', 'to', 'westbam', 'alumb', 'allergic', 'on', 'google', 'music']
[30, 30, 26, 30, 54, 30, 50, 61]
['O', 'O', 'B-artist', 'O', 'B-album', 'O', 'B-service', 'I-service']
listen to westbam alumb allergic on google music
2
PlayMusic
