In [1]:

import torch.nn as nn
import torch
## use a LSTM model to predict the current ecg sequence or future sequence
class Sequence(nn.Module):
    def __init__(self,input_feature_dim =1, num_embedding=64):
        super(Sequence, self).__init__()
        self.num_embedding = num_embedding
        self.lstm1 = nn.LSTMCell(input_feature_dim, num_embedding)
        self.lstm2 = nn.LSTMCell(num_embedding, num_embedding)
        self.linear = nn.Linear(num_embedding, input_feature_dim)

    def forward(self, input, future = 0):
        outputs = []
        h_t = torch.zeros(input.size(0), self.num_embedding, dtype=torch.float32)
        c_t = torch.zeros(input.size(0), self.num_embedding, dtype=torch.float32)
        h_t2 = torch.zeros(input.size(0), self.num_embedding, dtype=torch.float32)
        c_t2 = torch.zeros(input.size(0), self.num_embedding, dtype=torch.float32)

        for input_t in input.split(1, dim=1):
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]
        for i in range(future):# if we should predict the future
            h_t, c_t = self.lstm1(output, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]
        outputs = torch.cat(outputs, dim=1)
        return outputs



In [36]:
import torch
import  torch.nn as nn
batch_size = 2
num_leads = 5
feature_dim =512
time_length =1024
input = torch.randn(num_leads, batch_size, feature_dim)
mask_lead = torch.ones(batch_size, num_leads,time_length)
mask_lead[:,[0,3,4]] = 0
mask_lead = torch.mean(mask_lead,dim=2) ## (batch_size, num_leads)

# print('mask lead\n', mask_lead)
lead_attn_mask = torch.zeros((batch_size,num_leads,num_leads),device =input.device)
lead_attn_mask.masked_fill_(mask_lead.unsqueeze(1)==0, 1)
lead_attn_mask=lead_attn_mask>0
print(lead_attn_mask)
## duplicate the mask to match the number of heads
num_heads =8
lead_attn_mask = lead_attn_mask.unsqueeze(0).repeat(num_heads, 1, 1, 1)
lead_attn_mask = lead_attn_mask.view(-1, num_leads, num_leads)
# print ('lead_attn_mask\n', lead_attn_mask.shape)
lead_mha = nn.MultiheadAttention(embed_dim=feature_dim,num_heads =num_heads)
weighted_input, attention  = lead_mha(input,input,input,attn_mask = lead_attn_mask)


tensor([[[ True, False, False,  True,  True],
         [ True, False, False,  True,  True],
         [ True, False, False,  True,  True],
         [ True, False, False,  True,  True],
         [ True, False, False,  True,  True]],

        [[ True, False, False,  True,  True],
         [ True, False, False,  True,  True],
         [ True, False, False,  True,  True],
         [ True, False, False,  True,  True],
         [ True, False, False,  True,  True]]])


In [39]:
age_code = torch.nn.functional.one_hot(torch.tensor(7),num_classes=8)
print(age_code)

tensor([0, 0, 0, 0, 0, 0, 0, 1])


In [None]:
## output
import torch
import torch.nn as nn
n_features = 64
hidden_size = 12
num_layers = 2
rnn = nn.RNN(input_size=n_features, hidden_size=hidden_size, num_layers=num_layers,batch_first=True)
batch_size =5
len_seq = 1024

## initial hidden state:
initial_state = torch.zeros(num_layers, batch_size, hidden_size)
input = torch.randn(batch_size, 1,n_features)
output, hn = rnn(input,initial_state)
print (output.shape)


