Fine tuning BERT - Naufal Yahya Kurnianto 13519141

In [1]:
!pip install transformers

import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForMaskedLM, BertModel, AutoModel
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

from tabulate import tabulate
from tqdm import trange
import random

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
!unzip -o data_worthcheck.zip

Archive:  data_worthcheck.zip
  inflating: data_worthcheck/labels.txt  
  inflating: data_worthcheck/dev.csv  
  inflating: data_worthcheck/check.py  
  inflating: data_worthcheck/train.csv  
  inflating: data_worthcheck/test.csv  


In [3]:
!head -10 data_worthcheck/dev.csv 

text_a,label
jek dajal ga depok bang,no
detikcom untung depok masuk wilayah nya ridwan kamil kalo masuk wilayah nya anis abis lu bully ama buzzer kolam,no
df dom jakarta depok yg gunain vc cabang nya cabang yg tercantum pas kesana gabisa bayar pake shopeepay,no
your2rl depok jkt,no
doakan indonesia selamat virus corona pkb depok gelar nusantara bershalawat,yes
warga depok terganggu isu corona,yes
kenapaa mendengar kabar salah wni positif corona depok dimana tinggal ku ku kawatir takut,yes
hug f cibinong bogor depok ga makan siang bareng m24,no
mukenahhh tlongggg ak maw hp ak kentank bingits sdh kebelah hadiah ultah hshs ak depok btw follback yh,no


In [4]:
!head -10 data_worthcheck/train.csv 

,text_a,label
0,betewe buka twitter cuman ngetweet liat home berita corona panik kepikiran ndamau buka2 home yg aware aja i ll stay at home nda rumah kalo nda penting2 banget,no
1,mas piyuuu mugo2 corona tuh mulut tersumpal ma corona,no
2,e100ss gini buka informasi sejelas nya identitas daerah penderita terjangkit infokan masyarakat mengisolasi nya kontak langsung penderita positif corona ditutup tutupi,yes
3,neng solo wes ono terduga corona cobo neng ati mu neng conora,no
4,midiahn nii akun gak takut takut nya isu corona wkwkwkw,no
5,hey corona prrgi sna,no
6,gara corona masuk ketempat aja mesti scan jidat gw kek jajanan indomaret,no
7,jokowi menteri2 nya silakan tes corona,yes
8,pencegahan corona other moms minum multivitamin my mom minum rebusan sambiloto,yes


In [5]:
!head -10 data_worthcheck/test.csv 

text_a,label
jek dajal ga depok bang,no
detikcom untung depok masuk wilayah nya ridwan kamil kalo masuk wilayah nya anis abis lu bully ama buzzer kolam,no
df dom jakarta depok yg gunain vc cabang nya cabang yg tercantum pas kesana gabisa bayar pake shopeepay,no
your2rl depok jkt,no
doakan indonesia selamat virus corona pkb depok gelar nusantara bershalawat,yes
warga depok terganggu isu corona,yes
kenapaa mendengar kabar salah wni positif corona depok dimana tinggal ku ku kawatir takut,yes
hug f cibinong bogor depok ga makan siang bareng m24,no
mukenahhh tlongggg ak maw hp ak kentank bingits sdh kebelah hadiah ultah hshs ak depok btw follback yh,no


In [6]:
file_path = 'data_worthcheck/train.csv'
df = pd.DataFrame({'label':int(), 'text':str()}, index = [])
i = 0
with open(file_path) as f:
  next(f)
  for line in f.readlines():
    split = line.split(',')
    df = df.append({'label': 1 if split[2] == 'yes\n' else 0,
                    'text': split[1]},
                    ignore_index = True)
df.head(10)

Unnamed: 0,label,text
0,0,betewe buka twitter cuman ngetweet liat home b...
1,0,mas piyuuu mugo2 corona tuh mulut tersumpal ma...
2,1,e100ss gini buka informasi sejelas nya identit...
3,0,neng solo wes ono terduga corona cobo neng ati...
4,0,midiahn nii akun gak takut takut nya isu coron...
5,0,hey corona prrgi sna
6,0,gara corona masuk ketempat aja mesti scan jida...
7,1,jokowi menteri2 nya silakan tes corona
8,1,pencegahan corona other moms minum multivitami...
9,0,mamaciaaa mnrut gue jngan dkt2 corona cb dkt y...


In [7]:
text = df.text.values
labels = df.label.values

