In [None]:
import os
import sys
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchmetrics.classification import MultilabelAccuracy
from torchmetrics.classification import MultilabelAUROC
import pytorch_lightning as pl
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F
sys.path.append('../..')
from multi_modal_heart.model.ecg_net_attention import ECGEncoder,ECGAttentionAE
from multi_modal_heart.model.ecg_net import ECGAE
from multi_modal_heart.model.ecg_net import BenchmarkClassifier
from multi_modal_heart.ECG.ecg_dataset import ECGDataset
from torchmetrics import Accuracy

class LitClassifier(pl.LightningModule):
    def __init__(self,encoder,input_dim,num_classes=2,lr=1e-3,freeze_encoder=False):
        super().__init__()
        self.lr =lr
        self.freeze_encoder = freeze_encoder
        self.num_classes = num_classes
        self.encoder = encoder
        if self.freeze_encoder:
            self.encoder.eval()
            for param in self.encoder.parameters():
                param.requires_grad = False
        self.accu_metric = Accuracy(task="multiclass",num_classes=num_classes)
        self.test_feature_list =[]
        self.test_eid_list = []
        self.preds_list = []
        self.latent_feature_list = []
        self.y_list = []
        self.test_feature  = None
        self.attention_map_classwise = {}
        for i in range(num_classes):
            self.attention_map_classwise[i] = []
        #### add classifier if use benchmark classifier
  
        #### add classifier if use benchmark classifier
        self.downsteam_net = BenchmarkClassifier(input_size=input_dim,hidden_size=128,output_size=num_classes)
    def forward(self, x, mask=None):
        latent_code = self.encoder(x,mask)
        # print (latent_code.shape)
        self.latent_code  = latent_code
        return self.downsteam_net(latent_code)
    
    def training_step(self, batch, batch_idx):
        x = batch[0]
        y= batch[1]
       
        latent_code = self.encoder(x,mask=None)
        self.latent_code  = latent_code
        y_hat = self.downsteam_net(latent_code)
        # print (y_hat.shape)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss,prog_bar=True)
        return loss
    def test_step(self, batch, batch_idx):
        x = batch[0]
        y= batch[1]
        eid_list = batch[2]
        logits = self(x,mask=None)
        loss = F.cross_entropy(logits, y)
        self.log("test_loss", loss)
        ## save the feature
        preds = torch.argmax(logits, dim=1)
        last_hidden_feature = self.downsteam_net.get_features(self.latent_code)
        eid_list = batch[2]
        preds = torch.argmax(logits, dim=1)
        ## save the attention map
        try:
            lead_attention,_ = self.encoder.get_attention()
            # print(lead_attention[i].shape)
            ## make the [np.array] to numpy array
            # time_attention_map = np.array(time_attention_map)
            for i, eid in enumerate(eid_list):
                if preds[i]==y[i]:
                    print(f"correct prediction: {y[i]}, eid: {eid}")
                    self.attention_map_classwise[y[i]].append([eid,lead_attention[i]])
        except Exception as e:
            print(e)
            print("no attention map, skip")
            
        self.latent_feature_list.append(self.latent_code)
        self.test_feature_list.append(last_hidden_feature)
        self.test_eid_list.append(eid_list)
        self.preds_list.append(preds)
        self.y_list.append(y)

        self.accu_metric.update(preds, y)
        return loss
    def on_test_epoch_end(self):
        self.log("test_acc", self.accu_metric.compute())
        ## save the feature
        self.test_latent_feature = torch.cat(self.latent_feature_list,dim=0)
        self.test_feature = torch.cat(self.test_feature_list,dim=0)
        self.test_eid = torch.cat(self.test_eid_list,dim=0)
        self.preds = torch.cat(self.preds_list,dim=0)
        self.y = torch.cat(self.y_list,dim=0)
        self.attention_name_map_list = []
        ## if has attention map, save the attention map
        for key in self.attention_map_classwise.keys():
            group_attention = self.attention_map_classwise[key]
            if group_attention==[]:
                continue
            extracted_attention_maps= [v[1] for v in group_attention]
            if len(extracted_attention_maps)>0:
                extracted_attention_maps = np.stack(extracted_attention_maps,axis=0)
                average_attention_map_norm = np.mean(extracted_attention_maps,axis=0)
                std_attention_map_norm = np.std(extracted_attention_maps,axis=0)
                avg_name = f"AVG_attention_map_class_{key}"
                std_name = f"STD_attention_map_class_{key}"

                self.attention_name_map_list.append([avg_name,average_attention_map_norm])
                self.attention_name_map_list.append([std_name,std_attention_map_norm])

    
    def validation_step(self, batch, batch_idx):
        x = batch[0]
        y= batch[1]
        y_hat = self(x,mask=None)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)
        return loss
    def clear_cache(self):
        self.accu_metric.reset()
        self.latent_feature_list = []
        self.test_feature_list =[]
        self.test_eid_list = []
        self.preds_list = []
        self.y_list = []
        self.test_feature  = None
        self.attention_map_classwise={}
        for i in range(self.num_classes):
            self.attention_map_classwise[i] = []
    def configure_optimizers(self):
        return torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr, weight_decay=1e-4)
    
