In [None]:
import json
import os
import pickle
import sys

from matplotlib import pyplot as plt
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV, GroupKFold
from sklearn.tree import DecisionTreeClassifier, plot_tree

module_path = os.path.abspath(os.path.join("../../.."))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.config import PATH_INTERIM_CORPUS  # noqa: E402
from src.config import PATH_BEST_MODELS, PICKLE_PROTOCOL  # noqa: E402

In [None]:
INTERIM_DATASETS_PATH = os.path.join(PATH_INTERIM_CORPUS, "xml/depression")

In [None]:
dmc_corpus_train = os.path.join(INTERIM_DATASETS_PATH, "depression-dmc-train.pkl")

with open(dmc_corpus_train, "rb") as fp:
    x_train, y_train, groups_train = pickle.load(fp)

In [None]:
dmc_corpus_test = os.path.join(INTERIM_DATASETS_PATH, "depression-dmc-test.pkl")

with open(dmc_corpus_test, "rb") as fp:
    x_test, y_test, _ = pickle.load(fp)

In [None]:
dtc = DecisionTreeClassifier(random_state=0)

In [None]:
# Set the parameters by cross-validation
parameters = {
    "criterion": ["gini", "entropy"],
    "splitter": ["best", "random"],
    "max_depth": [3, 4],
    "min_samples_leaf": [1, 0.1, 10],
    "random_state": [42],
    "class_weight": [None, "balanced"],
}

In [None]:
gkf = GroupKFold(n_splits=21).split(x_train, y_train, groups_train)

for train_index, test_index in gkf:
    print(
        sum(y_train[idx] for idx in train_index),
        sum(y_train[idx] for idx in test_index),
    )

In [None]:
gkf = GroupKFold(n_splits=21).split(x_train, y_train, groups_train)

clf = GridSearchCV(dtc, parameters, cv=gkf, scoring="f1")

clf.fit(x_train, y_train)

In [None]:
clf.best_score_, clf.best_params_

In [None]:
y_test_pred = clf.best_estimator_.predict(x_test)

In [None]:
print(classification_report(y_test, y_test_pred))

In [None]:
dmc_corpus_feature_names = os.path.join(
    INTERIM_DATASETS_PATH, "depression-dmc-feature-names.json"
)

with open(dmc_corpus_feature_names) as fp:
    feature_names = json.load(fp=fp)

class_names = ["negative", "positive"]

fig = plt.figure(figsize=(60, 30), facecolor="white")
plot_tree(
    clf.best_estimator_,
    feature_names=feature_names,
    class_names=class_names,
    filled=True,
    proportion=True,
)
plt.show()

In [None]:
# Save model
model_path = os.path.join(
    PATH_BEST_MODELS,
    "positive_f1/reddit/depression/selected_models/dmc_decision_tree.pkl",
)

with open(model_path, "wb") as fp:
    pickle.dump(clf.best_estimator_, fp, protocol=PICKLE_PROTOCOL)