В этом файле мы будем обучать три модели: градиентный бустинг, динейную логистическую регрессию и полносвязную нейронную сеть и сравним их. В качестве основной метрики возьменм ROC AUC 

In [15]:
import numpy as np
import catboost as cat

Загружаем ембеддинги и таргеты

In [None]:
import json

with open("train_embeddings.json") as f:
    x = np.array(json.load(f))

with open("train_target.json") as f:
    y = np.array(json.load(f))

In [5]:
with open("test_embeddings.json") as f:
    x_test = np.array(json.load(f))

with open("test_target.json") as f:
    y_test = np.array(json.load(f))

Делаем one-hot содирование таргета

In [6]:
indexes = {'Anxiety': 0, 'Bipolar': 1, 'Depression': 2, 'Normal': 3,
       'Personality disorder': 4, 'Stress': 5, 'Suicidal': 6}

classes = ['Anxiety', 'Bipolar', 'Depression', 'Normal',
       'Personality disorder', 'Stress', 'Suicidal']

In [7]:
def to_matrix(lables):
    res = np.zeros((len(lables), 7))

    for i in range(len(res)):
        res[i][indexes[lables[i]]] = 1

    return res

def to_classes(probas):
    res = [str()] * len(probas)

    for i in range(len(probas)):
        j_max = 0
        v_max = probas[i][0]
        for j in range(1, 7):
            if probas[i][j] > v_max:
                v_max = probas[i][j]
                j_max = j
        res[i] = classes[j_max]
    
    return np.array(res)


y_cat = to_matrix(y)
y_cat_test = to_matrix(y_test)

Функции для оценки качества модели

In [8]:
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

# выводит precision, recall, f1 score и ROC AUC
def print_metrics(pred):

    roc_auc = roc_auc_score(y_cat_test, pred, labels=classes, average="weighted")
    prf1 = precision_recall_fscore_support(y_test, to_classes(pred), labels=classes, average="weighted")
    print(
        f"precision = {prf1[0]}",
        f"recall = {prf1[1]}",
        f"f1 score = {prf1[2]}",
        f"ROC AUC = {roc_auc}",
    sep="\n")

# выводит матрицу ошибок
def print_confusion_matrix(pred):

    mat_con = confusion_matrix(y_test, pred, labels=classes)

    fig, ax = plt.subplots(figsize=(7.5, 7.5))
    ax.matshow(mat_con, cmap=plt.cm.YlOrRd, alpha=0.5)

    for i in range(mat_con.shape[0]):
        for j in range(mat_con.shape[1]):
            ax.text(x=j, y=i,s=mat_con[i, j], va='center', ha='center', size='xx-large')

    plt.xticks(range(0,7), classes, fontsize= 8)
    plt.yticks(range(0,7), classes, fontsize= 8)
    plt.xlabel("Predictions", fontsize=16)
    plt.ylabel("Actuals", fontsize=16)
    plt.title("Confusion Matrix", fontsize=15)
    plt.show()
    plt.clf()

Обучаем градиентный бустинг

In [None]:
grad_boost_model = cat.CatBoostClassifier(5, depth=9)

grad_boost_model.fit(x, y,verbose=False, plot=True, eval_set=(x_test, y_test), early_stopping_rounds=20)

Обучаем линейную логистическую регрессию

In [None]:
from sklearn.linear_model import LogisticRegression

linear_model = LogisticRegression()
linear_model.fit(x, y)

Обучаем нейронную сеть

In [None]:
from keras.layers import Dense, Dropout
from keras.models import Sequential
# from tensorflow.keras.layers import BatchNormalization

n = 180

neuro_network_model = Sequential([
    Dense(384, input_shape=(384,), activation="relu"),
    Dropout(0.35),
    Dense(n, input_shape=(384,), activation="relu"),
    Dropout(0.35),
    Dense(7, input_shape=(n,), activation='softmax')
])

neuro_network_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["f1_score"])

neuro_network_model.fit(x, y_cat, batch_size=50, epochs=5)

Смотрим на метрики моделей

In [None]:
print("Gradient boosting:")
print_metrics(grad_boost_model.predict_proba(x_test))
print_confusion_matrix(grad_boost_model.predict(x_test))

print("\nLinear log regress:")
print_metrics(linear_model.predict_proba(x_test))
print_confusion_matrix(linear_model.predict(x_test))

print("\nNeuro network:")
print_metrics(neuro_network_model.predict(x_test, verbose=0))
print_confusion_matrix(to_classes(neuro_network_model.predict(x_test, verbose=0)))