In [None]:
class ConvRNN(nn.Module):
    def __init__(self, input_dim, timesteps, output_dim, kernel_size1=7, kernel_size2=5, kernel_size3=3, 
                 n_channels1=32, n_channels2=32, n_channels3=32, n_units1=32, n_units2=32, n_units3=32, if_multi_lead=False,out_ch = 12):
        super().__init__()
        self.avg_pool1 = nn.AvgPool1d(2, 2)
        self.avg_pool2 = nn.AvgPool1d(4, 4)
        self.conv11 = nn.Conv1d(input_dim, n_channels1, kernel_size=kernel_size1)
        self.conv12 = nn.Conv1d(n_channels1, n_channels1, kernel_size=kernel_size1)
        self.conv21 = nn.Conv1d(input_dim, n_channels2, kernel_size=kernel_size2)
        self.conv22 = nn.Conv1d(n_channels2, n_channels2, kernel_size=kernel_size2)
        self.conv31 = nn.Conv1d(input_dim, n_channels3, kernel_size=kernel_size3)
        self.conv32 = nn.Conv1d(n_channels3, n_channels3, kernel_size=kernel_size3)
        self.gru1 = nn.GRU(n_channels1, n_units1, batch_first=True)
        self.gru2 = nn.GRU(n_channels2, n_units2, batch_first=True)
        self.gru3 = nn.GRU(n_channels3, n_units3, batch_first=True)
        self.if_multi_lead = if_multi_lead
        if not if_multi_lead:
            self.linear1 = nn.Linear(n_units1+n_units2+n_units3, output_dim)
            self.linear2 = nn.Linear(input_dim*timesteps, output_dim)
        else:
            self.linear1 = nn.ModuleList()
            self.linear2 = nn.ModuleList()
            self.out_ch = out_ch
            for i in range(self.out_ch):
                self.linear1.append(nn.Linear(n_units1+n_units2+n_units3, output_dim))
                self.linear2.append(nn.Linear(input_dim*timesteps, output_dim))
        
        self.zp11 = nn.ConstantPad1d(((kernel_size1-1), 0), 0)
        self.zp12 = nn.ConstantPad1d(((kernel_size1-1), 0), 0)
        self.zp21 = nn.ConstantPad1d(((kernel_size2-1), 0), 0)
        self.zp22 = nn.ConstantPad1d(((kernel_size2-1), 0), 0)
        self.zp31 = nn.ConstantPad1d(((kernel_size3-1), 0), 0)
        self.zp32 = nn.ConstantPad1d(((kernel_size3-1), 0), 0)
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        # line1
        y1 = self.zp11(x)
        y1 = torch.relu(self.conv11(y1))
        y1 = self.zp12(y1)
        y1 = torch.relu(self.conv12(y1))
        y1 = y1.permute(0, 2, 1)
        out, h1 = self.gru1(y1)
        # line2
        y2 = self.avg_pool1(x)
        y2 = self.zp21(y2)
        y2 = torch.relu(self.conv21(y2))
        y2 = self.zp22(y2)
        y2 = torch.relu(self.conv22(y2))
        y2 = y2.permute(0, 2, 1)
        out, h2 = self.gru2(y2)
        # line3 
        y3 = self.avg_pool2(x)
        y3 = self.zp31(y3)
        y3 = torch.relu(self.conv31(y3))
        y3 = self.zp32(y3)
        y3 = torch.relu(self.conv32(y3))
        y3 = y3.permute(0, 2, 1)
        out, h3 = self.gru3(y3)
        h = torch.cat([h1[-1], h2[-1], h3[-1]], dim=1)
        if self.if_multi_lead:
            output_list = []
            for i in range(self.out_ch):
                out = self.linear1[i](h)
                output_list.append(out)
            out = torch.stack(output_list, dim=1)
        else:
            out = self.linear1(h)
            # out2 = self.linear2(x.contiguous().view(x.shape[0], -1))
            # out = out1 + out2
        return out

