<a href="https://colab.research.google.com/github/darisoy/EE517_Sp21/blob/master/hw3/hw3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 🐍 Setup Python environment

In [1]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d8/b2/57495b5309f09fa501866e225c84532d1fd89536ea62406b2181933fb418/transformers-4.5.1-py3-none-any.whl (2.1MB)
[K     |▏                               | 10kB 17.3MB/s eta 0:00:01[K     |▎                               | 20kB 24.2MB/s eta 0:00:01[K     |▌                               | 30kB 29.2MB/s eta 0:00:01[K     |▋                               | 40kB 27.6MB/s eta 0:00:01[K     |▉                               | 51kB 15.1MB/s eta 0:00:01[K     |█                               | 61kB 13.2MB/s eta 0:00:01[K     |█▏                              | 71kB 14.7MB/s eta 0:00:01[K     |█▎                              | 81kB 14.4MB/s eta 0:00:01[K     |█▍                              | 92kB 15.1MB/s eta 0:00:01[K     |█▋                              | 102kB 16.4MB/s eta 0:00:01[K     |█▊                              | 112kB 16.4MB/s eta 0:00:01[K     |██                              | 

In [2]:
import numpy as np
import pandas as pd
import math
import torch
from tqdm.notebook import tqdm
from torch.utils.data import TensorDataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from tokenizers import decoders
from sklearn.metrics import classification_report

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 📀 Load the data

In [3]:
labels = {'O' : 0,
          'B-geo-loc' : 1,
          'I-geo-loc' : 2,
          'B-product' : 3,
          'I-product' : 4,
          'B-facility' : 5,
          'I-facility' : 6,
          'B-company' : 7,
          'I-company' : 8,
          'B-person' : 9,
          'I-person' : 10,
          'B-sportsteam' : 11,
          'I-sportsteam' : 12,
          'B-musicartist' : 13,
          'I-musicartist' : 14,
          'B-movie' : 15,
          'I-movie' : 16,
          'B-tvshow' : 17,
          'I-tvshow' : 18,
          'B-other' : 19,
          'I-other' : 20,
          }
end_token = '<END>'
beg_token = '<BEG>'

In [4]:
def get_sentences(df):
    sentences = []
    labels = []
    running_sentence = ''
    runnnig_label = []
    for idx, row in df.iterrows():
        if row.word == end_token:
            if len(running_sentence[:-1]) > 0:
                sentences.append(running_sentence[:-1])
                labels.append(runnnig_label)
            running_sentence = ''
            runnnig_label = []
        else:
            running_sentence += row.word + ' '
            runnnig_label.append(row.tag)
    return sentences, labels

def get_data(type):
    data = pd.read_csv('https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/data/' + type, delimiter='\t', names=["word", "tag"], skip_blank_lines=False, quoting=3)
    data = data.fillna({'word': end_token, 'tag': 'O'})
    data.tag = data.tag.apply((lambda x: labels[x]))
    return get_sentences(data)

# 🔐 Encode the data using BERT transformer

## Load the transformer

In [5]:
transformer_name = "distilbert-base-uncased"
transformer = DistilBertModel.from_pretrained(transformer_name)
tokenizer = DistilBertTokenizer.from_pretrained(transformer_name)
tokenizer.decoder = decoders.WordPiece()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




## Get dataset embeddings

In [6]:
def get_sublist_start_end(sl, l):
    results = []
    sll = len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):
        if tokenizer.decode(l[ind:ind+sll]) ==  tokenizer.decode(sl):
            results.append([ind,ind+sll-1])
    return results

def get_embeddings(sentences):
    transformer.eval()
    transformer.to(device)
    data = []
    for sentence in tqdm(sentences):
        with torch.no_grad():
            tokens = tokenizer.encode(sentence)
            out = transformer(torch.tensor(tokens).unsqueeze(0).to(device))
            embed = []
            for i, word in enumerate(sentence.split()):
                target = word
                target_ids = tokenizer.encode(target, add_special_tokens=False)
                target_idx = get_sublist_start_end(target_ids, tokens)[0]
                embed.append(torch.mean(out[0][0][target_idx[0]:target_idx[1]+1], 0))
            data.append(torch.stack(embed))
    return data

In [7]:
train_sentences, train_tags = get_data('train')
train_embeddings = get_embeddings(train_sentences)

HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))




In [57]:
valid_sentences, valid_tags = get_data('dev')
valid_embeddings = get_embeddings(valid_sentences)

HBox(children=(FloatProgress(value=0.0, max=1003.0), HTML(value='')))




In [58]:
test_sentences, test_tags = get_data('test')
test_embeddings = get_embeddings(test_sentences)

HBox(children=(FloatProgress(value=0.0, max=3860.0), HTML(value='')))




# 🧑‍💻 Classify the embeddings using RNN

In [42]:
# Model Definition
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.hidden_dim = 128
        self.num_layers = 2
        self.rnn = nn.GRU(768, self.hidden_dim, self.num_layers, batch_first=True) 
        self.fc = nn.Linear(self.hidden_dim, len(labels))
    
    def forward(self, x):
        h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_().to(device)
        out, (h) = self.rnn(x, (h.detach()))
        out = self.fc(out) 
        return out

In [43]:
classifier = RNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
epochs = 10

In [50]:
# train
for epoch in range(epochs):
    running_loss = 0.0
    for i, sentence in enumerate(tqdm(train_embeddings)):
        tags = torch.tensor(train_tags[i])
        sentence, tags = sentence.to(device), tags.to(device)
        optimizer.zero_grad()
        outputs = classifier(sentence.unsqueeze(dim=0))
        loss = criterion(outputs.squeeze(dim=0), tags)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('[Epoch %d]\tTrain Loss: \t\t%.3f' % (epoch+1, running_loss / len(train_embeddings)))

HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 1]	Train Loss: 		0.224


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 2]	Train Loss: 		0.145


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 3]	Train Loss: 		0.109


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 4]	Train Loss: 		0.079


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 5]	Train Loss: 		0.057


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 6]	Train Loss: 		0.045


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 7]	Train Loss: 		0.036


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 8]	Train Loss: 		0.028


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 9]	Train Loss: 		0.026