def print_result(probs,super_classes_labels, topk=1):
    probs, label_indices = torch.topk(probs, topk)
    probs = probs.tolist()
    label_indices = label_indices.tolist()
    for prob, idx in zip(probs, label_indices):
        label = super_classes_labels[idx]
        print(f'{label} ({idx}):', round(prob, 4))
def calc_hamming_score(y_true, y_pred):
    return (
        (y_true & y_pred).sum(axis=1) / (y_true | y_pred).sum(axis=1)
    ).mean()    


In [None]:
## load MI data from UKB
import numpy as np
from scipy.stats import zscore
time_steps =608
mi_data_path = "/home/engs2522/project/multi-modal-heart/multi_modal_heart/toolkits/ukb/non_imaging_information/MI/batched_ecg_median_wave.npy"
healthy_data_path = "/home/engs2522/project/multi-modal-heart/multi_modal_heart/toolkits/ukb/non_imaging_information/non_CVD/batched_ecg_median_wave_1045.npy"
eid_list_mi= np.load("/home/engs2522/project/multi-modal-heart/multi_modal_heart/toolkits/ukb/non_imaging_information/MI/median_eid_list.npy")
eid_list_healthy = np.load("/home/engs2522/project/multi-modal-heart/multi_modal_heart/toolkits/ukb/non_imaging_information/non_CVD/selected_1045_healthy_subjects_eid_list.npy")
print ("load MI data")
hf_data = np.load(mi_data_path)

healthy_data = np.load(healthy_data_path)

## check duplicate eid
duplicate_eid = []

hf_data = zscore(hf_data,axis=-1)
healthy_data = zscore(healthy_data,axis=-1)
hf_data = np.nan_to_num(hf_data)
healthy_data = np.nan_to_num(healthy_data)


## pad the data to 608
pad_num = (time_steps-healthy_data.shape[-1])//2
hf_data = np.pad(hf_data,((0,0),(0,0),(pad_num,pad_num)),"constant",constant_values=0)
healthy_data = np.pad(healthy_data,((0,0),(0,0),(pad_num,pad_num)),"constant",constant_values=0)


labels = np.concatenate([np.ones(hf_data.shape[0]),np.zeros(healthy_data.shape[0])])
eid_full_list = np.concatenate([eid_list_mi,eid_list_healthy])
data = np.concatenate([hf_data,healthy_data],axis=0)
print (data.shape)
print (labels.shape)
print (len(eid_full_list))
## append the eid to the label
labels_eid = np.zeros((labels.shape[0],2))
labels_eid[:,0] = labels
labels_eid[:,1] = eid_full_list
print (labels_eid.shape)
# labels_eid.shape