In [2]:
##
class LinearDecoder(nn.Module):
    def __init__(self, input_feature_dim, output_length, out_ch=1,decompose=False):
        super().__init__()

        self.out_ch = out_ch
        self.decompose = decompose
        self.linear1 = nn.ModuleList()
        self.linear2 = nn.ModuleList()
        if decompose:
            assert input_feature_dim%2==0
            input_feature_dim = input_feature_dim//2
            for i in range(self.out_ch):
                self.linear1.append(nn.Linear(input_feature_dim, output_length))
                self.linear2.append(nn.Linear(input_feature_dim, output_length))
        else:
            for i in range(self.out_ch):
                self.linear1.append(nn.Linear(input_feature_dim, output_length))
            
    def forward(self, x):
        '''
        x: [Batch, feature_dim]
        return: [Batch, out_ch, output_length]
        '''
        output = []
        size = x.size(1)
        for i in range(self.out_ch):
            if self.decompose: 
                out = self.linear1[i](x[:,:size//2])
                output_2 = self.linear2[i](x[:,size//2:])
                out = out + output_2
            else:
                out = self.linear1[i](x)
            output.append(out)
        output = torch.stack(output, dim=1)
        return output
input_data = torch.randn(5,64)
model = LinearDecoder(input_feature_dim =64, output_length=1024,out_ch=12,decompose=False)
output = model(input_data)
print (output.shape)

torch.Size([5, 12, 1024])


In [3]:

out = ConvRNN(12, 1024, 1024,if_multi_lead=True,out_ch=12)(torch.randn(5, 1024,12))
print (out.shape)

NameError: name 'ConvRNN' is not defined

In [None]:
import sys
import torch
sys.path.append('../..')
from multi_modal_heart.model.ecg_net import doubleECGNet

output = doubleECGNet(decoder_type="linear",decoder_outdim=12)
output = (torch.randn(5, 12,1024))
output.shape

In [None]:
seq_model = Sequence(input_feature_dim=1)
test_input = torch.randn(10,512)
red = seq_model(test_input, future=0)
print (red.shape)

In [None]:
# 
import torch.nn as nn
import torch
n_features =512
batch_size =10
signal_length = 1024//32
num_lead = 12
input = torch.randn(batch_size, 512, 12,signal_length)
use_attention_mask = True
## make it as L,N,F, cross time attention
# attn_mask = torch.ones(signal_length, signal_length)
# attn_mask = torch.triu(attn_mask, diagonal=1)
time_input = input.permute(3,0,2,1).reshape(signal_length,batch_size,-1)
## 
 
time_encoder = PositionalEncoding(d_model = n_features*num_lead,dropout=0.5)
time_input = time_encoder(time_input)
mha = nn.MultiheadAttention(embed_dim=n_features*num_lead,num_heads = 8)

if use_attention_mask:
    attn_mask = torch.ones(signal_length,signal_length)
    attn_mask = 1-torch.triu(attn_mask, diagonal=1)
else:
    attn_mask = None
cross_time,attn_output_weights = mha(time_input,time_input,time_input,attn_mask=attn_mask)
print (cross_time.shape)
## leadwise attention:
lead_input = input.permute(2,0,1,3).reshape(num_lead,batch_size,-1)
embedding = n_features*signal_length

lead_mha = nn.MultiheadAttention(embed_dim=embedding,num_heads = 8)

lead_encoder = PositionalEncoding(d_model = n_features*signal_length,dropout=0.5)
lead_input = lead_encoder(lead_input)
cross_lead,lead_attn_output_weights = lead_mha(lead_input,lead_input,lead_input)


In [None]:
import torch.nn as nn
import torch.nn.functional as F
class AttentionPool1d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dropout_rate: float = 0.1):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim+ 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads
        self.dropout_p  = dropout_rate

    def forward(self, x):
        x = x.permute(2, 0, 1)  # NCL -> (L)NC

        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (L)NC
        # print (x.shape)
        # print (self.positional_embedding[:, None, :].to(x.dtype).shape)
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=  self.dropout_p,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )
        return x.squeeze(0)

In [None]:
## cross lead attention
import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.2, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe, persistent=False)
    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [9]:
import torch
import sys
import torch
sys.path.append('../..')
from multi_modal_heart.model.ecg_net import doubleECGNet

from multi_modal_heart.model.ecg_net_attention import ECGAttentionAE

ecg_net = ECGAttentionAE(num_leads=12, time_steps=1024, z_dims=512, 
                         linear_out=512, downsample_factor=4, upsample_factor=4,
                         base_feature_dim=4,if_VAE=False,use_attention_pool=False,no_linear_in_E=True)

ecg_net(torch.randn(1,12,1024)).shape

no linear layer


torch.Size([1, 12, 1024])

In [None]:
import torch.nn as nn
import sys
sys.path.append('../../')

from multi_modal_heart.model.ecg_net import ECGAE
import lightning.pytorch as pl

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, network,task_name,input_key_name="input_seq", target_key_name="cleaned_seq", future_key_name="next_seq", grad_clip=False, warmup=100,
        max_iters=2000,batch_size=128, **kwargs):
        super().__init__()  
        self.network = network
        self.task_name = task_name


ae=ECGAE(encoder_type ="ms_resnet",in_channels=8,ECG_length=1024,embedding_dim=256,latent_code_dim=64,add_time=False,
               apply_method="",decoder_outdim=8,time_dim=0,act = nn.GELU(),encoder_mha=True)

