In [12]:
from sklearn.feature_extraction.text import TfidfVectorizer, TfidfTransformer, CountVectorizer
import numpy as np
from transformers import Trainer, TrainingArguments
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
#plot roc curve
from sklearn.metrics import roc_curve, auc
import pandas as pd

In [13]:
#load books
df = pd.read_csv('../data/gutenberg_paragraphs.csv')

df.dropna(inplace=True)
df.reset_index(drop=True, inplace=True)

BALANCE_CLASSES = False
BALANCED_METHOD = 'upsampling' #upsampling, downsampling
RANDOM_SEED = 42

if BALANCE_CLASSES:
    if BALANCED_METHOD == 'upsampling':
        max_class_count = df['Author'].value_counts().max()
        df = df.groupby('Author').apply(lambda x: x.sample(max_class_count, replace=True, random_state=RANDOM_SEED)).reset_index(drop=True)
    elif BALANCED_METHOD == 'downsampling':
        min_class_count = df['Author'].value_counts().min()
        df = df.groupby('Author').apply(lambda x: x.sample(min_class_count, random_state=RANDOM_SEED)).reset_index(drop=True)




#merge books by author
#df = df.groupby('Authors')['Books'].apply(' '.join).reset_index()

df.groupby('Author').count()


Unnamed: 0_level_0,Text
Author,Unnamed: 1_level_1
"Alcott, Louisa May",7228
"Austen, Jane",6755
"Christie, Agatha",4266
"Doyle, Arthur Conan",4211
"Shakespeare, William",1410
"Verne, Jules",3977


In [14]:
books = df['Text'].values
authors = df['Author'].unique()

vectorizer = CountVectorizer()
X = vectorizer.fit_transform(books)
authors2idx = {author: idx for idx, author in enumerate(authors)}
idx2authors = {idx: author for idx, author in enumerate(authors)}
y = np.array([author for author in df["Author"]])





In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=RANDOM_SEED)

model = LogisticRegression(verbose=1, max_iter=1000,C=1)
model.fit(X_train, y_train)



In [None]:
y_pred = model.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
recall = recall_score(y_test, y_pred, average='macro')
precision = precision_score(y_test, y_pred, average='macro')
f1 = f1_score(y_test, y_pred, average='macro')

print("Test metrics")
print("Accuracy:", accuracy)
print("Recall:", recall)
print("Precision:", precision)
print("F1:", f1)
print("-------------------------------------")
print("Train metrics")
print("Accuracy:", accuracy_score(y_train, model.predict(X_train)))
print("Recall:", recall_score(y_train, model.predict(X_train), average='macro'))
print("Precision:", precision_score(y_train, model.predict(X_train), average='macro'))
print("F1:", f1_score(y_train, model.predict(X_train), average='macro'))


In [None]:
cm = confusion_matrix(y_test, y_pred)
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.colorbar()
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
#plot ticks
tick_marks = np.arange(len(authors))
plt.xticks(tick_marks, authors, rotation=45)
plt.yticks(tick_marks, authors)

plt.show()

In [None]:

n_classes = len(authors)

classes = [i for i in range(n_classes)]

y_pred_prob = model.predict_proba(X_test)
y_true = y_test

#covert author names to numbers
y_true = [authors2idx[author] for author in y_true]
y_pred_prob = [[prob for prob in author] for author in y_pred_prob]


fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    _y_true = np.array(y_true) == i
    #_y_true = _y_true.astype(int)
    _y_pred_prob = np.array(y_pred_prob)[:, i]
    fpr[i], tpr[i], _ = roc_curve(_y_true, _y_pred_prob)
    roc_auc[i] = auc(fpr[i], tpr[i])


plt.figure(figsize=(8, 8))
lw = 2

for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], "-", lw=lw,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(authors[i], roc_auc[i]))

plt.xlim([-0.05, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC for Author Classification')
plt.legend(loc="lower right")
plt.show()