## split the data into train validate and test, 40% for train, 10% for validate, 50% for test
from sklearn.model_selection import train_test_split
X_trainval, X_test, y_trainval_eid, y_test_eid = train_test_split(data, labels_eid, test_size=0.5, random_state=42)
X_train, X_val, y_train_eid, y_val_eid= train_test_split(X_trainval, y_trainval_eid, test_size=0.1, random_state=42)
y_train = y_train_eid[:,0]
y_val = y_val_eid[:,0]
y_test = y_test_eid[:,0]
print ('num of training data:{}, MI ratio:{}'.format(X_train.shape[0],y_train.sum()))
print ('num of validation data:{}, MI ratio:{}'.format(X_val.shape[0],y_val.sum()))
print ('num of test data:{}, MI ratio:{}'.format(X_test.shape[0],y_test.sum()))


In [None]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import pytorch_lightning as pl

batch_size  = 128
tensor_x = torch.Tensor(X_train) # transform to torch tensor
tensor_y = torch.Tensor(y_train).long()
tensor_eid = torch.Tensor(y_train_eid[:,1]).long()
my_train_dataset = TensorDataset(tensor_x,tensor_y,tensor_eid) # create your datset
my_dataloader = DataLoader(my_train_dataset,batch_size=batch_size) # create your dataloader

## validation data
tensor_x = torch.Tensor(X_val) # transform to torch tensor
tensor_y = torch.Tensor(y_val).long()
tensor_eid = torch.Tensor(y_val_eid[:,1]).long()

my_val_dataset =TensorDataset(tensor_x,tensor_y,tensor_eid) # create your datset
my_val_dataloader = DataLoader(my_val_dataset,batch_size=batch_size) # create your dataloader

## test data
tensor_x = torch.Tensor(X_test) # transform to torch tensor
tensor_y = torch.Tensor(y_test).long()
tensor_eid = torch.Tensor(y_test_eid[:,1]).long()
my_test_dataset = TensorDataset(tensor_x,tensor_y,tensor_eid) # create your datset
my_test_dataloader = DataLoader(my_test_dataset,batch_size=batch_size) # create your dataloader

In [None]:
import torch
import sys
sys.path.append('../../')
from multi_modal_heart.model.ecg_net import ECGAE
import torch.nn as nn
use_median_wave = True
time_steps = 608
model_name = "ECG_attention_pretrained_on_recon_ECG2Text"

if model_name=="resnet1d101_512_pretrained_recon" or model_name=="resnet1d101_512":
    ecg_net= ECGAE(encoder_type="resnet1d101",in_channels=12,ECG_length=time_steps,decoder_type="ms_resnet",
                    embedding_dim=256,latent_code_dim=512,
                    add_time=False,
                    encoder_mha = False,
                    apply_method="",
                    decoder_outdim=12)
    checkpoint_path = "../../log_median/resnet1d101_512+benchmark_classifier_ms_resnet/checkpoints/checkpoint_best_loss-v1.ckpt"
elif model_name=="resnet1d101_512_pretrained_recon+ECG2Text":
    ecg_net= ECGAE(encoder_type="resnet1d101",in_channels=12,ECG_length=time_steps,decoder_type="ms_resnet",
                    embedding_dim=256,latent_code_dim=512,
                    add_time=False,
                    encoder_mha = False,
                    apply_method="",
                    decoder_outdim=12)
    checkpoint_path = "../../log_median/resnet1d101_512+benchmark_classifier_ms_resnet_ECG2Text/checkpoints/checkpoint_best_loss-v2.ckpt"
    # checkpoint_path = "../../log_median/resnet1d101_512+benchmark_classifier_ms_resnet_ECG2Text/checkpoints/last-v2.ckpt"