In [8]:
tokenizer = BertTokenizer.from_pretrained(
    'indobenchmark/indobert-base-p1',
    do_lower_case = True
    )

def print_rand_sentence():
  '''Displays the tokens and respective IDs of a random text sample'''
  index = random.randint(0, len(text)-1)
  table = np.array([tokenizer.tokenize(text[index]), 
                    tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text[index]))]).T
  print(tabulate(table,
                 headers = ['Tokens', 'Token IDs'],
                 tablefmt = 'fancy_grid'))

print_rand_sentence()

╒═════════════╤═════════════╕
│ Tokens      │   Token IDs │
╞═════════════╪═════════════╡
│ nes         │       24995 │
├─────────────┼─────────────┤
│ ##sie       │       28157 │
├─────────────┼─────────────┤
│ ##jud       │       28256 │
├─────────────┼─────────────┤
│ ##ge        │        2731 │
├─────────────┼─────────────┤
│ kalo        │        1686 │
├─────────────┼─────────────┤
│ ka          │        1187 │
├─────────────┼─────────────┤
│ bah         │         168 │
├─────────────┼─────────────┤
│ dibersihkan │       10364 │
├─────────────┼─────────────┤
│ info        │        1617 │
├─────────────┼─────────────┤
│ ditutup     │        5819 │
├─────────────┼─────────────┤
│ sampe       │        5985 │
├─────────────┼─────────────┤
│ wk          │       20181 │
├─────────────┼─────────────┤
│ ##tu        │          33 │
├─────────────┼─────────────┤
│ yg          │         741 │
├─────────────┼─────────────┤
│ ditentukan  │        3835 │
├─────────────┼─────────────┤
│ semoga  

In [9]:
token_id = []
attention_masks = []

def preprocessing(input_text, tokenizer):
  '''
  Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
    - input_ids: list of token ids
    - token_type_ids: list of token type ids
    - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
  '''
  return tokenizer.encode_plus(
                        input_text,
                        add_special_tokens = True,
                        max_length = 32,
                        padding = 'max_length',
                        return_attention_mask = True,
                        return_tensors = 'pt',
                        truncation = True
                   )


for sample in text:
  encoding_dict = preprocessing(sample, tokenizer)
  token_id.append(encoding_dict['input_ids']) 
  attention_masks.append(encoding_dict['attention_mask'])


token_id = torch.cat(token_id, dim = 0)
attention_masks = torch.cat(attention_masks, dim = 0)
labels = torch.tensor(labels)

In [10]:
def print_rand_sentence_encoding():
  '''Displays tokens, token IDs and attention mask of a random text sample'''
  index = random.randint(0, len(text) - 1)
  tokens = tokenizer.tokenize(tokenizer.decode(token_id[index]))
  token_ids = [i.numpy() for i in token_id[index]]
  attention = [i.numpy() for i in attention_masks[index]]

  table = np.array([tokens, token_ids, attention]).T
  print(tabulate(table, 
                 headers = ['Tokens', 'Token IDs', 'Attention Mask'],
                 tablefmt = 'fancy_grid'))

print_rand_sentence_encoding()

╒════════════╤═════════════╤══════════════════╕
│ Tokens     │   Token IDs │   Attention Mask │
╞════════════╪═════════════╪══════════════════╡
│ [CLS]      │           2 │                1 │
├────────────┼─────────────┼──────────────────┤
│ presiden   │        1871 │                1 │
├────────────┼─────────────┼──────────────────┤
│ tegaskan   │       28072 │                1 │
├────────────┼─────────────┼──────────────────┤
│ serius     │        3934 │                1 │
├────────────┼─────────────┼──────────────────┤
│ tangani    │       18683 │                1 │
├────────────┼─────────────┼──────────────────┤
│ wabah      │       18465 │                1 │
├────────────┼─────────────┼──────────────────┤
│ cor        │        3021 │                1 │
├────────────┼─────────────┼──────────────────┤
│ ##ona      │        2524 │                1 │
├────────────┼─────────────┼──────────────────┤
│ jakarta    │         678 │                1 │
├────────────┼─────────────┼────────────

In [11]:
val_ratio = 0.2
# Recommended batch size: 16, 32. See: https://arxiv.org/pdf/1810.04805.pdf
batch_size = 16