model = LitAutoEncoder(network=ae,task_name="ECGAE",input_key_name="input_seq", target_key_name="cleaned_seq")
model.load_from_checkpoint("/home/engs2522/project/multi-modal-heart/log/dae_64+mha_ms_resnet/checkpoints/epoch=85-step=2924.ckpt")

# encoder_vec = ecg_net.encodeECG(torch.randn(4,12,1024),mask = torch.randn(4,12,1024),auto_pad_input=True)
# print (encoder_vec.shape)
# decoder_vec = ecg_net.decodeECG(encoder_vec)
# print (decoder_vec.shape)

In [None]:
## test resnet
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('../..')
from multi_modal_heart.model.ecg_net import ECG_ResNetencoder
resencoder = ECG_ResNetencoder(in_channels=12,ECG_length=1024,embedding_dim=256,output_dim=64)
resencoder(torch.randn(4,12,1024)).shape
resencoder.get_features_after_pooling(torch.randn(4,12,1024)).shape

In [None]:
from multi_modal_heart.model.ecg_net_attention import ECGAttentionAE
ecg_net = ECGAttentionAE(num_leads=12, time_steps=1024, z_dims=512, linear_out=64, downsample_factor=5, base_feature_dim=4,if_VAE=False,no_linear_in_E=True)
ecg_net.encoder.get_features_after_pooling(torch.randn(4,12,1024)).shape
# ecg_net.encoder(torch.randn(4,12,1024)).shape
ecg_net(torch.randn(4,12,1024)).shape

In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, Callback

class OneCycleLR(Callback):
    def __init__(self, max_lr, total_steps, pct_start=0.3, anneal_strategy='cos'):
        super(OneCycleLR, self).__init__()
        self.max_lr = max_lr
        self.total_steps = total_steps
        self.pct_start = pct_start
        self.anneal_strategy = anneal_strategy

    def on_train_start(self, trainer, pl_module):
        optimizer = trainer.optimizers[0]
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.max_lr,
            total_steps=self.total_steps,
            pct_start=self.pct_start,
            anneal_strategy=self.anneal_strategy
        )
        trainer.lr_schedulers = [scheduler]  # Set the scheduler to the trainer

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        scheduler = trainer.lr_schedulers[0]
        scheduler.step()  # Step the learning rate scheduler after each epoch


# Your PyTorch Lightning module
class MyModule(pl.LightningModule):
    def __init__(self):
        super(MyModule, self).__init__()
        # Initialize your model, loss, etc.

    def training_step(self, batch, batch_idx):
        # Define your training logic
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# Training script
model = MyModule()
callbacks = [
    LearningRateMonitor(logging_interval='step'),
    OneCycleLR(max_lr=0.1, total_steps=100, pct_start=0.3, anneal_strategy='cos')
]
trainer = pl.Trainer(devices=1,callbacks=callbacks, max_epochs=10)
trainer.fit(model)

In [None]:
## visualize the latent space
from nomic import atlas
import numpy as np

num_embeddings = 1000
embeddings = np.random.rand(num_embeddings, 256)

categories = ['rhizome', 'cartography', 'lindenstrauss']
data = [{'category': categories[i % len(categories)], 'id': i}
            for i in range(len(embeddings))]

# project = atlas.map_embeddings(embeddings=embeddings,
#                                 data=data,
#                                 id_field='id',
#                                 colorable_fields=['category']
#                                 )

In [None]:
import umap
import matplotlib.pyplot as plt
reducer = umap.UMAP()
## transform embeddings to 2D
umap_embeddings = reducer.fit_transform(embeddings)



In [None]:
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], cmap='Spectral')
plt.gca().set_aspect('equal', 'datalim')
plt.title('UMAP projection of the Autoencoder embeddings', fontsize=24)
plt.show()

In [None]:
from multi_modal_heart.model.basic_conv1d import MyResidualBlock1D

In [None]:
from multi_modal_heart.model.ecg_net_attention import ECGAttentionAE

In [8]:
## build classification network

