# Fine-tuning BERT for Binary Classification



In [1]:
# Mount the Google drive for access to files
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
# Basic Python modules
import os
import re
from collections import defaultdict, Counter
import random
import pickle

# For data manipulation and analysis
import pandas as pd
import numpy as np

# For machine learning tools and evaluation
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split

# For deep learning
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
import torch

In [3]:
incerto_dir = '/content/drive/MyDrive/incerto-autore'
new_poems_dir = os.path.join(incerto_dir, 'data', 'poems')
poems_split_df = pd.read_csv(os.path.join(new_poems_dir, 'poems_split.csv'))
len(poems_split_df)

682

In [4]:
!pip3 install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.3-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m105.1 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.3 tokenizers-0.13.2 transformers-4.27.3


In [5]:
# using DistilBERT for testing --> can switch to BERT once set up
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments

In [6]:
# Choose the GPU we want to process this script
device_name = 'cuda'      

# Choose the BERT model that we want to use (make sure to keep the cased/uncased consistent)
model = 'dbmdz/bert-base-italian-xxl-uncased'
#model = os.path.join(incerto_dir, 'contbertoldo-all', 'checkpoint')

# This is the maximum number of tokens in any document sent to BERT
max_length = 512                                                        

In [7]:
if 'contbertoldo' in model:
  finetuned_path = os.path.join(incerto_dir, 'output','finetuned-models', 'binary-class', 'bertoldo')
elif 'italian':
  finetuned_path = os.path.join(incerto_dir, 'output','finetuned-models', 'binary-class', 'bert-ita')
if not os.path.exists(finetuned_path):
  os.makedirs(finetuned_path)

### BERT setup

In [8]:
# load the encoder/tokenizer
tokenizer = BertTokenizer.from_pretrained(model)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/243k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/433 [00:00<?, ?B/s]

In [9]:
# class for Torch dataset
class SCDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [10]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=20,   # batch size for evaluation
    learning_rate=5e-5,              # initial learning rate for Adam optimizer
    warmup_steps=50,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
    evaluation_strategy='steps',
)

In [11]:
# load pre-trained model
model = BertForSequenceClassification.from_pretrained(model).to(device_name)

Downloading pytorch_model.bin:   0%|          | 0.00/445M [00:00<?, ?B/s]

Some weights of the model checkpoint at dbmdz/bert-base-italian-xxl-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification

In [12]:
# Define a custom evaluation function (this could be changes to return accuracy metrics)
def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  acc = accuracy_score(labels, preds)
  return {
      'accuracy': acc,
  }

## Classification task setup

In [13]:
annotations_df = poems_split_df.loc[poems_split_df['author'] != 'Unknown']
len(annotations_df)

622

In [14]:
for author in annotations_df['author'].unique():

  print(author)
  author_finetuned_path = os.path.join(finetuned_path, author)
  if not os.path.exists(author_finetuned_path):
    os.makedirs(author_finetuned_path)

    X = annotations_df['poem'].tolist()
    y = annotations_df['author'].map(lambda x: 1 if x==author else 0).tolist()
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25)
    print('Y train', Counter(y_train))
    print('Y test', Counter(y_test))
    print(X_test[0:3])

    # Pass training/testing sentences to tokenizer, truncate them if over max length, and add padding (PAD tokens up to 512)
    train_encodings = tokenizer(X_train,  truncation=True, padding=True)
    test_encodings = tokenizer(X_test,  truncation=True, padding=True)

    # Combine encoded text and labels into a torch dataset object.
    train_dataset = SCDataset(train_encodings, y_train)
    test_dataset = SCDataset(test_encodings, y_test)

    # Create the trainer object based on what we've set up prior to this point! This combines our model, training_args, train_dataset and test_dataset, and custom evaluation function compute_metrics.
    trainer = Trainer(
        model=model,                         # the instantiated 🤗 Transformers model to be trained
        args=training_args,                  # training arguments, defined above
        train_dataset=train_dataset,         # training dataset
        eval_dataset=test_dataset,            # evaluation dataset
        compute_metrics=compute_metrics      # custom evaluation function
    )

    # Fine-tune the model on our dataset/labels. The trainer object will periodically output the state of the model.
    trainer.train()

    # built in evaluation function
    trainer.evaluate()

    #save model
    model.save_pretrained(author_finetuned_path)

    print(Counter(y_test))

    predicted_labels = trainer.predict(test_dataset)
    actual_predicted_labels = predicted_labels.predictions.argmax(-1)

    class_report = classification_report(predicted_labels.label_ids.flatten(), actual_predicted_labels.flatten(), output_dict=True)
    print(classification_report(predicted_labels.label_ids.flatten(), actual_predicted_labels.flatten()))

    # New + simple save of classification report
    class_report_df = pd.DataFrame(class_report).transpose()
    class_report_df.to_csv(os.path.join(author_finetuned_path, 'classification_report.csv'))