HBox(children=(FloatProgress(value=0.0, max=2394.0), HTML(value='')))


[Epoch 10]	Train Loss: 		0.020


# 📈 Evaluate valid and test data

## Helper functions

In [54]:
def evaluate(sentences, sentence_tags):
    correct = 0
    total = 0
    running_loss = 0
    truth = []
    preds = []
    for i, sentence in enumerate(tqdm(sentences)):
        tags = torch.tensor(sentence_tags[i])
        sentence, tags = sentence.to(device), tags.to(device)
        optimizer.zero_grad()

        outputs = classifier(sentence.unsqueeze(dim=0)).squeeze(dim=0)
        pred = outputs.squeeze(dim=1).argmax(dim=1)
        loss = criterion(outputs, tags)
        running_loss += loss.item()

        correct += torch.sum(tags == pred)
        total += len(tags)

        for t in sentence_tags[i]:
            truth.append(t)
        for p in pred:
            preds.append(p.item())

    print('Overall Accuracy: \t%.3f%% \tloss: %.3f' % (100*correct/total, running_loss/len(sentences)))
    return preds, truth

In [55]:
import sys
from collections import defaultdict

def split_tag(chunk_tag):
    """
    split chunk tag into IOBES prefix and chunk_type
    e.g. 
    B-PER -> (B, PER)
    O -> (O, None)
    """
    if chunk_tag == 0:
        return ('O', None)
    return list(labels.keys())[chunk_tag].split('-', maxsplit=1)

def is_chunk_end(prev_tag, tag):
    """
    check if the previous chunk ended between the previous and current word
    e.g. 
    (B-PER, I-PER) -> False
    (B-LOC, O)  -> True
    Note: in case of contradicting tags, e.g. (B-PER, I-LOC)
    this is considered as (B-PER, B-LOC)
    """
    prefix1, chunk_type1 = split_tag(prev_tag)
    prefix2, chunk_type2 = split_tag(tag)

    if prefix1 == 'O':
        return False
    if prefix2 == 'O':
        return prefix1 != 'O'

    if chunk_type1 != chunk_type2:
        return True

    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']