## start training
import os
import sys
sys.path.append('../../')
from torch.utils.data import DataLoader
from multi_modal_heart.ECG.ecg_dataset import ECGDataset
## initialize a dataloader (all data)
data_folder = "/home/engs2522/project/multi-modal-heart/multi_modal_heart/data/ptbxl/"
train_data_statement_path = os.path.join(data_folder,"/home/engs2522/project/multi-modal-heart/multi_modal_heart/data/ptbxl/raw_split/Y_train.csv")
validate_data_statement_path = os.path.join(data_folder,"/home/engs2522/project/multi-modal-heart/multi_modal_heart/data/ptbxl/raw_split/Y_validate.csv")
test_data_statement_path = os.path.join(data_folder,"/home/engs2522/project/multi-modal-heart/multi_modal_heart/data/ptbxl/raw_split/Y_test.csv")


data_loaders = []
sampling_rate=500
batch_size  = 12
max_seq_len = 608
data_proc_config={
                "if_clean":True,
                 }
data_aug_config={
                "noise_frequency_list":[5,20,100,150,175],
                "noise_amplitude_range":[0.,0.2],
                "powerline_frequency_list":[50],
                "powerline_amplitude_range":[0.,0.05],
                "artifacts_amplitude_range":[0.,0.1],
                "artifacts_frequency_list":[5,10],
                "artifacts_number_range":[0,3],
                "linear_drift_range":[0.,0.1],
                "random_prob":0.5,
                "if_mask_signal":True, 
                "mask_whole_lead_prob":0.2,
                "lead_mask_prob":0.2,
                "region_mask_prob":0.15,
                "mask_length_range":[0.08, 0.18],
                "mask_value":0.0,
                
                }
for label_csv_path in [train_data_statement_path,validate_data_statement_path,test_data_statement_path]:
    if_test ="test" in label_csv_path.split("/")[-1]
    dataset = ECGDataset(data_folder,label_csv_path=label_csv_path,
                         use_median_wave=True, ## set to median wave, then it has 600 samples for each lead, when sampling rate is 100
                          sampling_rate=sampling_rate,
                          max_seq_len=max_seq_len,
                          augmentation= not if_test,
                          data_proc_config=data_proc_config,
                          data_aug_config=data_aug_config,)
    data_loader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            num_workers=0,
                            shuffle = not if_test,
                            drop_last= not if_test,
                            )
    print ('load {} data: {} samples'.format(label_csv_path.split("/")[-1],len(dataset)))
    data_loaders.append(data_loader)
    
train_loader, validate_loader, test_loader = data_loaders[0],data_loaders[1],data_loaders[2]


load Y_train.csv data: 17111 samples
load Y_validate.csv data: 2156 samples
load Y_test.csv data: 2163 samples


  self.label_csv_df = pd.read_csv(label_csv_path)


