# BERT Lab 3: Document Classification

In this third lab, you will see how to fine-tune BERT for document classification using the Wikipedia Personal Attacks as an example.

In the last lab, you have applied BERT to sentence classification. In this lab, you will look at longer pieces of text.

In [2]:
import numpy as np
import logging

import matplotlib.pyplot as plt
import seaborn as sns
from keras.preprocessing.sequence import pad_sequences
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
import time

### Setup

As always, let's start with setting the correct device and installing the transformers library

In [None]:
import torch

if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    assert False, "Please select GPU in the Colab"

In [None]:
!pip install transformers

### Dataset

As stated earlier, you will use Wikipedia Personal Attacks. The dataset is a corpus of discussion comments from English Wikipedia talk pages. Comments are grouped into different files by year. The objective is to classify if the comments contain some personal attacks or not. So, like CoLA dataset, the task is a binary classification task.

First, download the dataset and then use `pandas` to parse the `.tsv` file, and use the BERT tokenizer to pre-process the input data for your model.

In [None]:
# Downloading the dataset
import urllib.request as request
import os

if not os.path.exists("./Wikipedia_Talk_Labels"):
    os.mkdir("./Wikipedia_Talk_Labels")

files = [
    ("./Wikipedia_Talk_Labels/attack_annotated_comments.tsv", "https://ndownloader.figshare.com/files/7554634"),
    ("./Wikipedia_Talk_Labels/attack_annotations.tsv", "https://ndownloader.figshare.com/files/7554637")
]

for (filename, url) in files:
    if not os.path.exists(filename):
        print("Downloading", filename)
        request.urlretrieve(url, filename)
        print("    Done ....")

In [None]:
# Parse file
import pandas as pd
print("Parsing the dataset .tsv file ...")

comments = pd.read_csv(files[0][0], sep="\t", index_col=0)
annotations = pd.read_csv(files[1][0], sep="\t")

dataset_sizes = comments[["comment", "split"]].groupby("split").count()
print("Dataset Sizes: {}".format(dataset_sizes))

Let's see a few row to see what we have.

In [None]:
# Display the first row of the table
comments.head()

In [None]:
annotations.head()

We are provided with three splits, let's see how big are each one.

In [None]:
# See the sizes of each split
comments[["comment", "split"]].groupby('split').count()

#### Labels

The comments are uniquely identified by their 'rev_id'. The annotation table has multiple row for each comment because they are multiple labelers. You will consider a comment as an attack if the majority of the annotators agree that it is an attack.

In [None]:
# A comment is an attack if the majority of the annotators agree
labels = annotations.groupby("rev_id")["attack"].mean() > 0.5
comments["attack"] = labels

# remove newline and tab tokens
comments["comment"] = comments["comment"].apply(lambda x: x.replace("NEWLINE_TOKEN", " "))
comments["comment"] = comments["comment"].apply(lambda x: x.replace("TAB_TOKEN", " "))

comments.head(10)

Now, let's divide the dataset into training and validation and testing comments

In [12]:
# Splits

train_comments = comments.query("split=='train'")
val_comments = comments.query("split=='dev'")
test_comments = comments.query("split=='test'")

Let's display some of the comments labeled as containing an attack. We'll use `textwrap` to wrap a single paragraph in text, and return a single string containing the wrapped paragraph.

In [None]:
import textwrap
import random

# Get some positive samples (comments with attacks)
wrapper = textwrap.TextWrapper(width=80)
attack_examples = train_comments.query("attack")["comment"]

for i in range(10):
    j = random.choice(attack_examples.index)
    print(wrapper.fill(attack_examples[j]))
    print(" ---------- ")

Some stats on the distribution of labels.

In [None]:
total_comments = comments.shape[0]
num_attacks = comments.query("attack").shape[0]

print(f"Percentage of attack comments {(num_attacks / total_comments) * 100:.2f}%")

As we can see, this is a highly imbalanced classes, so predicting 0 will give us 88% accuracy

In [None]:
labels = train_comments.attack.to_numpy().astype(int)
print(f"Number of positive comments {labels.sum()} and negative ones {len(labels) - labels.sum()}")

## Tokenization & BERT input length limitation

Like discussed in the introduction, BERT has a maximum of 512 tokens. In this part of the lab, you will look at how this limitation can affect in practice and study some possible approaches to address this limitation.

Let's start by seeing the distribution of lengths of the tokenized comments.

In [None]:
from tqdm.notebook import tqdm
from transformers import BertTokenizer

# Avoid printing warnings
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

# Tokenizer the comments
input_ids, lengths = [], []

# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

print("Tokenizing comments ....")
input_ids = [tokenizer.encode(s, add_special_tokens=True) for s in tqdm(train_comments.comment)]
lengths = [len(input_id) for input_id in input_ids]

In [None]:
# Plot the distribution of comment lengths

# Set style
sns.set(style="darkgrid")
sns.set(font_scale=1.5)
plt.rcParams["figure.figsize"] = (10,5)

# Consider all length > 512 as equal to 512, if not the distribution will be skewed
lengths_shortened = [min(l, 512) for l in lengths]
sns.distplot(lengths_shortened, kde=False, rug=False)

plt.xlabel("Comments length")
plt.ylabel("# Comments")
plt.title("Comment lengths")
plt.show()

#### Solutions:

How can we solve this problem? There's no obvious solution. There are two simple and straightforward approaches that we can try:


**Truncation:**
The first solution is to simply drop some of the tokens, and hope that the remaining text is enough to predict the correct label.
But what tokens have to be dropped? We can drop:
- From the beginning.
- At the end.
- In this middle.
- At a random starting position.

In this [paper](https://arxiv.org/abs/1905.05583), the auhors experimented on IMDb movie dataset, and showed that keeping the first 125 tokens and the last 382 tokens is ok. The total number of tokens remaining is 512, leaving room for the two special tokens `[CLS]` and `[SEP]` that are neccessary to append for BERT fine-tunning.

**Chunking:** 
The second solution is Chunking. It consists in dividing the test into 512-token chunks and generate embeddings for each of these chunks. These embeddings are then combined (like a simple summation or other pooling strategies) before performing the final classification.

##  <span style="color:red">Your turn. </span>
We propose to use Truncation, and truncate all the inputs to max length of 128, similar to the second lab. You will need to do:

1. **You first need the tokens if the lengh of the input sequence is > 128.**
1. **Pad the sequences of length < 128.**
1. **Create the attention masks to differentiate between the padding tokens and the rest of tokens.**
1. **Split the data into train and validation: 90 %and 10% as in lab 2.**
1. **Create the datasets and the dataloaders.**
1. **Create the model and the optimizer using the same hyperparameters as lab 2, but only use 1 epoch.**
1. **Fine-tune the BERT model for classification. Only train for one epoch, because the dataset is a big, and it will take approx. 30min.**
1. **Apply the same preprocessing to the test data (tokenizer, truncate, create dataset and dataloader).**
1. **Apply the model to the test data.**
1. **Given how imbalanced the dataset is, compute ROC AUC metric using [roc_auc_score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html) on the test data. <span style="color:red">The results should be > 0.95 </span>** 

Please use the second lab as a reference, the changes to be made are minimal.

In [None]:
# Your turn