elif model_name=="resnet1d101_512_pretrained_classification":
    ecg_net= ECGAE(encoder_type="resnet1d101",in_channels=12,ECG_length=time_steps,decoder_type="ms_resnet",
                    embedding_dim=256,latent_code_dim=512,
                    add_time=False,
                    encoder_mha = False,
                    apply_method="",
                    decoder_outdim=12)
    # checkpoint_path = "/home/egs2522/project/multi-modal-heart/log_median_finetune/resnet1d101_512+benchmark_classifier_ms_resnet/checkpoints/checkpoint_best_loss.ckpt"
    checkpoint_path  ="/home/engs2522/project/multi-modal-heart/log_median_finetune/resnet1d101_512+benchmark_classifier_raw_ms_resnet/checkpoints/checkpoint_best_val_macro_auc.ckpt"

elif model_name=="resnet1d101_512_pretrained_classification+ECG2Text":
    ecg_net= ECGAE(encoder_type="resnet1d101",in_channels=12,ECG_length=time_steps,decoder_type="ms_resnet",
                    embedding_dim=256,latent_code_dim=512,
                    add_time=False,
                    encoder_mha = False,
                    apply_method="",
                    decoder_outdim=12)
    # checkpoint_path = "/home/engs2522/project/multi-modal-heart/log_median_finetune/resnet1d101_512+benchmark_classifier_ms_resnet_ECG2Text/checkpoints/checkpoint_best_loss.ckpt"
    checkpoint_path  ="/home/engs2522/project/multi-modal-heart/log_median_finetune/resnet1d101_512+benchmark_classifier_raw_ms_resnet_ECG2Text/checkpoints/checkpoint_best_val_macro_auc-v1.ckpt"

elif model_name=="ECG_attention_pretrained_on_classification":
    ecg_net  = ECGAttentionAE(num_leads=12, time_steps=time_steps, z_dims=512, linear_out=512, downsample_factor=5, base_feature_dim=4,if_VAE=False,use_attention_pool=False,
                         no_linear_in_E=True, apply_lead_mask=False)
    checkpoint_path = "../../log_median_finetune/ECG_attention_512_raw_no_attention_pool_no_linear_ms_resnet/checkpoints/checkpoint_best_val_macro_auc.ckpt"
elif model_name=="ECG_attention_pretrained_on_classification+ECG2Text":
    ecg_net  = ECGAttentionAE(num_leads=12, time_steps=time_steps, z_dims=512, linear_out=512, downsample_factor=5, base_feature_dim=4,if_VAE=False,use_attention_pool=False,
                         no_linear_in_E=True, apply_lead_mask=False)
    checkpoint_path = "../../log_median_finetune/ECG_attention_512_raw_no_attention_pool_no_linear_ms_resnet_ECG2Text/checkpoints/checkpoint_best_val_macro_auc-v1.ckpt"

elif model_name=="ECG_attention_pretrained_on_recon" or model_name =="ECG_attention":
    ecg_net  = ECGAttentionAE(num_leads=12, time_steps=time_steps, z_dims=512, linear_out=512, 
                            downsample_factor=5, base_feature_dim=4,if_VAE=False,
                            use_attention_pool=False,
                         no_linear_in_E=True, apply_lead_mask=False, no_lead_attention=False,no_time_attention=False)
    checkpoint_path = "../../log_median/ECG_attention_512_finetuned_no_attention_pool_no_linear_ms_resnet/checkpoints/checkpoint_best_loss-v2.ckpt"
elif model_name=="ECG_attention_pretrained_on_recon_ECG2Text":
    ecg_net  = ECGAttentionAE(num_leads=12, time_steps=time_steps, z_dims=512, linear_out=512, 
                            downsample_factor=5, base_feature_dim=4,if_VAE=False,
                            use_attention_pool=False,
                         no_linear_in_E=True, apply_lead_mask=False, no_lead_attention=False,no_time_attention=False)
    print ("load ECG attention model")
    checkpoint_path = "../../log_median/ECG_attention_512_finetuned_no_attention_pool_no_linear_ms_resnet_ECG2Text/checkpoints/checkpoint_best_loss-v2.ckpt"

