# Code to Run Single Experiments for Traditional Classifiers

#### Do not change the cell below

In [1]:
import os
os.chdir("../../..")

## Load Libraries

In [2]:

# Import required packages
import json
import numpy as np

from src.models.VRNNGMM.model import VRNNGMM
from src.dataloading.TrainingDataLoader import TrainingDataLoader

from src.metrics.summarize import summary_multiclass_metrics, summary_clustering_metrics, print_avg_metrics_paper
import src.logging.logger_utils as logger_utils

from datetime import datetime

## Configuration 

In [3]:
start_time = datetime.now()

#### LOAD CONFIGURATIONS
with open("src/models/VRNNGMM/run_config.json", "r") as f:
    args = json.load(f)


## Load Data and Process

In [4]:

#### LOAD DATA
data_loader = TrainingDataLoader(
    data_dir=args["data_dir"],
    time_window=args["time_window"],
    feat_subset=args["feat_subset"],
    train_test_ratio=args["train_test_ratio"],
    train_val_ratio=args["train_val_ratio"],
    seed=args["seed"],
    normalize=args["normalize"],
    num_folds=args["num_folds"]
)


data_characteristics = data_loader._get_data_characteristics()

# Unpack
input_shape = data_characteristics["num_samples"], data_characteristics["num_timestamps"], data_characteristics["num_features"]
output_dim = data_characteristics["num_outcomes"]

8328it [01:02, 133.98it/s]


## Load Model

In [5]:
from importlib import reload
import src.models.VRNNGMM.model as Model
reload(Model)


<module 'src.models.VRNNGMM.model' from 'c:\\Users\\hruia\\PycharmProjects\\phd-repo-template\\src\\models\\VRNNGMM\\model.py'>

In [6]:

# Initialize model
model = VRNNGMM(
    input_dims=input_shape[-1],
    output_dim=output_dim,
    K=args["model_params"]["K"],
    latent_dim=args["model_params"]["latent_dim"],
    gate_num_hidden_layers=args["model_params"]["gate_num_hidden_layers"],
    gate_num_hidden_nodes=args["model_params"]["gate_num_hidden_nodes"],
    bias=args["model_params"]["bias"],
    dropout=args["model_params"]["dropout"],
    device=args["model_params"]["device"],
    seed=args["model_params"]["seed"],
    K_fold_idx=args["model_params"]["K_fold_idx"],
)


## Train Model

In [7]:

# Get whole training data and validation data
X_train, y_train = data_loader.get_train_X_y(fold=1)
X_val, y_val = data_loader.get_test_X_y(fold=1, mode="val")

# import wandb
# wandb.init(project="VRNNGMM", config=args)

model.fit(
    train_data=(X_train.astype(np.float32), y_train.astype(np.float32)),
    val_data=(X_val.astype(np.float32), y_val.astype(np.float32)),
    lr=args["train_params"]["lr"],
    batch_size=args["train_params"]["batch_size"],
    num_epochs=args["train_params"]["num_epochs"],
)

Printing Losses loss, Log Lik, KL
Epoch 1 (50) :  286.76 - -108.78 - 177.99 - 0.00     Val 1 (50): 283.61 - -106.27 - 177.34
Epoch 2 (50) :  279.28 - -102.01 - 177.27 - 0.00     Val 2 (50): 275.63 - -98.67 - 176.96
Epoch 3 (50) :  272.70 - -95.81 - 176.88 - 0.00     Val 3 (50): 270.15 - -93.52 - 176.64
Epoch 4 (50) :  266.58 - -89.95 - 176.63 - 0.00     Val 4 (50): 263.80 - -87.38 - 176.41
Epoch 5 (50) :  260.24 - -83.77 - 176.47 - 0.00     Val 5 (50): 256.06 - -79.78 - 176.28
Epoch 6 (50) :  253.78 - -77.43 - 176.35 - 0.00     Val 6 (50): 247.47 - -71.26 - 176.21
Epoch 7 (50) :  246.67 - -70.40 - 176.27 - 0.00     Val 7 (50): 238.00 - -61.84 - 176.16
Epoch 8 (50) :  239.49 - -63.28 - 176.21 - 0.00     Val 8 (50): 229.86 - -53.74 - 176.12
Epoch 9 (50) :  232.59 - -56.42 - 176.17 - 0.00     Val 9 (50): 219.60 - -43.51 - 176.09
Epoch 10 (50) :  225.03 - -48.90 - 176.13 - 0.00     Val 10 (50): 213.09 - -37.03 - 176.06
Epoch 11 (50) :  218.38 - -42.28 - 176.10 - 0.00     Val 11 (50): 207.2

## Evaluate on Test Set and Get Metrics

In [14]:
#### EVALUATE MODEL
# Evaluate on test data
X_test, y_test = data_loader.get_test_X_y(fold=1, mode="test")
y_pred, clus_pred, model_objects = model.predict(X=X_test.astype(np.float32), y=y_test.astype(np.float32))

In [15]:
# Convert to Labels
labels_test = np.argmax(y_test, axis=1)

# Compute Metrics and visualize
multiclass_dic = summary_multiclass_metrics(labels_true=labels_test, scores_pred=y_pred)
clustering_dic = summary_clustering_metrics(
    X=X_test.reshape(X_test.shape[0], -1),
    labels_true=labels_test,
    clus_pred=clus_pred
)
metrics_dict = {**multiclass_dic, **clustering_dic}

## Log Results and Performance

In [16]:

# Log Model, Results and Visualizations
cur_time_as_str = datetime.now().strftime("%Y%m%d-%H%M%S")
test_dir = f"results/VRNNGMM/{cur_time_as_str}/"

# Save outputs into data objects and run information
data_objects = {
    "y_pred": y_pred,
    "labels_test": labels_test,
    "X": (X_train, X_val, X_test),
    "y": (y_train, y_val, y_test),
    "test_output": model_objects
}

run_info = {
    "data_characteristics": data_characteristics,
    "args": args,
    "metrics": metrics_dict,
}

model.log_model(save_dir=test_dir, objects_to_log=data_objects, run_info=run_info)
print("Time taken: ", datetime.now() - start_time)

Time taken:  0:04:06.308193


## Logging into CSV

In [17]:

# ===================== CSV LOGGING
csv_path = "results/VRNNGMM/tracker.csv" 

params_header = [key for key in args.keys() if key not in ["model_params"]]
metrics_header = ["F1", "Precision", "Recall", "Auroc", "SIL", "DBI", "VRI"]
logger_utils.make_csv_if_not_exists(csv_path, params_header + metrics_header)

# Append Row
metrics_to_print = print_avg_metrics_paper(metrics_dict)
row_append = *[args[key] for key in params_header], *metrics_to_print
logger_utils.write_csv_row(csv_path, row_append)

macro_f1_score: 0.166
precision: 0.124
recall: 0.250
ovr_auroc: 0.508
silhouette: -0.092
davies_bouldin: 34.143
calinski_harabasz: 1.005


TypeError: write_csv_row() takes 2 positional arguments but 17 were given

## Analysis and Printing 

In [12]:
for name in ["macro_f1_score", "precision", "recall", "ovr_auroc"]:
    print(f"{name}: {np.mean(metrics_dict[name]):.3f}")

macro_f1_score: 0.166
precision: 0.124
recall: 0.250
ovr_auroc: 0.489


In [13]:
print(metrics_dict["confusion_matrix"])

[[   0    0    0    0]
 [   0    0    0    0]
 [   0    0    0    0]
 [  10 1144  529 1649]]
