# setup

In [1]:
import torch
import os

In [2]:
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

from data.load_data import LoadData
from models.baseline import ResnetBaseline

from models.moe import ResnetMoE
from utils import synthesis

# init

In [3]:
# model_label = 'gate'
model_label = 'moe'

In [4]:
if model_label == 'gate':
    from configs.gate import LoadDataConfig, Downstream_cnn_args

    loader_config = LoadDataConfig()
    resnet_config = Downstream_cnn_args()

    dataloader = LoadData(**loader_config.__dict__)
    reference = ResnetBaseline(**resnet_config.__dict__)

if model_label == 'moe':
    from configs.baseline import LoadDataConfig
    from configs.moe import MoE_cnn_args

    loader_config = LoadDataConfig()
    moe_config = MoE_cnn_args()

    dataloader = LoadData(**loader_config.__dict__)
    reference = ResnetMoE(**moe_config.__dict__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# threshould

In [5]:
model = torch.load('output/transfer_{}.pt'.format(model_label))
# assert model.state_dict().keys() == reference.state_dict().keys()
model = model.to(device)

In [6]:
val_dl = dataloader.get_val_dataloader()
best_f1s, best_thresholds = synthesis(model, val_dl, None, device)
best_f1s, best_thresholds

  0%|          | 0/272 [00:00<?, ?it/s]

6


100%|██████████| 272/272 [04:56<00:00,  1.09s/it]


6


([0.5854513584574934,
  0.8181818181818182,
  0.802547770700637,
  0.6628099173553719,
  0.7417974322396577,
  0.7488738738738739],
 [0.33, 0.46, 0.45, 0.41000000000000003, 0.52, 0.41000000000000003])

In [7]:
# import numpy as np
# from tqdm import tqdm
# from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# from utils import get_inputs

In [8]:
# num_classes = 6
# thresholds = np.arange(0, 1.01, 0.01)  # Array of thresholds from 0 to 1 with step 0.01
# predictions = {thresh: [[] for _ in range(num_classes)] for thresh in thresholds}
# true_labels_dict = [[] for _ in range(num_classes)]

In [9]:
# val_dl = dataloader.get_val_dataloader()

In [10]:
# model.eval()
# with torch.no_grad():
#     for val_batch in tqdm(val_dl):
#         raw, exam_id, label = val_batch
#         ecg = get_inputs(raw).to(device)
#         label = label.to(device).float()

#         logits = model(ecg)
#         probs = torch.sigmoid(logits)

#         for class_idx in range(num_classes):
#             for thresh in thresholds:
#                 predicted_binary = (probs[:, class_idx] >= thresh).float()
#                 predictions[thresh][class_idx].extend(
#                     predicted_binary.cpu().numpy()
#                 )
#             true_labels_dict[class_idx].extend(
#                 label[:, class_idx].cpu().numpy()
#             )

In [11]:
# best_thresholds = [0.5] * num_classes
# best_f1s = [0.0] * num_classes

# for class_idx in (range(num_classes)):
#     for thresh in thresholds:
#         f1 = f1_score(
#             true_labels_dict[class_idx],
#             predictions[thresh][class_idx],
#             zero_division=0,
#         )

#         if f1 > best_f1s[class_idx]:
#             best_f1s[class_idx] = f1
#             best_thresholds[class_idx] = thresh

In [12]:
# best_f1s, best_thresholds

# test

In [13]:
test_dl = dataloader.get_test_dataloader()
all_binary_results, all_true_labels, metrics_dict = synthesis(model, test_dl, best_thresholds, device)
metrics_dict

100%|██████████| 135/135 [03:35<00:00,  1.60s/it]

6





{'Accuracy': [0.9837346608011114,
  0.9883653623523964,
  0.9921856911322066,
  0.9881338272748321,
  0.9895230377402177,
  0.9863973141931003],
 'Precision': [0.5544871794871795,
  0.7575757575757576,
  0.7388535031847133,
  0.5836065573770491,
  0.739454094292804,
  0.6766355140186916],
 'Recall': [0.5492063492063493,
  0.8875739644970414,
  0.8140350877192982,
  0.6953125,
  0.7967914438502673,
  0.8537735849056604],
 'F1 Score': [0.5518341307814992,
  0.8174386920980927,
  0.7746243739565942,
  0.6345811051693405,
  0.767052767052767,
  0.7549530761209593]}

In [14]:
# test_dl = dataloader.get_test_dataloader()

In [15]:
# all_binary_results = []
# all_true_labels = []

In [16]:
# model.eval()
# with torch.no_grad():
#     for test_batch in tqdm(test_dl):
#         raw, exam_id, label = test_batch
#         ecg = get_inputs(raw).to(device)
#         label = label.to(device).float()

#         logits = model(ecg)
#         probs = torch.sigmoid(logits)

#         binary_result = torch.zeros_like(probs)
#         for i in range(len(best_thresholds)):
#             binary_result[:, i] = (
#                 probs[:, i] >= best_thresholds[i]
#             ).float()

#         # Append binary results and true labels for this batch
#         all_binary_results.append(binary_result)
#         all_true_labels.append(label)
# all_binary_results = torch.cat(all_binary_results, dim=0)
# all_true_labels = torch.cat(all_true_labels, dim=0)

In [17]:
# accuracy_scores = []
# precision_scores = []
# recall_scores = []
# f1_scores = []

In [18]:
# for class_idx in range(num_classes):
#     class_binary_results = all_binary_results[:, class_idx].cpu().numpy()
#     class_true_labels = all_true_labels[:, class_idx].cpu().numpy()

#     accuracy = accuracy_score(class_true_labels, class_binary_results)
#     precision = precision_score(
#         class_true_labels, class_binary_results, zero_division=0
#     )
#     recall = recall_score(
#         class_true_labels, class_binary_results, zero_division=0
#     )
#     f1 = f1_score(class_true_labels, class_binary_results, zero_division=0)

#     accuracy_scores.append(accuracy)
#     precision_scores.append(precision)
#     recall_scores.append(recall)
#     f1_scores.append(f1)

# metrics_dict = {
#     "Class": dataloader.output_col,
#     "Accuracy": accuracy_scores,
#     "Precision": precision_scores,
#     "Recall": recall_scores,
#     "F1 Score": f1_scores,
# }

In [19]:
# metrics_dict