# Indices of the train and validation splits stratified by labels
train_idx, val_idx = train_test_split(
    np.arange(len(labels)),
    test_size = val_ratio,
    shuffle = True,
    stratify = labels)

# Train and validation sets
train_set = TensorDataset(token_id[train_idx], 
                          attention_masks[train_idx], 
                          labels[train_idx])

val_set = TensorDataset(token_id[val_idx], 
                        attention_masks[val_idx], 
                        labels[val_idx])

# Prepare DataLoader
train_dataloader = DataLoader(
            train_set,
            sampler = RandomSampler(train_set),
            batch_size = batch_size
        )

validation_dataloader = DataLoader(
            val_set,
            sampler = SequentialSampler(val_set),
            batch_size = batch_size
        )

In [12]:
def b_tp(preds, labels):
  '''Returns True Positives (TP): count of correct predictions of actual class 1'''
  return sum([preds == labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_fp(preds, labels):
  '''Returns False Positives (FP): count of wrong predictions of actual class 1'''
  return sum([preds != labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_tn(preds, labels):
  '''Returns True Negatives (TN): count of correct predictions of actual class 0'''
  return sum([preds == labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_fn(preds, labels):
  '''Returns False Negatives (FN): count of wrong predictions of actual class 0'''
  return sum([preds != labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_metrics(preds, labels):
  '''
  Returns the following metrics:
    - accuracy    = (TP + TN) / N
    - precision   = TP / (TP + FP)
    - recall      = TP / (TP + FN)
    - specificity = TN / (TN + FP)
  '''
  preds = np.argmax(preds, axis = 1).flatten()
  labels = labels.flatten()
  tp = b_tp(preds, labels)
  tn = b_tn(preds, labels)
  fp = b_fp(preds, labels)
  fn = b_fn(preds, labels)
  b_accuracy = (tp + tn) / len(labels)
  b_precision = tp / (tp + fp) if (tp + fp) > 0 else 'nan'
  b_recall = tp / (tp + fn) if (tp + fn) > 0 else 'nan'
  b_specificity = tn / (tn + fp) if (tn + fp) > 0 else 'nan'
  return b_accuracy, b_precision, b_recall, b_specificity

In [13]:
# model = AutoModel.from_pretrained(
#     "indolem/indobert-base-uncased",
#     num_labels = 2,
#     output_attentions = False,
#     output_hidden_states = False,
# )

model = BertForSequenceClassification.from_pretrained (
    "indobenchmark/indobert-base-p1",
    num_labels = 2,
    output_attentions = False,
    output_hidden_states = False,
)

# model = BertModel.from_pretrained (
#     "bert-base-multilingual-uncased",
#     num_labels = 2,
#     output_attentions = False,
#     output_hidden_states = False,
# )

# Recommended learning rates (Adam): 5e-5, 3e-5, 2e-5. See: https://arxiv.org/pdf/1810.04805.pdf
optimizer = torch.optim.AdamW(model.parameters(), 
                              lr = 3e-5,
                              eps = 1e-08
                              )

# Run on GPU
model.cuda()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at indobenchmark/indobert-base-p1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(50000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Recommended number of epochs: 2, 3, 4. See: https://arxiv.org/pdf/1810.04805.pdf
epochs = 2

for _ in trange(epochs, desc = 'Epoch'):
    
    # ========== Training ==========
    
    # Set model to training mode
    model.train()
    
    # Tracking variables
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0

    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        optimizer.zero_grad()
        # Forward pass
        train_output = model(b_input_ids, 
                             token_type_ids = None, 
                             attention_mask = b_input_mask, 
                             labels = b_labels)
        # Backward pass
        train_output.loss.backward()
        optimizer.step()
        # Update tracking variables
        tr_loss += train_output.loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1

    # ========== Validation ==========

    # Set model to evaluation mode
    model.eval()

    # Tracking variables 
    val_accuracy = []
    val_precision = []
    val_recall = []
    val_specificity = []

    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
          # Forward pass
          eval_output = model(b_input_ids, 
                              token_type_ids = None, 
                              attention_mask = b_input_mask)
        logits = eval_output.logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        # Calculate validation metrics
        b_accuracy, b_precision, b_recall, b_specificity = b_metrics(logits, label_ids)
        val_accuracy.append(b_accuracy)
        # Update precision only when (tp + fp) !=0; ignore nan
        if b_precision != 'nan': val_precision.append(b_precision)
        # Update recall only when (tp + fn) !=0; ignore nan
        if b_recall != 'nan': val_recall.append(b_recall)
        # Update specificity only when (tn + fp) !=0; ignore nan
        if b_specificity != 'nan': val_specificity.append(b_specificity)

    print('\n\t - Train loss: {:.4f}'.format(tr_loss / nb_tr_steps))
    print('\t - Validation Accuracy: {:.4f}'.format(sum(val_accuracy)/len(val_accuracy)))
    print('\t - Validation Precision: {:.4f}'.format(sum(val_precision)/len(val_precision)) if len(val_precision)>0 else '\t - Validation Precision: NaN')
    print('\t - Validation Recall: {:.4f}'.format(sum(val_recall)/len(val_recall)) if len(val_recall)>0 else '\t - Validation Recall: NaN')
    print('\t - Validation Specificity: {:.4f}\n'.format(sum(val_specificity)/len(val_specificity)) if len(val_specificity)>0 else '\t - Validation Specificity: NaN')

Epoch:  50%|█████     | 1/2 [02:22<02:22, 142.50s/it]


	 - Train loss: 0.2988
	 - Validation Accuracy: 0.9031
	 - Validation Precision: 0.8421
	 - Validation Recall: 0.8198
	 - Validation Specificity: 0.9379



Epoch: 100%|██████████| 2/2 [04:47<00:00, 143.54s/it]


	 - Train loss: 0.1798
	 - Validation Accuracy: 0.9018
	 - Validation Precision: 0.7918
	 - Validation Recall: 0.8882
	 - Validation Specificity: 0.9060






In [15]:
file_path = 'data_worthcheck/test.csv'
df_testval = pd.DataFrame({'label':int(), 'text':str()}, index = [])
with open(file_path) as f:
  next(f)
  for line in f.readlines():
    split = line.split(',')
    # if (len(split[0])>1024):
    #   continue
    df_testval = df_testval.append({'label': 1 if split[1] == 'yes\n' else 0,
                    'text': split[0]},
                    ignore_index = True)

df_test = pd.DataFrame({'label':int(), 'text':str()}, index = [])
with open(file_path) as f:
  next(f)
  for line in f.readlines():
    split = line.split(',')
    # if (len(split[0])>1024):
    #   continue
    test_ids = []
    test_attention_mask = []

    # Apply the tokenizer
    encoding = preprocessing(split[0], tokenizer)

    # Extract IDs and Attention Mask
    test_ids.append(encoding['input_ids'])
    test_attention_mask.append(encoding['attention_mask'])
    test_ids = torch.cat(test_ids, dim = 0)
    test_attention_mask = torch.cat(test_attention_mask, dim = 0)

    # Forward pass, calculate logit predictions
    with torch.no_grad():
      output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))

    prediction = 'yes' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'no'

    df_test = df_test.append({'label': 1 if prediction == 'yes' else 0,
                    'text': split[0]},
                    ignore_index = True)

# # We need Token IDs and Attention Mask for inference on the new sentence
# test_ids = []
# test_attention_mask = []

# # Apply the tokenizer
# encoding = preprocessing(new_sentence, tokenizer)

# # Extract IDs and Attention Mask
# test_ids.append(encoding['input_ids'])
# test_attention_mask.append(encoding['attention_mask'])
# test_ids = torch.cat(test_ids, dim = 0)
# test_attention_mask = torch.cat(test_attention_mask, dim = 0)

# # Forward pass, calculate logit predictions
# with torch.no_grad():
#   output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))

# prediction = 'yes' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'no'

df_testval.head(100)
df_test.head(100)

Unnamed: 0,label,text
0,0,jek dajal ga depok bang
1,1,detikcom untung depok masuk wilayah nya ridwan...
2,0,df dom jakarta depok yg gunain vc cabang nya c...
3,0,your2rl depok jkt
4,1,doakan indonesia selamat virus corona pkb depo...
...,...,...
95,0,jrantesalu8 aniesbaswedan permadiaktivis gunro...
96,0,sebener nya gua panik bgt pdhl td lg joke ttg ...
97,0,udah bercandain corona tunggu terjangkit
98,0,meremehkan tindakan2 pemberitaan pengen tau aj...


In [16]:
print("%.2f" % (100-len(df_test.compare(df_testval))/len(df_testval)*100) + " % classification correct")

85.89 % classification correct