In [9]:
for data in train_loader:
    print (data["input_seq"].shape)
    print (data["cleaned_seq"].shape)
    break

  warn(
  warn(
  warn(


['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
['gamboa error: index -1 is out of bounds for axis 0 with size 0', 'zong error: index 600 is out of bounds for axis 0 with size 600']
['zong error: index 600 is out of bounds for axis 0 with size 600']
torch.Size([12, 12, 608])
torch.Size([12, 12, 608])


In [None]:
import wfdb
# wfdb.rdsamp("/home/engs2522/project/multi-modal-heart/multi_modal_heart/data/ptbxl/records100/00000/00001_lr")
wfdb.rdsamp("/home/engs2522/project/multi-modal-heart/multi_modal_heart/data/ptbxl/median_beats/unig/11000/011848_medians")

In [1]:
## UKB dataset
## build classification network

## start training
import os
import sys
sys.path.append('../../')
from torch.utils.data import DataLoader
from multi_modal_heart.ECG.ecg_UKB_dataset import ECGUKBDataset

data_loaders = []
sampling_rate=100
batch_size  = 12
max_seq_len = 1024
data_proc_config={
                "if_clean":False,
                 }
data_aug_config={
                "noise_frequency_list":[5,20,100,150,175],
                "noise_amplitude_range":[0.,0.2],
                "powerline_frequency_list":[50],
                "powerline_amplitude_range":[0.,0.05],
                "artifacts_amplitude_range":[0.,0.1],
                "artifacts_frequency_list":[5,10],
                "artifacts_number_range":[0,3],
                "linear_drift_range":[0.,0.1],
                "random_prob":0.5,
                "if_mask_signal":True, 
                "mask_whole_lead_prob":0.2,
                "lead_mask_prob":0.2,
                "region_mask_prob":0.15,
                "mask_length_range":[0.08, 0.18],
                "mask_value":0.0,
                
                }

## initialize a dataloader (all data)
data_folder = "/home/engs2522/project/multi-modal-heart/multi_modal_heart/data/UKB/"
train_data_statement_path = os.path.join(data_folder,"splits/100_norm_100_HF/UKB_train.csv")
validate_data_statement_path = os.path.join(data_folder,"splits/100_norm_100_HF/UKB_val.csv")
test_data_statement_path = os.path.join(data_folder,"splits/100_norm_100_HF/UKB_test.csv")

for label_csv_path in [train_data_statement_path,validate_data_statement_path,test_data_statement_path]:
    if_test ="test" in label_csv_path.split("/")[-1]
    dataset = ECGUKBDataset(data_folder,label_csv_path=label_csv_path,
                          sampling_rate=sampling_rate,
                          use_median_wave=False,
                          max_seq_len=max_seq_len,
                          augmentation= not if_test,
                          data_proc_config=data_proc_config,
                          data_aug_config=data_aug_config,)
    data_loader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            num_workers=0,
                            shuffle = not if_test,
                            drop_last= not if_test,
                            )
    print ('load {} data: {} samples'.format(label_csv_path.split("/")[-1],len(dataset)))
    data_loaders.append(data_loader)
    
train_loader, validate_loader, test_loader = data_loaders[0],data_loaders[1],data_loaders[2]



load UKB_train.csv data: 100 samples
load UKB_val.csv data: 50 samples
load UKB_test.csv data: 50 samples


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from torch import optim, nn
import torch
from torchmetrics.classification import MultilabelAUROC
 
import sys
sys.path.append('../../')
from multi_modal_heart.model.ecg_net import ClassifierMLP
from multi_modal_heart.model.ecg_net_attention import ECGAttentionAE
from multi_modal_heart.ECG.utils import evaluate_experiment
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, network):
        super().__init__()
        self.network=network
        self.latent_code_dim =64
        self.num_classes = 5  
    def forward(self, x):
        return self.network(x)

class LitClassifier(pl.LightningModule):
    def __init__(self,encoder,num_classes=5,learning_rate=1e-3,freeze_encoder=False, 
                 task_name = "ECG_Classifier", max_iters =20000):
        super().__init__()
        
        self.encoder = encoder
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        self.freeze_encoder = freeze_encoder
        self.learning_rate = learning_rate
        self.downsteam_net= ClassifierMLP(input_size=64,hidden_sizes=[256],output_size=num_classes)
   
        self.latent_code_dim =64
        self.test_preds = []
        self.test_ground_truth = []
        self.task_name = task_name
        self.max_iters = max_iters
        self.save_hyperparameters()
        self.class_names= ['CD', 'HYP', 'MI', 'NORM', 'STTC']
        ## define metrics here
        self.macro_auroc_metric = MultilabelAUROC(num_labels=num_classes, average="macro", thresholds=None)
        self.classwise_auroc_metric = MultilabelAUROC(num_labels=num_classes, average=None, thresholds=None)
        self.test_macro_auroc_metric = MultilabelAUROC(num_labels=num_classes, average="macro", thresholds=None)
        self.test_classwise_auroc_metric = MultilabelAUROC(num_labels=num_classes, average=None, thresholds=None)
    
    def forward(self, x, eval=False):
        if self.freeze_encoder or eval:
            self.encoder.eval()
        else:
            self.encoder.train()
        latent_code = self.encoder(x)
        return self.downsteam_net(latent_code)
    
    def run_task(self, batch, batch_idx, prefix_name=""):
        input = batch["input_seq"]
        target = batch["super_class_encoding"].float()
        if prefix_name=="train" and self.global_step>200:
            eval=False
        else:
            eval=True
        outputs_before_sigmoid = self.forward(input,eval)
        loss = torch.nn.BCEWithLogitsLoss()(outputs_before_sigmoid,target)
        self.log(f"{self.task_name}/{prefix_name}_loss", loss)
        return loss,torch.sigmoid(outputs_before_sigmoid),target

    def training_step(self, batch, batch_idx):
        loss, _,_ = self.run_task(batch, batch_idx, prefix_name="train")
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss,pred,target= self.run_task(batch, batch_idx, prefix_name="val")
        self.classwise_auroc_metric(pred,target.long())
        self.macro_auroc_metric(pred,target.long())
        return loss
    
    def on_validation_epoch_end(self):
        ## gather all the predictions and ground truth
        macro_auc = self.macro_auroc_metric.compute()
        self.log(f"{self.task_name}/val_macro_auc", macro_auc)
        # classwise_auc = self.classwise_auroc_metric.compute()
        print ("macro_auc",macro_auc.item())
        # for i in range(classwise_auc.shape[0]):
        #     self.log(f"{self.task_name}/val_auc_{i}", classwise_auc[i])
        # self.classwise_auroc_metric.reset()
        self.macro_auroc_metric.reset()
        

    def test_step(self, batch, batch_idx):
        loss, pred,target = self.run_task(batch, batch_idx, prefix_name="test")
        # self.test_preds.append(pred)
        # self.test_ground_truth.append(target)
        self.test_classwise_auroc_metric(pred,target.long())
        self.test_macro_auroc_metric(pred,target.long())
    
    def configure_optimizers(self):
        optimizer = optim.AdamW([
            {'params':self.encoder.parameters() , 'lr': 1e-4 or self.lr,},
            {'params':self.downsteam_net.parameters(), 'lr': self.learning_rate or self.lr},
        ],betas=(0.9, 0.95))

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=20,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

    def on_train_epoch_end(self):
        print (self.task_name)
  
    def on_test_epoch_end(self):
        ## gather all the predictions and ground truth
        # total_test_preds = torch.cat(self.test_preds,dim=0)
        # total_test_ground_truth = torch.cat(self.test_ground_truth,dim=0)
        test_classwise = self.test_classwise_auroc_metric.compute()
        test_summary = self.test_macro_auroc_metric.compute()
        self.log(f"{self.task_name}/test_macro_auc", test_summary)
        for i in range(test_classwise.shape[0]):
            self.log(f"{self.task_name}/test_auc_{self.class_names[i]}", test_classwise[i])
        print (test_summary)
        print (test_classwise)
        print (test_summary)
        print (test_classwise)
        self.test_classwise_auroc_metric.reset()
        self.test_macro_auroc_metric.reset()
        return test_classwise,test_summary
n_leads = 12
input_length = 608

ecg_net = ECGAttentionAE(num_leads=n_leads, time_steps=input_length, z_dims=64,downsample_factor=5, base_feature_dim=4,if_VAE=False)
model = LitAutoEncoder(network=ecg_net)
model.load_from_checkpoint("/home/engs2522/project/multi-modal-heart/log/ECG_attention_64_ms_resnet/checkpoints/epoch=299-step=10200.ckpt")
   
model.network.decoder =None ## save space
classifier = LitClassifier(model.network.encoder,num_classes=5,freeze_encoder=False, task_name = "freeze_encoder",
                           learning_rate=1e-3, max_iters =20000)
prediction =classifier(torch.randn(5,12,608)).shape


In [None]:
import os
from finetuning_scheduler import FinetuningScheduler
from transformers import    get_linear_schedule_with_warmup
from pytorch_lightning.tuner import Tuner

finetune_max_epochs =50
finetune_task_name = "finetune"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
early_stop_callback = pl.callbacks.EarlyStopping(
    monitor=f"{finetune_task_name}/val_macro_auc",
    min_delta=0.00,
    patience=10,
    verbose=False,
    mode='min'
)  



class FineTuneLearningRateFinder(pl.callbacks.LearningRateFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.lr_find(trainer, pl_module)

finetuner = pl.Trainer(accelerator="gpu",
                devices=1, max_epochs=finetune_max_epochs,fast_dev_run=True,
                callbacks=[FineTuneLearningRateFinder(milestones=(5, 10)),early_stop_callback],
                ) 

finetuner.fit(model = classifier,train_dataloaders=train_loader,val_dataloaders=validate_loader)
## evaluate the model
finetuner.test(classifier,test_loader)
print (pl.__version__)

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
result = mlb.fit_transform([["NORM", "MI", "HYP","STTC","CD"]])
input= [["NORM"]]
mlb.transform(input)
mlb.classes_


## LLM model


In [None]:
import torch
import torch.nn as nn

m = nn.AdaptiveAvgPool2d((1,1))
input_data = torch.randn(1, 128, 7,64)
m(input_data).size()