def is_chunk_start(prev_tag, tag):
    """
    check if a new chunk started between the previous and current word
    """
    prefix1, chunk_type1 = split_tag(prev_tag)
    prefix2, chunk_type2 = split_tag(tag)

    if prefix2 == 'O':
        return False
    if prefix1 == 'O':
        return prefix2 != 'O'

    if chunk_type1 != chunk_type2:
        return True

    return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']


def calc_metrics(tp, p, t, percent=True):
    """
    compute overall precision, recall and FB1 (default values are 0.0)
    if percent is True, return 100 * original decimal value
    """
    precision = tp / p if p else 0
    recall = tp / t if t else 0
    fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0
    if percent:
        return 100 * precision, 100 * recall, 100 * fb1
    else:
        return precision, recall, fb1


def count_chunks(true_seqs, pred_seqs):
    """
    true_seqs: a list of true tags
    pred_seqs: a list of predicted tags
    return: 
    correct_chunks: a dict (counter), 
                    key = chunk types, 
                    value = number of correctly identified chunks per type
    true_chunks:    a dict, number of true chunks per type
    pred_chunks:    a dict, number of identified chunks per type
    correct_counts, true_counts, pred_counts: similar to above, but for tags
    """
    correct_chunks = defaultdict(int)
    true_chunks = defaultdict(int)
    pred_chunks = defaultdict(int)

    correct_counts = defaultdict(int)
    true_counts = defaultdict(int)
    pred_counts = defaultdict(int)

    prev_true_tag, prev_pred_tag = 0, 0
    correct_chunk = None

    for true_tag, pred_tag in zip(true_seqs, pred_seqs):
        if true_tag == pred_tag:
            correct_counts[true_tag] += 1
        true_counts[true_tag] += 1
        pred_counts[pred_tag] += 1

        _, true_type = split_tag(true_tag)
        _, pred_type = split_tag(pred_tag)

        if correct_chunk is not None:
            true_end = is_chunk_end(prev_true_tag, true_tag)
            pred_end = is_chunk_end(prev_pred_tag, pred_tag)

            if pred_end and true_end:
                correct_chunks[correct_chunk] += 1
                correct_chunk = None
            elif pred_end != true_end or true_type != pred_type:
                correct_chunk = None

        true_start = is_chunk_start(prev_true_tag, true_tag)
        pred_start = is_chunk_start(prev_pred_tag, pred_tag)

        if true_start and pred_start and true_type == pred_type:
            correct_chunk = true_type
        if true_start:
            true_chunks[true_type] += 1
        if pred_start:
            pred_chunks[pred_type] += 1

        prev_true_tag, prev_pred_tag = true_tag, pred_tag
    if correct_chunk is not None:
        correct_chunks[correct_chunk] += 1

    return (correct_chunks, true_chunks, pred_chunks, 
        correct_counts, true_counts, pred_counts)

def get_result(correct_chunks, true_chunks, pred_chunks,
    correct_counts, true_counts, pred_counts, verbose=True):
    """
    if verbose, print overall performance, as well as preformance per chunk type;
    otherwise, simply return overall prec, rec, f1 scores
    """
    # sum counts
    sum_correct_chunks = sum(correct_chunks.values())
    sum_true_chunks = sum(true_chunks.values())
    sum_pred_chunks = sum(pred_chunks.values())

    sum_correct_counts = sum(correct_counts.values())
    sum_true_counts = sum(true_counts.values())
    O_correct_counts = sum(v for k, v in correct_counts.items() if k == 0)
    O_true_counts = sum(v for k, v in true_counts.items() if k == 0)
    O_pred_counts = sum(v for k, v in pred_counts.items() if k == 0)

    chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks))))

    # compute overall precision, recall and FB1 (default values are 0.0)
    prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks)
    res = (prec, rec, f1)
    if not verbose:
        return res

    print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='')
    print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='')
    print()
    print("%i Entity Types:" % (len(chunk_types)))
    print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='')
    print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1))

    for t in chunk_types:
        prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t])
        print("%17s: " %t , end='')
        print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" %
                    (prec, rec, f1), end='')
        print("  %d" % pred_chunks[t])

    print()
    print("No Types: ")
    print("accuracy: %6.2f%%; " % (100*O_correct_counts/O_true_counts), end='')
    prec, rec, f1 = calc_metrics(O_correct_counts, O_pred_counts, O_true_counts)
    print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1), end='')
    print("  %d" % O_pred_counts)
    return res