else:
    raise NotImplementedError

num_classes = 2 ## for binary classification


In [None]:
## load the autoencoder model as well
checkpoint = torch.load(checkpoint_path)["state_dict"]
encoder_params = {(".").join(key.split(".")[2:]):value for key, value in checkpoint.items() if str(key).startswith("network.encoder")}
decoder_params = {(".").join(key.split(".")[2:]):value for key, value in checkpoint.items() if str(key).startswith("network.decoder")}
ecg_net.decoder.load_state_dict(decoder_params)
ecg_net.encoder.load_state_dict(encoder_params)
classification_net = LitClassifier(encoder=ecg_net.encoder,input_dim=512,num_classes=num_classes,lr=1e-3,freeze_encoder=True)


In [None]:
import torch
import sys
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import os
pl.seed_everything(42)

tb_logger = TensorBoardLogger( f"./finetune_on_MI", name=model_name, version="")  
checkpoint_dir  = os.path.join(tb_logger.log_dir,"checkpoints")

checkpoint_callback_best_loss_min = pl.callbacks.ModelCheckpoint(dirpath=checkpoint_dir, 
                                                    filename='checkpoint_best_loss',
                                                    save_top_k=1, monitor="val_loss"
                                                    , mode='min',save_last=True)


callbacks=[
    # FineTuneLearningRateFinder(milestones=[5, 10],min_lr=1e-5, max_lr=1e-3, 
    #                             mode='linear', early_stop_threshold=4.0),
    checkpoint_callback_best_loss_min
    ]


trainer = pl.Trainer(accelerator="gpu",
                    devices=1, max_epochs=50,
                    # logger=tb_logger,log_every_n_steps=1,check_val_every_n_epoch = 1,
                    # callbacks=callbacks,
                    )
trainer.fit(classification_net, train_dataloaders=my_dataloader,val_dataloaders=my_val_dataloader)


In [None]:
# # ## test the model
checkpoint_path = "/home/engs2522/project/multi-modal-heart/multi_modal_heart/tasks/finetune_on_MI/ECG_attention_pretrained_on_recon_ECG2Text/checkpoints/last-v7.ckpt"
classification_net.load_from_checkpoint(checkpoint_path,encoder=ecg_net.encoder,input_dim=512,num_classes=num_classes,lr=1e-3,freeze_encoder=True)
# ## load the classifier
# classifier_state_dict = torch.load(checkpoint_path)["state_dict"]
# classifier_state_dict = {(".").join(key.split(".")[1:]):value for key, value in classifier_state_dict.items() if str(key).startswith("downsteam_net")}
# classification_net.downsteam_net.load_state_dict(classifier_state_dict)
classification_net.clear_cache()
classification_net.eval()
classification_net.freeze()


result = trainer.test(classification_net,my_test_dataloader)


In [None]:
from multi_modal_heart.ECG.ecg_utils import arraytodataframe,plot_multiframe_in_one_figure

def plot_overlapped_multi_lead_signals(sample_y_nd,sample_x_hat_nd, labels=["GT ","Pred "]):
    '''
    input: sample_y_nd: numpy array, shape: (12, time_steps)
         sample_x_hat_nd: numpy array, shape: (12, time_steps)
    '''
        
    y_df = arraytodataframe(sample_y_nd)
    y_recon = arraytodataframe(sample_x_hat_nd)

    if sample_x_hat_nd.shape[0]==12:
        lead_names= ['I','II','III','aVR','aVL','aVF','V1','V2','V3','V4','V5','V6']
        figure_arrangement=(4,3)
    elif sample_x_hat_nd.shape[0]==8:
        lead_names= ['I','II','V1','V2','V3','V4','V5','V6']
        figure_arrangement=(4,2)
    else:
        raise NotImplementedError
    y_df.columns = [labels[0]+k for k in lead_names]
    y_recon.columns = [labels[1]+k for k in lead_names]

    figure = plot_multiframe_in_one_figure([y_df,y_recon],figsize=(15,4), figure_arrangement=figure_arrangement, logger=None)
    return figure


