In [None]:
import os
import sys
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchmetrics.classification import MultilabelAccuracy
from torchmetrics.classification import MultilabelAUROC
import pytorch_lightning as pl
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

sys.path.append('../..')
from multi_modal_heart.model.ecg_net_attention import ECGEncoder,ECGAttentionAE
from multi_modal_heart.model.ecg_net import ECGAE
from multi_modal_heart.model.ecg_net import BenchmarkClassifier
from multi_modal_heart.ECG.ecg_dataset import ECGDataset

class LitClassifier(pl.LightningModule):
    def __init__(self,encoder,input_dim,num_classes=5):
        super().__init__()
        
        self.encoder = encoder
        #### add classifier if use benchmark classifier
        self.downsteam_net = BenchmarkClassifier(input_size=input_dim,hidden_size=128,output_size=num_classes)
    def forward(self, x, mask):
        latent_code = self.encoder.get_features_after_pooling(x,mask)
        return self.downsteam_net(latent_code)
    
def print_result(probs,super_classes_labels, topk=1):
    probs, label_indices = torch.topk(probs, topk)
    probs = probs.tolist()
    label_indices = label_indices.tolist()
    for prob, idx in zip(probs, label_indices):
        label = super_classes_labels[idx]
        print(f'{label} ({idx}):', round(prob, 4))
def calc_hamming_score(y_true, y_pred):
    return (
        (y_true & y_pred).sum(axis=1) / (y_true | y_pred).sum(axis=1)
    ).mean()    
# ecg_net = ECGAttentionAE(num_leads=12, time_steps=1024, z_dims=512, linear_out=512, downsample_factor=5, base_feature_dim=4,if_VAE=False,
#                          use_attention_pool=False,no_linear_in_E=True, apply_lead_mask=False, no_time_attention=False)
# classification_net = LitClassifier(encoder=ecg_net.encoder,input_dim=512,num_classes=5)
# # checkpoint_path  ="../../log_finetune/ECG_attention_512_raw_no_attention_pool_no_linear_abl_no_time_attention_ms_resnet/checkpoints/last-v3.ckpt"
# # checkpoint_path  ="../../log_finetune/ECG_attention_512_raw_no_attention_pool_no_linear_ms_resnet_ECG2Text/checkpoints/last-v5.ckpt"
# checkpoint_path = "../../log_finetune/ECG_attention_512_raw_no_attention_pool_no_linear_ms_resnet/checkpoints/last-v8.ckpt"
# print (torch.load(checkpoint_path)["state_dict"].keys())
# mm_checkpoint = torch.load(checkpoint_path)["state_dict"]
# encoder_params = {(".").join(key.split(".")[1:]):value for key, value in mm_checkpoint.items() if str(key).startswith("encoder")}
# classification_params = {(".").join(key.split(".")[1:]):value for key, value in mm_checkpoint.items() if str(key).startswith("downsteam_net")}
# classification_net.encoder.load_state_dict(encoder_params)
# classification_net.downsteam_net.load_state_dict(classification_params)



In [None]:
import torch
import sys
sys.path.append('../../')
from multi_modal_heart.model.ecg_net import ECGAE
use_median_wave = True
time_steps = 608
ecg_net= ECGAE(encoder_type="resnet1d101",in_channels=12,ECG_length=time_steps,decoder_type="ms_resnet",
                    embedding_dim=256,latent_code_dim=512,
                    add_time=False,
                    encoder_mha = False,
                    apply_method="",
                    decoder_outdim=12)
classification_net = LitClassifier(encoder=ecg_net.encoder,input_dim=512,num_classes=5)
# resnet_checkpoint = '../../log_finetune/resnet1d101_512+benchmark_classifier_ms_resnet/checkpoints/epoch=23-val_auroc:benchmark_classifier/val_macro_auc=0.91.ckpt'
resnet_checkpoint = "/home/engs2522/project/multi-modal-heart/log_median_finetune/resnet1d101_512+benchmark_classifier_ms_resnet_ECG2Text/checkpoints/checkpoint_best_loss.ckpt"
checkpoint = torch.load(resnet_checkpoint)["state_dict"]
encoder_params = {(".").join(key.split(".")[1:]):value for key, value in checkpoint.items() if str(key).startswith("encoder")}
classification_params = {(".").join(key.split(".")[1:]):value for key, value in checkpoint.items() if str(key).startswith("downsteam_net")}
classification_net.encoder.load_state_dict(encoder_params)
classification_net.downsteam_net.load_state_dict(classification_params)


