# FINE TUNE TRANSFORMER MODEL ON CUSTOM DATASET

This is an adaptation of a new repo on using transformer models to detect state trolls on Twitter. I reckon many might not be interested in the subject matter itself, so I ported over the Colab notebook on fine tuning with custom dataset for folks who are looking to try this out.

This notebook took about 5.5 hours to run on a Colab Pro account on TPU and "high-RAM" settings. It could run slower or faster depending on your set-up. The datasets needed - train_raw.csv and validate.csv - are in the [data folder](https://github.com/chuachinhon/practical_nlp/tree/master/data) of this repo.

The fine tuned Distilbert model at the end of this notebook is too big for Github, but you can download it [here](https://www.dropbox.com/sh/90h7ymog2oi5yn7/AACTuxmMTcso6aMxSmSiD8AVa) from Dropbox instead.

Details on the data collection, preparation + comparisons with other machine learning models are available in my [separate repo](https://github.com/chuachinhon/transformers_state_trolls_cch) on using transformers to detect state trolls on Twitter

In this Colab notebook, we'll fine tune the Distilbert model on about 90K rows of troll+real tweets using Hugging Face's [trainer](https://huggingface.co/transformers/master/main_classes/trainer.html). 10K rows have been set aside as validation data to see how the fine tuned model performs on tweets it has not seen. 

The proportion of troll Vs real tweets in the datasets was kept to a 50-50 split for practical reasons - I don't think anyone outside of Twitter know what's the *real world* mix of state troll Vs real tweets at any one point.

In any case, it makes sense to let the model be equally exposed to both types of tweets during training.

But these are my assumptions for this project. If you believe a different proportion of state trolls Vs real tweets works better, change it up by all means.


## REFERENCES:

Hugging Face has very easy to follow examples on its site, and I've modelled the code below mostly from these two pages:

* [Fine-tuning with custom datasets](https://huggingface.co/transformers/master/custom_datasets.html)

* [Trainer (documentation)](https://huggingface.co/transformers/master/main_classes/trainer.html)

In [1]:
! pip -q install transformers

[K     |████████████████████████████████| 890kB 3.4MB/s 
[K     |████████████████████████████████| 3.0MB 15.7MB/s 
[K     |████████████████████████████████| 1.1MB 44.2MB/s 
[K     |████████████████████████████████| 890kB 51.2MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [2]:
import numpy as np
import os
import pandas as pd
import torch

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import (
    DistilBertTokenizerFast,
    DistilBertForSequenceClassification,
    Trainer,
    TrainingArguments,
)


In [3]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [4]:
os.chdir("/content/drive/My Drive/Colab Notebooks")

In [5]:
# switch up the dir path as needed per your Colab/GDrive folder 

raw = pd.read_csv("train_raw.csv")

In [6]:
raw.shape

(89948, 5)

In [7]:
raw.head()

Unnamed: 0,tweetid,user_display_name,tweet_text,clean_text,troll_or_not
0,1245883557362282497,85c9M6CDZxgBwoEye0rF12ZBgGl3xvz6Bnbvhp7MUKI=,"having each tiny wish come true, or having som...",having each tiny wish come true or having some...,1
1,961577921461866496,曲剑明,＠null It is 12:25 UTC now,null It is UTC now,1
2,941616158075211776,IFL1E0m0SRX2cdOtuLFV7xKtnBgxagKzNgkuGFvNtvs=,British number two Bedene to switch back to Sl...,British number two Bedene to switch back to Sl...,1
3,850414479976345600,Klausv,kalamitykait Thanks for bearing with us - you ...,kalamitykait Thanks for bearing with us you sh...,1
4,960784360071925760,曲剑明,＠null It is 08:56 CET now,null It is CET now,1


## 1.1: PREPARING THE DATA

In [8]:
# Train-test split the main training dataset via the familiar scikit-learn feature

X = list(raw["clean_text"].values)
y = list(raw["troll_or_not"].values)


train_texts, test_texts, train_labels, test_labels = train_test_split(
    X, y, random_state=42, test_size=0.2, stratify=y
)


## 1.2: TOKENIZATION + TURN LABELS & ENCODINGS INTO A DATASET OBJECT

In [9]:
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [10]:
class tweetsdataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


train_dataset = tweetsdataset(train_encodings, train_labels)
test_dataset = tweetsdataset(test_encodings, test_labels)

## 1.3: FINE TUNE WITH TRAINER (NO HYPERPARAMETERS SEARCH AT THIS POINT)

Prior to setting the parameters for training, we'll define a function for the usual metrics. The big question(s) here is how one would know how many epochs to run, the "best" learning rate etc.

In scikit-learn, this is of course dealt with by a gridsearch. Hugging Face has introduced a [hyperparameters search feature for trainer](https://huggingface.co/transformers/master/main_classes/trainer.html#transformers.Trainer.hyperparameter_search), but I've not been able to get the search completed within a reasonable period of time (so far taking longer than the actual fine tuning process). So looks like this will take further experiments elsewhere and getting familiar with Optuna or Ray.

There is at least one [discussion thread](https://discuss.huggingface.co/t/using-hyperparameter-search-in-trainer/785/2) on Hugging Face with regards to hyperparameters search. Worth checking out to see examples of usage and issues raised by other users.

In [11]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [12]:
# parameters below based on own my trials 

training_args = TrainingArguments(
    output_dir="results",  # output directory
    overwrite_output_dir=True,
    num_train_epochs=3,  # total number of training epochs
    per_device_train_batch_size=4,  # batch size per device during training
    per_device_eval_batch_size=4,  # batch size for evaluation
    warmup_steps=1000,  # number of warmup steps for learning rate scheduler
    weight_decay=0.01,  # strength of weight decay
    logging_dir="logs",  # directory for storing logs
    logging_steps=5000,  # default: 500
    save_steps=5000,  # default: 500
    learning_rate=1e-5,
    do_train=True,
    do_eval=True,
    evaluate_during_training=True,
    seed=16,
    gradient_accumulation_steps=8,  # reduce memory usage while allowing bigger overall batch size.
)

model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

trainer = Trainer(
    model=model,  # the instantiated Transformers model to be trained
    args=training_args,  # training arguments, defined above
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,  # training dataset
    eval_dataset=test_dataset,  # test dataset
)


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

In [13]:
%%time

trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=17990.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=4498.0, style=ProgressStyle(description_…


{'eval_loss': 0.28122428506985514, 'eval_accuracy': 0.8798221234018899, 'eval_f1': 0.8832739444984343, 'eval_precision': 0.8579819593035453, 'eval_recall': 0.9101023587004895, 'epoch': 0.44469149527515284, 'step': 1000}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=4498.0, style=ProgressStyle(description_…


{'eval_loss': 0.23676062129513506, 'eval_accuracy': 0.8996664813785437, 'eval_f1': 0.9007641981417339, 'eval_precision': 0.890338006738398, 'eval_recall': 0.9114374721851357, 'epoch': 0.8893829905503057, 'step': 2000}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=17990.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=4498.0, style=ProgressStyle(description_…


{'eval_loss': 0.22289709488016393, 'eval_accuracy': 0.9093385214007782, 'eval_f1': 0.9090909090909092, 'eval_precision': 0.9108678655199375, 'eval_recall': 0.9073208722741433, 'epoch': 1.3344080044469149, 'step': 3000}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=4498.0, style=ProgressStyle(description_…


{'eval_loss': 0.21431779008663635, 'eval_accuracy': 0.9122290161200667, 'eval_f1': 0.9135694345612787, 'eval_precision': 0.8991487986208383, 'eval_recall': 0.9284601691143747, 'epoch': 1.7790994997220677, 'step': 4000}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=17990.0, style=ProgressStyle(description_…

{'loss': 0.2559881103515625, 'learning_rate': 3.036211699164346e-06, 'epoch': 2.224124513618677, 'step': 5000}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=4498.0, style=ProgressStyle(description_…


{'eval_loss': 0.2215935878084456, 'eval_accuracy': 0.9135630906058921, 'eval_f1': 0.9141026349223886, 'eval_precision': 0.9077345035655513, 'eval_recall': 0.9205607476635514, 'epoch': 2.224124513618677, 'step': 5000}




HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=4498.0, style=ProgressStyle(description_…


{'eval_loss': 0.22005129584224906, 'eval_accuracy': 0.9148971650917176, 'eval_f1': 0.9148166694486175, 'eval_precision': 0.9149693934335003, 'eval_recall': 0.9146639964396974, 'epoch': 2.6688160088938297, 'step': 6000}


CPU times: user 4d 12h 40min 31s, sys: 1h 14min 19s, total: 4d 13h 54min 50s
Wall time: 5h 29min 43s


TrainOutput(global_step=6744, training_loss=0.22906771079501223)

## 1.4: EVALUATE RESULTS OF FINE-TUNING

HF's trainer makes the evaluation of the fine-tuned model very easy. With the usual metrics for a classifier (f1, recall, precision etc) near or above 0.9, the results certainly look very good.

In [14]:
%%time

trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=4498.0, style=ProgressStyle(description_…


{'eval_loss': 0.2218987653671809, 'eval_accuracy': 0.9158421345191773, 'eval_f1': 0.9163813100629625, 'eval_precision': 0.9098486510199605, 'eval_recall': 0.9230084557187361, 'epoch': 2.9996664813785436, 'step': 6744}
CPU times: user 2h 10min 23s, sys: 1min 54s, total: 2h 12min 18s
Wall time: 6min 53s


{'epoch': 2.9996664813785436,
 'eval_accuracy': 0.9158421345191773,
 'eval_f1': 0.9163813100629625,
 'eval_loss': 0.2218987653671809,
 'eval_precision': 0.9098486510199605,
 'eval_recall': 0.9230084557187361}

## 1.5: SAVE THE FINE TUNED MODEL

The resulting model is too big to be pushed to Github. But I've uploaded a [copy to Dropbox](https://www.dropbox.com/sh/90h7ymog2oi5yn7/AACTuxmMTcso6aMxSmSiD8AVa) for anyone who wants to try it out.

You can of course [upload your model to Hugging Face's model hub](https://huggingface.co/transformers/master/model_sharing.html). I opted not to do so in this case since the use case for a state troll detector isn't that wide (though the problem is huge). 

In [16]:
ft_model = "finetuned/troll_detect"
trainer.save_model(ft_model)
tokenizer.save_pretrained(ft_model)


('finetuned/troll_detect/vocab.txt',
 'finetuned/troll_detect/special_tokens_map.json',
 'finetuned/troll_detect/added_tokens.json')

## 1.6: QUICK EVALUATION ON VALIDATION SET

Trainer also provides an easy way to quickly evaluate the fine tuned model against new data via the predict function. Just prepare the data in the same way as the train-test datasets and you are good to go.

From the looks of things, the model did very well in picking out the unseen state troll and real tweets. 

See this [notebook](https://github.com/chuachinhon/transformers_state_trolls_cch/blob/master/notebooks/2.1_validate_finetuned_model_cch.ipynb) for a closer examination of the model's performance on the validation set.

In [17]:
val = pd.read_csv("validate.csv")

val_texts = list(val["clean_text"].values)
val_labels = list(val["troll_or_not"].values)

val_encodings = tokenizer(val_texts, truncation=True, padding=True)
val_dataset = tweetsdataset(val_encodings, val_labels)


In [18]:
trainer.predict(val_dataset)

HBox(children=(FloatProgress(value=0.0, description='Prediction', max=2500.0, style=ProgressStyle(description_…




PredictionOutput(predictions=array([[ 0.78790015, -0.55224204],
       [ 1.5685508 , -1.5927229 ],
       [ 1.0601425 , -0.9790779 ],
       ...,
       [ 3.130257  , -3.0041547 ],
       [ 2.9400787 , -2.7221072 ],
       [ 2.944794  , -2.8034925 ]], dtype=float32), label_ids=array([0, 0, 1, ..., 0, 0, 0]), metrics={'eval_loss': 0.21223547568158246, 'eval_accuracy': 0.9179, 'eval_f1': 0.9189935865811544, 'eval_precision': 0.9178163184864012, 'eval_recall': 0.9201738786801027})