In [None]:
ecg_net.cuda()
reconstructed_signal = ecg_net(x.cuda(),mask=None)
figure = plot_overlapped_multi_lead_signals(x[0].cpu().numpy(),
                                            reconstructed_signal[0].detach().cpu().numpy(),
                                            labels = ["input ","recon "])

In [None]:
test_data = next(iter(my_test_dataloader))
classification_net.eval()
classification_net.freeze()
ecg_net.eval()



In [None]:
## get the projected last hidde feature z
classification_net.clear_cache()
classification_net.eval()
x = test_data[0]
print(x.size())
y = test_data[1]
ecg_id = test_data[2]
indice = 0
classification_net.cuda()
z = classification_net.encoder(x.cuda(),mask=None)
y = test_data[1].cuda()
reconstructed_signal = ecg_net.decoder(z)
pred = classification_net.downsteam_net(z)
pred_label = torch.argmax(pred,dim=1)
## select those predictions with correct prediction
print((pred_label ==y)&(y==1))

z = z[(pred_label ==y)&(y==1)]
x = x.cuda()
filtered_x = x[(pred_label ==y)&(y==1)]
print (z.shape)
reconstructed_signal = reconstructed_signal[(pred_label ==y)&(y==1)]

acc = classification_net.accu_metric(pred_label,y)
print (acc)
## convert the median wave to 16-bit float precision to save space
# figure = plot_overlapped_multi_lead_signals(x[indice].cpu().numpy(),reconstructed_signal[indice].detach().cpu().numpy())


In [None]:
last_hidden_feature = classification_net.downsteam_net.get_features(z)
print(last_hidden_feature.shape)


In [None]:
hidden_z_dim = 28 ## the principle dimension of the hidden feature
## adjust the latent feature z so that the hidden feature in x74 is close to 1 ##
previous_z_74_value = last_hidden_feature[:,hidden_z_dim]
## plot the value of z_74
## get the minumum value and indices of z_74 
min_z_74, min_indice = torch.min(previous_z_74_value,dim=0)
max_z_74, max_indice = torch.max(previous_z_74_value,dim=0)
## median value of z_74
median_z_74, median_indice = torch.median(previous_z_74_value,dim=0)
## get z_74 with value close to median z_74
# median_z_74 = z[torch.abs(previous_z_74_value-median_z_74)<1e-10]
# print(median_z_74.shape)
print("the range of {} is: [{},{}]".format(str(hidden_z_dim),min_z_74,max_z_74))

In [None]:
## plot the ECG recon signal with lowest z_74 value
figure = plot_overlapped_multi_lead_signals(filtered_x[min_indice].detach().cpu().numpy(),
                                   filtered_x[max_indice].detach().cpu().numpy(),
                                   labels = ["low risk ","high risk "])

In [None]:
## plot the ECG recon signal with lowest z_74 value
figure = plot_overlapped_multi_lead_signals(reconstructed_signal[min_indice].detach().cpu().numpy(),
                                   reconstructed_signal[max_indice].detach().cpu().numpy(),
                                   labels = ["low risk ","high risk "])

In [None]:
figure = plot_overlapped_multi_lead_signals(x[median_indice].cpu().numpy(),
                                            reconstructed_signal[max_indice].detach().cpu().numpy(),
                                            labels = ["medium risk ","high risk "])

In [None]:
## compute the incidence cases's z_74 value
import pandas as pd
MI_HF_coxreg_df= pd.read_csv('/home/engs2522/project/multi-modal-heart/multi_modal_heart/toolkits/ukb/non_imaging_information/MI/MI_HF_coxreg_df.csv')
MI_HF_coxreg_df.head(5)

