<a href="https://colab.research.google.com/github/navneetkrc/Flair_SOTA_NLP/blob/master/Text_clickbait_detection_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text Classification Task using BERT (ClickBait detection task)

In this example we will demonstrate how to fine-tune and evaluate the BERT model on the ClickBait dataset using PyTorchWrapper.

#### Downloading Data
First of all we download and extract the data.

In [0]:
! mkdir -p data/

! wget -P data/ https://github.com/bhargaviparanjape/clickbait/raw/master/dataset/clickbait_data.gz
! wget -P data/ https://github.com/bhargaviparanjape/clickbait/raw/master/dataset/non_clickbait_data.gz

! gunzip -f data/clickbait_data.gz 
! gunzip -f data/non_clickbait_data.gz 


#### Additional libraries

Next we need to install the `tranformers` library in order use the pretrained BERT model.

In [0]:
!pip install transformers

In [0]:
!pip install pytorch_wrapper

#### Import Statements

In [5]:
import torch
import os
import random
import math

from torch import nn
from collections import Counter
from glob import glob
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SequentialSampler, SubsetRandomSampler
from tqdm.auto import tqdm
from transformers import BertModel, BertTokenizer

from pytorch_wrapper import modules, System
from pytorch_wrapper import functional as pwF
from pytorch_wrapper import evaluators as evaluators
from pytorch_wrapper.loss_wrappers import GenericPointWiseLossWrapper
from pytorch_wrapper.training_callbacks import EarlyStoppingCriterionCallback
from pytorch_wrapper.samplers import SubsetOrderedBatchWiseRandomSampler, SubsetOrderedSequentialSampler, \
    OrderedSequentialSampler
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


#### Dataset Definition
Next we create the ClickBaitDataset class. We will use the pretrained BPE tokenizer provided by the `transformers` library in order to prepare the input for the BERT model.

In [6]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


class ClickBaitDataset(Dataset):
    def __init__(self):
        self.ids = []
        self.texts = []
        self.texts_len = []
        self.targets = []

        current_id = 0
        for filename in ['clickbait_data', 'non_clickbait_data']:
            with open(f'data/{filename}') as fw:
                for line in tqdm(fw):
                    if line == '\n':
                        continue
                    self.ids.append(current_id)
                    text = bert_tokenizer.encode(line.lower(), add_special_tokens=True)
                    self.texts.append(text)
                    self.texts_len.append(len(text))
                    self.targets.append(filename == 'clickbait_data')

        self._shuffle_examples()

    def __getitem__(self, index):

        return (
            self.ids[index],
            (
                self.texts[index],
                self.texts_len[index]
            ),
            self.targets[index]
        )

    def __len__(self):
        return len(self.ids)

    def _shuffle_examples(self, seed=12345):
        """
        Shuffles the examples with the given seed.
        :param seed: The seed used for shuffling.
        """
        random.seed(seed)
        l = list(zip(self.ids, self.texts, self.texts_len, self.targets))
        random.shuffle(l)
        self.ids, self.texts, self.texts_len, self.targets = zip(*l)

    @staticmethod
    def collate_fn(batch):
        """
        Function that combines a list of examples into a batch (Called internally by dataloaders).
        """
        batch_zipped = list(zip(*batch))
        input_zipped = list(zip(*batch_zipped[1]))

        ids = batch_zipped[0]
        texts = torch.tensor(ClickBaitDataset.pad_to_max(input_zipped[0]), dtype=torch.long)
        texts_len = torch.tensor(input_zipped[1], dtype=torch.int)
        targets = torch.tensor(batch_zipped[2], dtype=torch.float)

        return {

            'id': ids,
            'input': [texts, texts_len],
            'target': targets
        }

    @staticmethod
    def pad_to_max(lst, max_len=None, pad_int=0):
        """
        Pads the given list of list of tokens to the maximum length.
        :param lst: List of list of tokens.
        """
        pad = len(max(lst, key=len))
        if max_len is not None:
            pad = min(max_len, pad)

        return [i + [pad_int] * (pad - len(i)) if len(i) <= pad else i[:pad] for i in lst]


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




#### Model Definition
In this example we will use the pretrained base uncased BERT model. 