In [None]:
# use_median_wave = False
# # checkpoint_path = "../../log_finetune/ECG_attention_512_raw_no_attention_pool_no_linear_ms_resnet/checkpoints/epoch=49-val_auroc:benchmark_classifier/val_macro_auc=0.90.ckpt"
checkpoint_path = "../../log_median_finetune/ECG_attention_512_raw_no_attention_pool_no_linear_ms_resnet_ECG2Text/checkpoints/checkpoint_best_loss.ckpt"
# checkpoint_path = "../../log_finetune/ECG_attention_512_raw_no_attention_pool_no_linear_ms_resnet/checkpoints/epoch=49-val_auroc:benchmark_classifier/val_macro_auc=0.90.ckpt"

use_median_wave = True
time_steps = 608
ecg_net = ECGAttentionAE(num_leads=12, time_steps=time_steps, z_dims=512, linear_out=512, downsample_factor=5, base_feature_dim=4,if_VAE=False,use_attention_pool=False,
                         no_linear_in_E=True, apply_lead_mask=False)
classification_net = LitClassifier(encoder=ecg_net.encoder,input_dim=512,num_classes=5)

print (torch.load(checkpoint_path)["state_dict"].keys())
mm_checkpoint = torch.load(checkpoint_path)["state_dict"]
encoder_params = {(".").join(key.split(".")[1:]):value for key, value in mm_checkpoint.items() if str(key).startswith("encoder")}
classification_params = {(".").join(key.split(".")[1:]):value for key, value in mm_checkpoint.items() if str(key).startswith("downsteam_net")}
classification_net.encoder.load_state_dict(encoder_params)
classification_net.downsteam_net.load_state_dict(classification_params)


In [None]:
## load MI data from UKB
from scipy.stats import zscore

mi_data_path = "/home/engs2522/project/multi-modal-heart/multi_modal_heart/toolkits/ukb/non_imaging_information/MI/batched_ecg_median_wave.npy"

test_data = np.load(mi_data_path)
## zero mean and unit variance
test_data = zscore(test_data,axis=2)
test_data = np.nan_to_num(test_data)

## pad the data to 1024
test_data = np.pad(test_data,((0,0),(0,0),((time_steps-test_data.shape[2])//2,(time_steps-test_data.shape[2])//2)),"constant",constant_values=0)
print(test_data.shape)

In [None]:
test_data_tensor = torch.from_numpy(test_data).float().cuda()

In [None]:
## get model prediction
classification_net.eval()
classification_net.freeze()
with torch.no_grad():
    classification_net.cuda()
    probs = classification_net(test_data_tensor,None)
    probs = torch.sigmoid(probs)


In [None]:
probs

In [None]:
## use threshold to get prediction
predict_labels=probs>=0.5
print(predict_labels.shape)
## MI class is the second one.
predict_labels = predict_labels[:,1]
print(predict_labels.shape)
print(torch.sum(predict_labels))
print('acc:',predict_labels.sum()/len(predict_labels))

In [None]:
## evaluate MI data
from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer()
result = mlb.fit_transform([["NORM", "MI", "HYP","STTC","CD"]])
# print (binarizer.classes_)
mlb.classes_ = np.array(["NORM", "MI", "HYP","STTC","CD"])


In [None]:
ecg_net = ECGAttentionAE(num_leads=12, time_steps=1024, z_dims=512, linear_out=512, downsample_factor=5, base_feature_dim=4,if_VAE=False,use_attention_pool=False,
                         no_linear_in_E=True, apply_lead_mask=False)
classification_net = LitClassifier(encoder=ecg_net.encoder,input_dim=512,num_classes=5)