def ConLLEval(true_seqs, pred_seqs, verbose=True):
    (correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs)
    result = get_result(correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts, verbose=verbose)
    return result

## Run

In [59]:
print('[Validation Data]')
preds, truth = evaluate(valid_embeddings, valid_tags)
print()
ConLLEval(truth, preds)

[Validation Data]


HBox(children=(FloatProgress(value=0.0, max=1003.0), HTML(value='')))


Overall Accuracy: 	94.371% 	loss: 0.324

processed 16256 tokens with 661 phrases; found: 644 phrases; correct: 252.

10 Entity Types:
accuracy:  94.37%; precision:  39.13%; recall:  38.12%; FB1:  38.62
          company: precision:  20.00%; recall:  46.15%; FB1:  27.91  90
         facility: precision:  15.00%; recall:  15.79%; FB1:  15.38  40
          geo-loc: precision:  49.30%; recall:  60.34%; FB1:  54.26  142
            movie: precision:   0.00%; recall:   0.00%; FB1:   0.00  4
      musicartist: precision:  13.64%; recall:   7.32%; FB1:   9.52  22
            other: precision:   9.92%; recall:   9.09%; FB1:   9.49  121
           person: precision:  66.30%; recall:  70.18%; FB1:  68.18  181
          product: precision:  10.00%; recall:   2.70%; FB1:   4.26  10
       sportsteam: precision:  70.00%; recall:  30.00%; FB1:  42.00  30
           tvshow: precision:  25.00%; recall:  50.00%; FB1:  33.33  4

No Types: 
accuracy:  98.66%; precision:  97.04%; recall:  98.66%; FB1:  97

(39.130434782608695, 38.12405446293495, 38.62068965517241)

In [60]:
print('[Test Data]')
preds, truth = evaluate(test_embeddings, test_tags)
print()
ConLLEval(truth, preds)

[Test Data]


HBox(children=(FloatProgress(value=0.0, max=3860.0), HTML(value='')))


Overall Accuracy: 	92.765% 	loss: 0.492

processed 61880 tokens with 3473 phrases; found: 3218 phrases; correct: 1301.

10 Entity Types:
accuracy:  92.77%; precision:  40.43%; recall:  37.46%; FB1:  38.89
          company: precision:  45.57%; recall:  47.99%; FB1:  46.75  654
         facility: precision:  17.44%; recall:  11.86%; FB1:  14.12  172
          geo-loc: precision:  54.44%; recall:  61.11%; FB1:  57.59  990
            movie: precision:  11.11%; recall:   2.94%; FB1:   4.65  9
      musicartist: precision:  21.54%; recall:   7.33%; FB1:  10.94  65
            other: precision:  10.60%; recall:   9.08%; FB1:   9.78  500
           person: precision:  48.80%; recall:  67.43%; FB1:  56.62  666
          product: precision:   6.98%; recall:   1.22%; FB1:   2.08  43
       sportsteam: precision:  36.63%; recall:  25.17%; FB1:  29.84  101
           tvshow: precision:   5.56%; recall:   3.03%; FB1:   3.92  18

No Types: 
accuracy:  98.62%; precision:  95.94%; recall:  98.62%; F

(40.42883778744562, 37.460408868413474, 38.88805858616051)