In [None]:
from saliency_utils import *

import sys
from pathlib import Path
parent_dir = Path.cwd().parent
sys.path.insert(0, str(parent_dir))

from model_finetuning.ModelClassAC_CMR import *

import torch
import neurokit2 as nk
from fairseq_signals.models.wav2vec2.wav2vec2_cmsc import Wav2Vec2CMSCModel
from omegaconf import OmegaConf
from scipy.io import loadmat
import os

In [None]:
path_to_foundational_model = "/mnt/cat/jdeseo/ECG_FM/ckpts/mimic_iv_ecg_physionet_pretrained.pt" #path where the ECG-FM mimic_iv_ecg_physionet_pretrained.pt model is saved.
path_to_finetuned_weights = "/mnt/cat/jdeseo/ECG_FM/MyModels/finetuned_model_weights_LA_LV_ratio.pt" #path where the finetuned model weights are saved
path_to_ecgs = "/mnt/cat/jdeseo/ECG_FM/ECG_processed/segmented/" #path to the ecgs to be evaluated

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}", flush=True)

In [None]:
ckpt_encoder = torch.load(path_to_foundational_model, weights_only = False)
cfg_encoder = ckpt_encoder['cfg']
cfg_encoder = OmegaConf.create(cfg_encoder["model"])
cfg_encoder["saliency"] = False
print(cfg_encoder)
encoder = Wav2Vec2CMSCModel(cfg_encoder)

In [None]:
state_dict = torch.load(path_to_finetuned_weights)
finetuned_model = ECGRegressionModel(encoder, feature_dim=768, num_outputs=4)
finetuned_model.load_state_dict(state_dict)
finetuned_model.to(device)
finetuned_model.eval()

In [None]:
ecg_dirs = os.listdir(path_to_ecgs)
batch_size = 64
batch_num = int(len(ecg_dirs)/batch_size) + 1
batch_num

In [None]:
all_maps = []
all_ecgs = []
for batch in range(batch_num):  
    try:
        ecg_dirs_batch = ecg_dirs[batch*batch_size:(batch+1)*batch_size]
    except:
        ecg_dirs = ecg_dirs[batch*batch_size:]
    test_samples = [] 
    for i, ecg_file in enumerate(ecg_dirs):
        ecg_samples = []
        ecg_dir = path_to_ecgs + ecg_file
        ecg = loadmat(ecg_dir)["feats"]
        ecg_tensor = torch.from_numpy(ecg).float()
        test_samples.append(ecg_tensor) 

    X_test = torch.stack(test_samples, dim=0) 
    saliency_maps = compute_vanilla_gradient_saliency(finetuned_model, X_test, target_output_idx=1) #Target outputs: 0: LA max, 1: LA min, 2: LAEF, 3:LALV 
    all_maps.append(saliency_maps)
    all_ecgs.append(X_test)


In [None]:
saliency_maps = torch.cat(all_maps)
saliency_collapsed = torch.mean(saliency_maps, dim=1)
all_ecgs = torch.cat(all_ecgs)

In [None]:
sample_id = "id_to_visualize (int)"
visualize_ecg_saliency(all_ecgs, saliency_maps, sample_idx=sample_id, smooth = True)

In [None]:
data = {"ID" : [], "P_wave" : [], "PQ" : [], "QRS" : [], "ST" : [], "T_wave": [], "TP": [], "num_heartbeats": [], "Total_saliency" : []}

for i, ecg_file in enumerate(ecg_dirs):
    ecg_dir = path_to_ecgs + ecg_file
    ecg = loadmat(ecg_dir)["feats"]
    try:
        _, info = nk.ecg_process(ecg[1,:], sampling_rate=500, method="neurokit")
        P_on = info["ECG_P_Onsets"]
        P_off = info["ECG_P_Offsets"]
        Q = info["ECG_Q_Peaks"]
        S = info["ECG_S_Peaks"]
        T_on = info["ECG_T_Onsets"]
        T_off = info["ECG_T_Offsets"]   
        saliency = saliency_collapsed[i,:].cpu()
        print(saliency.shape)
        avg_sal, num_cycles = average_saliency_per_segment(P_on, P_off, Q, S, T_on, T_off, saliency)

        for key in data.keys():
            if key == "ID":
                data["ID"].append(ecg_file)
            elif key == "num_heartbeats":
                data[key].append(num_cycles)
            else:
                data[key].append(avg_sal[key])
    except:
        for key in data.keys():
            if key == "ID":
                data["ID"].append(ecg_file)
            elif key == "num_heartbeats":
                data[key].append(0)
            else:
                data[key].append(None)
