In [12]:
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
from math import ceil
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.utils import shuffle
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
from numpy import load

In [2]:
# load dict of arrays
dict_data = load('title_full_bert.npz')
# extract the first array
X = dict_data['arr_0']

# load dict of arrays
dict_data = load('ytrain_label.npz')
# extract the first array
y = dict_data['arr_0']

In [3]:
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.4, random_state=42)

for train_idx, test_idx in splitter.split(X, y):
    xtrain = [X[i] for i in train_idx]
    ytrain = [y[i] for i in train_idx]
    xtest = [X[i] for i in test_idx]
    ytest = [y[i] for i in test_idx]

xtrain = np.array(xtrain)
ytrain = np.array(ytrain)
xtest = np.array(xtest)
ytest = np.array(ytest)

In [6]:
# Train in batch:
logmod = SGDClassifier(loss="log_loss", n_jobs=-1, max_iter=1000)
batch_size = 100000
max_len = len(xtrain)
num_batch = ceil(max_len / batch_size)
for epoch in range(25):
    X, y = shuffle(xtrain, ytrain)
    for i in range(num_batch):
        if i == num_batch - 1:
            logmod.partial_fit(xtrain[i * batch_size : max_len], ytrain[i * batch_size : max_len], classes=[0, 1, 2])
        else:
            logmod.partial_fit(xtrain[i * batch_size : (i + 1) * batch_size], ytrain[i * batch_size : (i + 1) * batch_size], classes=[0, 1, 2])

In [8]:
logmod = LogisticRegression(solver="lbfgs", max_iter=500).fit(xtest, ytest)

In [9]:
ypred = logmod.predict(xtrain)

print(accuracy_score(ytrain, ypred))
print(confusion_matrix(ytrain, ypred))

0.961994097292444
[[568400   2718  28849]
 [  2502 592669   4829]
 [ 24609   4898 570378]]


In [14]:
print(classification_report(ytrain, ypred))

              precision    recall  f1-score   support

           0       0.95      0.95      0.95    599967
           1       0.99      0.99      0.99    600000
           2       0.94      0.95      0.95    599885

    accuracy                           0.96   1799852
   macro avg       0.96      0.96      0.96   1799852
weighted avg       0.96      0.96      0.96   1799852



In [10]:
import joblib
joblib.dump(logmod, "../../src/models/lr_label_bert_full_60_42.joblib", compress=3)

['../../src/models/lr_label_bert_full_60_42.joblib']