# Code to Run Single Experiments for Traditional Classifiers

#### Do not change the cell below

In [1]:
import os
os.chdir("../../..")

## Load Libraries

In [2]:

# Import required packages
import json
import numpy as np

from src.models.GMM.model import GMM
from src.dataloading.TrainingDataLoader import TrainingDataLoader

from src.metrics.summarize import summary_all_metrics
from src.visualization.metrics_and_losses import plot_multiclass_metrics

from datetime import datetime



## Configuration 

In [3]:
start_time = datetime.now()

args = {
    "model_name": "GMM",
    "data_dir":"data/MIMIC/processed",
    "time_window": [0, 10],
    "feat_subset":"vitals-static",
    "train_test_ratio":0.6,
    "train_val_ratio":0.6,
    "seed": 3535,
    "normalize": True,
    "num_folds": 1,
    "model_params": {
        "K": 10,
        "random_state": 3535
    }
}

## Load Data and Process

In [4]:

start_time = datetime.now()

#### LOAD DATA
data_loader = TrainingDataLoader(
    data_dir=args["data_dir"],
    time_window=args["time_window"],
    feat_subset=args["feat_subset"],
    train_test_ratio=args["train_test_ratio"],
    train_val_ratio=args["train_val_ratio"],
    seed=args["seed"],
    normalize=args["normalize"],
    num_folds=args["num_folds"]
)
data_characteristics = data_loader._get_data_characteristics()


# Unpack
input_shape = data_characteristics["num_samples"], data_characteristics["num_timestamps"], data_characteristics["num_features"]
output_dim = data_characteristics["num_outcomes"]

8328it [00:50, 164.41it/s]


## Load Model

In [5]:

#### TRAIN MODEL
# Initialize model
model = GMM(input_shape=input_shape, 
            output_dim=output_dim, 
            model_name=args["model_name"], 
            K=args["model_params"]["K"], 
            random_state=args["model_params"]["random_state"]
            )


## Train Model

In [6]:

# Get whole training data and validation data
X_train, y_train = data_loader.get_train_X_y(fold=0)
model.train(train_data=(X_train, y_train))



Training GMM model...




## Evaluate on Test Set and Get Metrics

In [7]:

#### EVALUATE MODEL
# Evaluate on test data
X_test, y_test = data_loader.get_test_X_y(fold=0, mode="test")
X_test_2D = X_test.reshape(X_test.shape[0], -1)
y_pred, clus_pred = model.predict(X_test_2D)

# Convert to Labels
labels_test = np.argmax(y_test, axis=1)

# Compute Metrics and visualize
metrics_dict = summary_all_metrics(
    labels_true=labels_test, scores_pred=y_pred,
    X=X_test_2D, clus_pred=clus_pred)
# ax, lachiche_ax = plot_multiclass_metrics(metrics_dict=metrics_dict, class_names=data_characteristics["outcomes"])


## Log Results and Performance

In [8]:

# Log Model, Results and Visualizations
cur_time_as_str = datetime.now().strftime("%Y%m%d-%H%M%S")
test_dir = f"results/{args['model_name']}/{cur_time_as_str}/"

means, covariances, cluster_probs = model.get_model_objects()
run_info = {
    "data_characteristics": data_characteristics,
    "args": args,
    "metrics": metrics_dict
}
output_dir = {
    "data": {
        "X": (X_train, X_test),
        "y": (y_train, y_test)
    },
    "model": {
        "means": means,
        "covariances": covariances,
        "cluster_probs": cluster_probs
    },
    "labels_test": labels_test,
    "y_pred": y_pred,
    "clus_pred": clus_pred,
}

model.log_model(save_dir=test_dir, objects_to_log=output_dir, run_info=run_info)
print("Time taken: ", datetime.now() - start_time)

Time taken:  0:01:00.972522


In [9]:
for name in ["macro_f1_score", "precision", "recall", "ovr_auroc", 
            "silhouette", "davies_bouldin", "calinski_harabasz"]:
    print(f"{name}: {np.mean(metrics_dict[name]):.3f}")

macro_f1_score: 0.183
precision: 0.330
recall: 0.256
ovr_auroc: 0.626
silhouette: -0.043
davies_bouldin: 6.887
calinski_harabasz: 41.486


In [10]:
print(metrics_dict["confusion_matrix"])

[[   0    0    0    0]
 [   0   10    0   14]
 [   2   10   15   10]
 [   8 1124  514 1625]]


In [11]:
run_info

{'data_characteristics': {'num_samples': 8328,
  'num_timestamps': 11,
  'num_features': 9,
  'num_outcomes': 4,
  'features': ['TEMP',
   'HR',
   'RR',
   'SPO2',
   'SBP',
   'DBP',
   'age',
   'gender',
   'ESI'],
  'outcomes': ['Death', 'Discharge', 'ICU', 'Ward']},
 'args': {'model_name': 'GMM',
  'data_dir': 'data/MIMIC/processed',
  'time_window': [0, 10],
  'feat_subset': 'vitals-static',
  'train_test_ratio': 0.6,
  'train_val_ratio': 0.6,
  'seed': 3535,
  'normalize': True,
  'num_folds': 1,
  'model_params': {'K': 10, 'random_state': 3535}},
 'metrics': {'accuracy': array([0.9969988 , 0.6554622 , 0.83913565, 0.49879953], dtype=float32),
  'macro_f1_score': array([0.        , 0.01712329, 0.05300353, 0.6605691 ], dtype=float32),
  'micro_f1_score': 0.49519807,
  'precision': array([0.        , 0.41666666, 0.4054054 , 0.49678996], dtype=float32),
  'recall': array([0.        , 0.00874126, 0.02835539, 0.98544574], dtype=float32),
  'ovr_auroc': array([0.70942205, 0.580907  , 

In [12]:
print("Means: ", np.reshape(means, (model.K, *model.input_shape[1:])))
print("Covariances: ", np.reshape(covariances, (model.K, *model.input_shape[1:])))
print("Cluster Probs: ", cluster_probs)

Means:  [[[9.88355412e+01 9.98460914e+01 1.98649271e+01 9.64807641e+01
   1.29293043e+02 7.49065264e+01 6.12835318e+01 1.00000000e+00
   2.00000000e+00]
  [9.88362456e+01 9.98667538e+01 1.98502456e+01 9.64885224e+01
   1.29374205e+02 7.49341124e+01 6.12835318e+01 1.00000000e+00
   2.00000000e+00]
  [9.88313264e+01 9.97583553e+01 1.98531619e+01 9.64756572e+01
   1.29258384e+02 7.49898655e+01 6.12835318e+01 1.00000000e+00
   2.00000000e+00]
  [9.88853592e+01 9.95415711e+01 1.98982110e+01 9.65030986e+01
   1.29040012e+02 7.46854927e+01 6.12835318e+01 1.00000000e+00
   2.00000000e+00]
  [9.88729726e+01 9.94660500e+01 1.98328172e+01 9.65133936e+01
   1.29098929e+02 7.45573585e+01 6.12835318e+01 1.00000000e+00
   2.00000000e+00]
  [9.89311375e+01 9.93078473e+01 1.97583355e+01 9.64580578e+01
   1.28473155e+02 7.42254226e+01 6.12835318e+01 1.00000000e+00
   2.00000000e+00]
  [9.89467446e+01 9.87113578e+01 1.98135317e+01 9.64869899e+01
   1.27813849e+02 7.35837494e+01 6.12835318e+01 1.00000000e