In [0]:
class BERTModel(nn.Module):
    def __init__(self):
        super(BERTModel, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.output_linear = nn.Linear(768, 1)

    def forward(self, text, text_len):
        bert_last_hidden_states = self.bert(text)[0]
        mask = pwF.create_mask_from_length(text_len, bert_last_hidden_states.shape[-2], zeros_at_end=True)
        encoding = pwF.masked_mean_pooling(bert_last_hidden_states, mask, -2)
        return self.output_linear(encoding).squeeze(-1)


#### Training

Next we create the dataset object along with three data loaders (for training, validation, and testing). We will also make use of `SubsetOrderedBatchWiseRandomSampler` and `SubsetOrderedSequentialSampler` in order to batch together texts with similar lengths.

In [8]:
train_val_test_dataset = ClickBaitDataset()

eval_size = math.floor(0.1 * len(train_val_test_dataset))
train_val_test_dataset_indexes = list(range(len(train_val_test_dataset)))
train_split_indexes = train_val_test_dataset_indexes[2 * eval_size:]
val_split_indexes = train_val_test_dataset_indexes[eval_size:2 * eval_size]
test_split_indexes = train_val_test_dataset_indexes[:eval_size]

batch_size = 32
train_dataloader = DataLoader(
    train_val_test_dataset,
    sampler=SubsetOrderedBatchWiseRandomSampler(
        train_split_indexes,
        get_order_value_callable=lambda example_index: train_val_test_dataset[example_index][1][1],
        batch_size=batch_size
    ),
    batch_size=batch_size,
    collate_fn=ClickBaitDataset.collate_fn
)

val_dataloader = DataLoader(
    train_val_test_dataset,
    sampler=SubsetOrderedSequentialSampler(
        val_split_indexes,
        get_order_value_callable=lambda example_index: train_val_test_dataset[example_index][1][1]
    ),
    batch_size=batch_size,
    collate_fn=ClickBaitDataset.collate_fn
)

test_dataloader = DataLoader(
    train_val_test_dataset,
    sampler=SubsetOrderedSequentialSampler(
        test_split_indexes,
        get_order_value_callable=lambda example_index: train_val_test_dataset[example_index][1][1]
    ),
    batch_size=batch_size,
    collate_fn=ClickBaitDataset.collate_fn
)


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




Then we create the model and we wrap it with a `System` object.

In [9]:
model = BERTModel()

last_activation = nn.Sigmoid()
if torch.cuda.is_available():
    system = System(model, last_activation=last_activation, device=torch.device('cuda'))
else:
    system = System(model, last_activation=last_activation, device=torch.device('cpu'))


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=361.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Next we finetune the model on the training set, using a small learning rate (0.00005).

In [10]:
loss_wrapper = GenericPointWiseLossWrapper(nn.BCEWithLogitsLoss())
evals = {

    'acc': evaluators.AccuracyEvaluator(),
    'f1': evaluators.F1Evaluator(),
    'auc': evaluators.AUROCEvaluator()

}

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, system.model.parameters()),
    lr=0.00005
)

_ = system.train(
    loss_wrapper,
    optimizer,
    train_data_loader=train_dataloader,
    evaluators=evals,
    evaluation_data_loaders={
        'val': val_dataloader
    },
    callbacks=[
        EarlyStoppingCriterionCallback(
            patience=3,
            evaluation_data_loader_key='val',
            evaluator_key='f1',
            tmp_best_state_filepath='data/click_bait_cur_best.weights'
        )
    ]
)


--------------------------------------------------------------------------------

Epoch: 0

Training...



HBox(children=(FloatProgress(value=0.0, max=800.0), HTML(value='')))


Time elapsed: 4052

Evaluating...

val


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


acc: 98.84%
binary-f1: 0.9883
auroc: 0.9995

--------------------------------------------------------------------------------

Epoch: 1

Training...



HBox(children=(FloatProgress(value=0.0, max=800.0), HTML(value='')))


Time elapsed: 4067

Evaluating...

val


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


acc: 98.59%
binary-f1: 0.9857
auroc: 0.9992

--------------------------------------------------------------------------------

Epoch: 2

Training...



HBox(children=(FloatProgress(value=0.0, max=800.0), HTML(value='')))


Time elapsed: 4082

Evaluating...

val


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


acc: 98.84%
binary-f1: 0.9882
auroc: 0.9993

--------------------------------------------------------------------------------

Epoch: 3

Training...



HBox(children=(FloatProgress(value=0.0, max=800.0), HTML(value='')))


Time elapsed: 4098

Evaluating...

val


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


acc: 98.84%
binary-f1: 0.9883
auroc: 0.9994

--------------------------------------------------------------------------------

Epoch: 4

Training...



HBox(children=(FloatProgress(value=0.0, max=800.0), HTML(value='')))


Time elapsed: 4083

Evaluating...

val


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


acc: 98.59%
binary-f1: 0.9858
auroc: 0.9987

Epoch chosen: 0


Next we evaluate the model.

In [11]:
results = system.evaluate(test_dataloader, evals)
for r in results:
    print(results[r])


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


acc: 99.12%
binary-f1: 0.9913
auroc: 0.9992


We can also use the `predict` method in order to predict for all the examples returned by a `Dataloder`.

In [12]:
predictions = system.predict(test_dataloader, perform_last_activation=True)


HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [17]:
example_id = 6
input_loc = 1
text_loc = 0

print(bert_tokenizer.decode(train_val_test_dataset[example_id][input_loc][text_loc]))
print(predictions['outputs'][example_id])


[CLS] 7 very easy ways to eat healthier this week [SEP]
0.0005689897807314992


In [19]:
example_id = 4
input_loc = 1
text_loc = 0

print(bert_tokenizer.decode(train_val_test_dataset[example_id][input_loc][text_loc]))
print(predictions['outputs'][example_id])

[CLS] which classical composer best suits your taste [SEP]
0.9998045563697815


Finally we save the model's weights.

In [0]:
system.save_model_state('data/click_bait_final.weights')
