-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
66 lines (53 loc) · 1.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import TruncatedSVD
from catboost import CatBoostClassifier
from preprocessing.feature_engineering import (
load_data,
split_users,
make_features,
svd,
cross_validate,
)
from evaluation.evaluation_metrics import (
get_class_accuracy_cv,
confusion_matrix_cv,
print_classification_report,
plot_feature_importances,
)
from figures.construct_plots import plot_data
from constants import WINDOW_SIZE, EPOCH_SIZE
def main():
# Load the data
data = load_data("data/")
print("Data loaded.", end="\n")
# Initiate the model and the SVD transformer
catbst = CatBoostClassifier(
loss_function="MultiClass",
iterations=100,
depth=6,
learning_rate=0.05,
verbose=False,
auto_class_weights="SqrtBalanced",
)
svd = TruncatedSVD(n_components=10)
# Train the model using cross-validation
print("Training model...")
y_tests, y_preds, feature_importances = cross_validate(data, catbst, svd, n=5)
# Evaluate the accuracy of the model
accs = get_class_accuracy_cv(y_tests, y_preds)
print("\nWake\tNREM\tREM", end="\n")
print("\t".join(accs.round(2).astype(str)))
# Confusion matrix
confusion_matrix_cv(
np.concatenate(y_tests),
np.concatenate(y_preds).reshape(-1),
save=True,
)
# Classification report
print(print_classification_report(y_tests, y_preds))
# Feature importances
plot_feature_importances(np.array(feature_importances).mean(axis=0), save=True)
if __name__ == "__main__":
main()