**Classifier Wilde vs Mistral7B-Instruct**

In this notebook we are going to train a classifier for authorship atribution. The possible authors are Oscar Wilde or the baseline model Mistral7B-Instruct. The implementation of the classifier is based on the paper [BertAA: BERT fine-tuning for Authorship Attribution](https://aclanthology.org/2020.icon-main.16.pdf).

In [None]:
# uncomment the following lines to run in colab
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# run this cell to install simpletransformers if you are running in colab
# !pip install -U simpletransformers

In [None]:
# import the required libraries
from pandas import DataFrame
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from simpletransformers.classification import ClassificationModel
import torch
import json
import re
import matplotlib.pyplot as plt

In [None]:
# data and models paths
root_path = '.' # comment this line if you are running in colab
root_path = './drive/MyDrive/DL-ENS' # uncomment this line if you are running in colab
dir_data = f'{root_path}/data'
models_path = f'{root_path}/models'
wilde_texts_path = f'{dir_data}/wilde_complete.txt'
mistral_gen_texts_list = f'{dir_data}/BaseModelCompletionsToTrainClassifier/dataset_mistral7B_gen_texts.json'
authors_names = ["Wilde", "Mistral7B-Instruct"]

In [None]:
# function to read the texts of an specific author
def read_texts(path: str, label, len_to_read =None, max_length = 350):
  """
  Read the texts of an specific author and return a dictionary with the texts and the labels
  inputs:
    path: str: path to the file with the texts
    label: int: label to assign to the texts
    len_to_read: int: number of texts to read from the author
    max_length: int: max length of the texts to return
  outputs:
    dt: dict: dictionary with the texts and the labels
  """
    text = ''
    with open(path, 'r+') as fd:
      text = fd.read()
      if len_to_read != None:
        text = text[:len_to_read]
    text_splited = text.split()
    dt = {'text': [], 'label': []}
    for i in range(0,len(text_splited),max_length):
      text = ' '.join(text_splited[i:min(i+max_length, len(text_splited))])
      dt['text'].append(text)
      dt['label'].append(label)
    return dt

In [None]:
# build dataset for classification with both authors
dt = {'text': [], 'label': []}
for i,path in enumerate([wilde_texts_path, mistral_gen_texts_list]):
  dt_i = read_texts(path,i)
  dt['text'].extend(dt_i['text'])
  dt['label'].extend(dt_i['label'])

In [None]:
# convert the dataset to DataFrame
dt = DataFrame.from_dict(dt)
dt.head()

In [None]:
# split the dataset into train and test
dt_train, dt_test = train_test_split(dt, test_size=0.2, random_state=42, shuffle=True)

In [None]:
dt_train.head()

In [None]:
dt_test.head()

In [None]:
# check the distribution of the labels in the train dataset
dt_train.hist()
plt.show()

In [None]:
# check the distribution of the labels in the test dataset
dt_test.hist()
plt.show()

In [None]:
# define model for classifier and initial weights
model_name = 'bert'
model_weights =  'bert-base-cased'

In [None]:
# train the model (take into the inbalance of the dataset by setting the weights of the classes to be the inverse of the frequency of the classes in the dataset)
model = ClassificationModel(model_name, model_weights, num_labels=2, weight = [1 - sum(dt_train['label'])/len(dt_train['label']), sum(dt_train['label'])/len(dt_train['label'])], args={'reprocess_input_data': True, 'overwrite_output_dir': True,  'num_train_epochs' : 5}, use_cuda=True)
model.train_model(dt_train)

In [None]:
# predict the test dataset labels
predictions, raw_out = model.predict(list(dt_test['text']))

In [None]:
# classification report for the test dataset
print(classification_report(dt_test['label'], predictions, target_names = authors_names))

In [None]:
# save classifier model
model_save_name = 'BertAA_wilde_vs_mistral7B.pt'
path = f"{models_path}/{model_save_name}"
torch.save(model, path)

In [1]:
print("Done")

Done