Franco
Y train Counter({0: 301, 1: 165})
Y test Counter({0: 97, 1: 59})
['Lassa che s un nemico a l altro chieda al suo bisogno aiuto ei gli vien dato che la virtu convien che gli odii ecceda e io creder devro ch aspro e ingrato esser mi debba il mio signor diletto perch ei sia forse d altra innamorato Oime che d altra standosi nel letto me lascia raffreddar sola e scontenta colma d affanni e piena di dispetto altra ei fa del suo amor lieta e contenta e del mio mal con lei fors ancor ride che vanagloriosa ne diventa Quanto per me si lagrima e si stride dolce concento e de le loro orecchie', 'Quando m aprira mai benigna Aurora Quel lieto di de l aureo albergo uscita Ch a te pur torni homai cara e gradita Moglie e rompa si lunga aspra dimora Cosi godano ogn hor quest occhi anchora L imagin bella c ho nel cor scolpita Com io privo di te mia luce e vita Non ho mai lieta o riposata un hora Tu fidissima sposa honesta e bella Sola mi giovi e cangi l pianto in riso Con l angelico volto e la fa



Step,Training Loss,Validation Loss,Accuracy
10,0.6746,0.664565,0.621795
20,0.6601,0.638483,0.621795
30,0.6247,0.616759,0.705128
40,0.5613,0.574096,0.685897
50,0.5499,0.367201,0.858974
60,0.3997,0.45365,0.807692
70,0.1628,0.397781,0.852564
80,0.2389,0.331927,0.891026
90,0.1382,0.274959,0.903846


Counter({0: 97, 1: 59})
              precision    recall  f1-score   support

           0       0.95      0.90      0.92        97
           1       0.84      0.92      0.88        59

    accuracy                           0.90       156
   macro avg       0.89      0.91      0.90       156
weighted avg       0.91      0.90      0.90       156

AntonGiacomoCorso
Y train Counter({0: 426, 1: 40})
Y test Counter({0: 137, 1: 19})
['Questa si mesta e dolorosa vita Che non posa o si ferma a guisa d onda Raporta hor quici hor quidi entro al mio core DOLCE l amato e dilettoso tempo Che lieto mi godea l alto mio sole Onde provo vivendo un aspra morte Prima vorrei che interompesse morte Quest anni ohime che in questa amara vita Viver noiando il Ciel la luna il sole E misurar le piagge e solcar l onda Nemica al mio soave e caro tempo Con questo travagliato tristo core Perche so ben che l rio pensier che il core Sempre tien desto condurami a morte', 'Benche fortuna rade volte amica Et la nemic



Step,Training Loss,Validation Loss,Accuracy
10,1.3009,0.607166,0.858974
20,0.417,0.447248,0.878205
30,0.2799,0.385509,0.878205
40,0.3018,0.427497,0.878205
50,0.3423,0.37702,0.878205
60,0.2696,0.399319,0.878205
70,0.3623,0.37133,0.878205
80,0.2897,0.385199,0.878205
90,0.2171,0.396787,0.878205


Counter({0: 137, 1: 19})


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.88      1.00      0.94       137
           1       0.00      0.00      0.00        19

    accuracy                           0.88       156
   macro avg       0.44      0.50      0.47       156
weighted avg       0.77      0.88      0.82       156

PietroBembo
Y train Counter({0: 427, 1: 39})
Y test Counter({0: 141, 1: 15})
['Questa si mesta e dolorosa vita Che non posa o si ferma a guisa d onda Raporta hor quici hor quidi entro al mio core DOLCE l amato e dilettoso tempo Che lieto mi godea l alto mio sole Onde provo vivendo un aspra morte Prima vorrei che interompesse morte Quest anni ohime che in questa amara vita Viver noiando il Ciel la luna il sole E misurar le piagge e solcar l onda Nemica al mio soave e caro tempo Con questo travagliato tristo core Perche so ben che l rio pensier che il core Sempre tien desto condurami a morte', 'Benche fortuna rade volte amica Et la nemica eterna di virtute Invidia d 



Step,Training Loss,Validation Loss,Accuracy
10,0.2681,0.322596,0.903846
20,0.3212,0.319284,0.903846
30,0.2714,0.313323,0.903846
40,0.3428,0.325573,0.903846
50,0.2014,0.350232,0.903846
60,0.3099,0.296257,0.903846
70,0.2235,0.305139,0.891026
80,0.2255,0.43113,0.846154
90,0.1994,0.289702,0.916667


Counter({0: 141, 1: 15})
              precision    recall  f1-score   support

           0       0.94      0.97      0.95       141
           1       0.60      0.40      0.48        15

    accuracy                           0.92       156
   macro avg       0.77      0.69      0.72       156
weighted avg       0.91      0.92      0.91       156

CelioMagno
Y train Counter({0: 419, 1: 47})
Y test Counter({0: 146, 1: 10})
['Questa si mesta e dolorosa vita Che non posa o si ferma a guisa d onda Raporta hor quici hor quidi entro al mio core DOLCE l amato e dilettoso tempo Che lieto mi godea l alto mio sole Onde provo vivendo un aspra morte Prima vorrei che interompesse morte Quest anni ohime che in questa amara vita Viver noiando il Ciel la luna il sole E misurar le piagge e solcar l onda Nemica al mio soave e caro tempo Con questo travagliato tristo core Perche so ben che l rio pensier che il core Sempre tien desto condurami a morte', 'Benche fortuna rade volte amica Et la nemica eter



Step,Training Loss,Validation Loss,Accuracy
10,0.3943,0.273315,0.916667
20,0.3063,0.237061,0.935897
30,0.3893,0.314454,0.935897
40,0.3176,0.236357,0.935897
50,0.3375,0.225927,0.935897
60,0.3001,0.20317,0.935897
70,0.3463,0.199298,0.935897
80,0.3366,0.194925,0.935897
90,0.2008,0.192199,0.935897


Counter({0: 146, 1: 10})


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.94      1.00      0.97       146
           1       0.00      0.00      0.00        10

    accuracy                           0.94       156
   macro avg       0.47      0.50      0.48       156
weighted avg       0.88      0.94      0.90       156

DomenicoVenier
Y train Counter({0: 437, 1: 29})
Y test Counter({0: 145, 1: 11})
['Questa si mesta e dolorosa vita Che non posa o si ferma a guisa d onda Raporta hor quici hor quidi entro al mio core DOLCE l amato e dilettoso tempo Che lieto mi godea l alto mio sole Onde provo vivendo un aspra morte Prima vorrei che interompesse morte Quest anni ohime che in questa amara vita Viver noiando il Ciel la luna il sole E misurar le piagge e solcar l onda Nemica al mio soave e caro tempo Con questo travagliato tristo core Perche so ben che l rio pensier che il core Sempre tien desto condurami a morte', 'Benche fortuna rade volte amica Et la nemica eterna di virtute Invidia



Step,Training Loss,Validation Loss,Accuracy
10,0.2952,0.238083,0.929487
20,0.155,0.246547,0.929487
30,0.2237,0.265672,0.929487
40,0.3265,0.261211,0.929487
50,0.2124,0.248687,0.929487
60,0.163,0.277547,0.929487
70,0.1672,0.27799,0.929487
80,0.2653,0.248266,0.929487
90,0.2695,0.240676,0.929487


Counter({0: 145, 1: 11})


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.93      1.00      0.96       145
           1       0.00      0.00      0.00        11

    accuracy                           0.93       156
   macro avg       0.46      0.50      0.48       156
weighted avg       0.86      0.93      0.90       156

GiorgioGradenigo
Y train Counter({0: 457, 1: 9})
Y test Counter({0: 152, 1: 4})
['Questa si mesta e dolorosa vita Che non posa o si ferma a guisa d onda Raporta hor quici hor quidi entro al mio core DOLCE l amato e dilettoso tempo Che lieto mi godea l alto mio sole Onde provo vivendo un aspra morte Prima vorrei che interompesse morte Quest anni ohime che in questa amara vita Viver noiando il Ciel la luna il sole E misurar le piagge e solcar l onda Nemica al mio soave e caro tempo Con questo travagliato tristo core Perche so ben che l rio pensier che il core Sempre tien desto condurami a morte', 'Benche fortuna rade volte amica Et la nemica eterna di virtute Invidia



Step,Training Loss,Validation Loss,Accuracy
10,0.1348,0.128661,0.974359
20,0.1254,0.124433,0.974359
30,0.0398,0.131125,0.974359
40,0.0402,0.14613,0.974359
50,0.162,0.115157,0.974359
60,0.0927,0.128379,0.974359
70,0.0961,0.126869,0.974359
80,0.0918,0.125382,0.974359
90,0.095,0.123623,0.974359


Counter({0: 152, 1: 4})


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.97      1.00      0.99       152
           1       0.00      0.00      0.00         4

    accuracy                           0.97       156
   macro avg       0.49      0.50      0.49       156
weighted avg       0.95      0.97      0.96       156

MarcoVenier
Y train Counter({0: 456, 1: 10})
Y test Counter({0: 154, 1: 2})
['Questa si mesta e dolorosa vita Che non posa o si ferma a guisa d onda Raporta hor quici hor quidi entro al mio core DOLCE l amato e dilettoso tempo Che lieto mi godea l alto mio sole Onde provo vivendo un aspra morte Prima vorrei che interompesse morte Quest anni ohime che in questa amara vita Viver noiando il Ciel la luna il sole E misurar le piagge e solcar l onda Nemica al mio soave e caro tempo Con questo travagliato tristo core Perche so ben che l rio pensier che il core Sempre tien desto condurami a morte', 'Benche fortuna rade volte amica Et la nemica eterna di virtute Invidia d o



Step,Training Loss,Validation Loss,Accuracy
10,0.0684,0.068615,0.987179
20,0.0963,0.068707,0.987179
30,0.3371,0.069573,0.987179
40,0.0945,0.070894,0.987179
50,0.0738,0.070403,0.987179
60,0.1596,0.068425,0.987179
70,0.119,0.069791,0.987179
80,0.0718,0.068612,0.987179
90,0.1239,0.068587,0.987179


Counter({0: 154, 1: 2})


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.99      1.00      0.99       154
           1       0.00      0.00      0.00         2

    accuracy                           0.99       156
   macro avg       0.49      0.50      0.50       156
weighted avg       0.97      0.99      0.98       156

Petrarca
Y train Counter({0: 338, 1: 128})
Y test Counter({0: 121, 1: 35})
['Questa si mesta e dolorosa vita Che non posa o si ferma a guisa d onda Raporta hor quici hor quidi entro al mio core DOLCE l amato e dilettoso tempo Che lieto mi godea l alto mio sole Onde provo vivendo un aspra morte Prima vorrei che interompesse morte Quest anni ohime che in questa amara vita Viver noiando il Ciel la luna il sole E misurar le piagge e solcar l onda Nemica al mio soave e caro tempo Con questo travagliato tristo core Perche so ben che l rio pensier che il core Sempre tien desto condurami a morte', 'Benche fortuna rade volte amica Et la nemica eterna di virtute Invidia d of



Step,Training Loss,Validation Loss,Accuracy
10,1.063,0.923071,0.775641
20,1.123,0.627526,0.775641
30,0.6043,0.552896,0.775641
40,0.6254,0.531255,0.775641
50,0.5923,0.531047,0.775641
60,0.5838,0.513111,0.775641
70,0.4831,0.506955,0.775641
80,0.5509,0.479701,0.775641
90,0.6142,0.502266,0.775641


Counter({0: 121, 1: 35})
              precision    recall  f1-score   support

           0       0.78      1.00      0.87       121
           1       0.00      0.00      0.00        35

    accuracy                           0.78       156
   macro avg       0.39      0.50      0.44       156
weighted avg       0.60      0.78      0.68       156



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