In [None]:
incident_case = MI_HF_coxreg_df[MI_HF_coxreg_df['HF_status']==1]
h_74_list = []
z_feature_list =[]
hidden_feature_list= []
hidden_z_dim = 28
## latent z feature 
ecg_wave_tensor_list = []
duration_list = []
for eid, ecg_wave in  zip(eid_list_mi, mi_data):
    if eid in incident_case.eid.values:
        print (eid)
        # print (ecg_wave.shape)
        ecg_wave_tensor = torch.Tensor(ecg_wave).unsqueeze(0).float().cuda()
        duration = MI_HF_coxreg_df[MI_HF_coxreg_df['eid']==eid]["time_to_HF"].values[0]
        ecg_wave_tensor_list.append(ecg_wave_tensor)
        # z = classification_net.encoder.get_features_after_pooling(ecg_wave_tensor,mask=None)
        # print(z.shape)
        duration_list.append(duration)
batched_ecg_wave = torch.cat(ecg_wave_tensor_list,dim=0)
z_feature = classification_net.encoder.get_features_after_pooling(batched_ecg_wave,mask=None)
h_feature = classification_net.downsteam_net.get_features(z_feature)
h_74_list = h_feature[:,hidden_z_dim].detach().cpu().numpy()
final_list  = [[h,d] for h,d in zip(h_74_list,duration_list)]

z_value_time_to_hf_df = pd.DataFrame(final_list,columns=[f"h_{hidden_z_dim}","time_to_hf"])
# final_list[0].shape## plot the z_74 value with time to HF
z_value_time_to_hf_df.head()

In [None]:
#plot the z_74 value with time to HF
import seaborn as sns
sns.set_theme(style="whitegrid")
## plot the z_74 value with time to HF as color
ax = sns.scatterplot(x="time_to_hf", y=f"h_{hidden_z_dim}", data=z_value_time_to_hf_df,hue="time_to_hf",palette="cool")

In [None]:
h_selected = last_hidden_feature[[min_indice],:]
print(h_selected.shape)
target_h_feature = h_selected
target_h_feature[:,hidden_z_dim] = max_z_74 ## set the value of z_74 to 1 to increase the likelihood of HF
print (target_h_feature)


In [None]:
## perform optimization of latent code z so that hidden feature h matches the target hidden feature
from torch.optim import AdamW
from torch.nn import MSELoss, L1Loss
from tqdm import tqdm
z_selected= z[[min_indice],:].detach().clone()
z_selected_original = z_selected.detach().clone()
reconstructed_signal = ecg_net.decoder(z_selected)
target_h_feature = target_h_feature.detach().clone()
target_h_feature.requires_grad = False
z_selected.requires_grad = True
optimizer = AdamW([z_selected],lr=1e-3,weight_decay=1e-1)
loss_fn = MSELoss()
classification_net.unfreeze()
classification_net.eval()
for i in tqdm(range(2000)):
    optimizer.zero_grad()
    last_hidden_feature = classification_net.downsteam_net.get_features(z_selected)
    loss = loss_fn(input = last_hidden_feature[:,hidden_z_dim],target = target_h_feature[:,hidden_z_dim])
    loss.backward()
    optimizer.step()
    if i%100==0:
        print ("loss:{}".format(loss.item()))
        print ("h_28:{}".format(last_hidden_feature[:,hidden_z_dim]))
        ## measure the difference of the optimized value and the one before optimization
        print ("z difference:{}".format(torch.sum(torch.abs(z_selected_original-z_selected))))


In [None]:
recon_selected = ecg_net.decoder(z_selected)
recon_original = ecg_net.decoder(z_selected_original)
figure = plot_overlapped_multi_lead_signals(recon_original[min_indice].detach().cpu().numpy(),
                                            recon_selected[min_indice].detach().cpu().numpy(),
                                            labels = ["before (z^0)","after (z*) "])