# SSSD-ECG-nle methods

#### Requirements installation

In [None]:
%cd ../
# Install ecglib package
%pip install .
# Install the cauchy CUDA kernel (it's need to fast calculate s4 kernel)
%pip install src/ecglib/models/architectures/sssd/extensions/cauchy/.
%cd src/

## SSSD-ECG-nle training

#### Imports

In [None]:
import os
from dataclasses import dataclass, field
from typing import List, Callable
import signal
import sys


import torch
from torch.utils.data import DataLoader
import pandas as pd
import wfdb
from tqdm import tqdm
import numpy as np


from ecglib.data import load_ptb_xl
from ecglib.data import EcgDataset
from ecglib.models.config.model_configs import SSSDConfig
from ecglib.models.model_builder import create_model

#### Prepare PTB-XL ECG dataset

need to change a logic of downloading

In [None]:
path_to_unzip = '/home/ecg_data/physionet_data/raw_data/'
SAMPLE_FREQUENCY = 100
ptb_xl_info = load_ptb_xl(path_to_unzip=path_to_unzip, frequency=SAMPLE_FREQUENCY)
ptb_xl_info['frequency'] = SAMPLE_FREQUENCY

# Split in accordance with PTB-XL Benchmarking
val_ptbxl_info = ptb_xl_info.loc[ptb_xl_info['strat_fold'] == 9].reset_index()
test_ptbxl_info = ptb_xl_info.loc[ptb_xl_info['strat_fold'] == 10].reset_index()
train_ptbxl_info = ptb_xl_info.loc[~ptb_xl_info['strat_fold'].isin([9, 10])].reset_index()

In [None]:
# Configure parameters of dataset
@dataclass(repr=True)
class DataParams:
    syndromes: List[str] = field(default_factory=lambda: ['AFIB']) # CRBBB, 1AVB, PVC or any code from `scp_statements.csv`. For multilabel-scenario append to list
    normalization: str = 'identical'
    leads: List[int] = field(default_factory=lambda: [0, 5, 6, 7, 8, 9, 10, 11]) # Standart sequence ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']. 
                                                                                 # We take only 8 LI leads ['I', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    sample_frequency: int = SAMPLE_FREQUENCY # 500
    augmentation: Callable = None # Check README.md Preprocessing for more details
    batch_size: int = 8 # Train with batch_size=8 require 15Gb GPU memory

data_config = DataParams()

In [None]:
# make sure the sequence of leads
wfdb.rdsamp(ptb_xl_info.iloc[0]['fpath'])[1]['sig_name']

In [None]:
# Create EcgDataset's
def dataset_from_map_file(map_file):
    targets = [[1.0 if synd in eval(map_file.iloc[i]['scp_codes']).keys() else 0.0 for synd in data_config.syndromes] 
                for i in range(map_file.shape[0])]
    return EcgDataset(ecg_data=map_file, 
                      target=targets,
                      frequency=data_config.sample_frequency,
                      leads=data_config.leads,
                      norm_type=data_config.normalization,
                      classes=len(data_config.syndromes),
                      augmentation=data_config.augmentation
                    )

train_dataset = dataset_from_map_file(train_ptbxl_info)
val_dataset = dataset_from_map_file(val_ptbxl_info)
test_dataset = dataset_from_map_file(test_ptbxl_info)

In [None]:
# Create torch dataloaders
def dataloader_from_dataset(dataset, train=True):
    return DataLoader(
            dataset,
            batch_size=data_config.batch_size,
            shuffle=train,
            drop_last=train,
        )

train_loader = dataloader_from_dataset(train_dataset)
val_loader = dataloader_from_dataset(val_dataset, train=False)
test_loader = dataloader_from_dataset(test_dataset, train=False)

#### Create SSSD-ECG-nle

In [None]:
sssd_ecg_nle_config = SSSDConfig(
    in_channels=len(data_config.leads),
    label_embed_classes=len(data_config.syndromes)
)

In [None]:
sssd_model = create_model(
    model_name='sssd_ecg',
    config=sssd_ecg_nle_config,
    pathology=data_config.syndromes,
    leads_count=len(data_config.leads),
    num_classes=len(data_config.syndromes)
)

#### Training pipeline

In [None]:
# Configure parameters of dataset
@dataclass(repr=True)
class TrainParams:
    num_viewed_samples: int = 10**6 # Number of viewed samples before model training will stop
    batch_size: int = data_config.batch_size
    # Diffusion Hyperparams
    T: int = 200 # denoising num steps
    beta_0: float = 0.0001 # first beta in markov chain
    beta_T: float = 0.02   # last beta. Linear interpolation between beta_0 and beta_T
    grad_norm: float = None # Maximum norm to gradient norm clipping
    grad_val: float = None # Maximum norm to gradient value clipping
    lr: float = 0.0002

train_config = TrainParams()

In [None]:
def calc_diffusion_hyperparams(T, beta_0, beta_T):
        """
        Compute diffusion process hyperparameters

        Parameters:
        T (int):                    number of diffusion steps
        beta_0 and beta_T (float):  beta schedule start/end value, 
                                    where any beta_t in the middle is linearly interpolated
        
        Returns:
        a dictionary of diffusion hyperparameters including:
            T (int), Beta/Alpha/Alpha_bar/Sigma (torch.tensor on cpu, shape=(T, ))
            These cpu tensors are changed to cuda tensors on each individual gpu
        """

        Beta = torch.linspace(beta_0, beta_T, T)  # Linear schedule
        Alpha = 1 - Beta
        Alpha_bar = Alpha + 0
        Beta_tilde = Beta + 0
        for t in range(1, T):
            Alpha_bar[t] *= Alpha_bar[t - 1]  # \bar{\alpha}_t = \prod_{s=1}^t \alpha_s
            Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (
                    1 - Alpha_bar[t])  # \tilde{\beta}_t = \beta_t * (1-\bar{\alpha}_{t-1})
            # / (1-\bar{\alpha}_t)
        Sigma = torch.sqrt(Beta_tilde)  # \sigma_t^2  = \tilde{\beta}_t

        _dh = {}
        _dh["T"], _dh["Beta"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"] = T, Beta, Alpha, Alpha_bar, Sigma
        return _dh

train_config.diffusion_hyperparams = calc_diffusion_hyperparams(train_config.T, train_config.beta_0, train_config.beta_T)

In [None]:
def train(model, optimizer, criterion, dataloader, train_config, device):
    num_epochs = train_config.num_viewed_samples // (len(dataloader) * train_config.batch_size)
    for epoch in range(num_epochs):
        print("Epoch number:", epoch)
        model.train()
        num_curr_samples = 0 # counter for the number of examples viewed
        for _, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
            # Get batch
            ids, (input, targets) = batch
            ecg_signal = input[0].to(device)
            targets = targets.long().to(device)
            # Get model input
            T, Alpha_bar = train_config.diffusion_hyperparams["T"], train_config.diffusion_hyperparams["Alpha_bar"]
            Alpha_bar = Alpha_bar.to(device)
            diffusion_steps = torch.randint(T, size=(ecg_signal.shape[0],1)).to(device) # randomly sample diffusion steps from 1~T
            z = torch.normal(0, 1, size=ecg_signal.shape).to(device)
            transformed_X = torch.sqrt(Alpha_bar[diffusion_steps]) * ecg_signal + torch.sqrt(1-Alpha_bar[diffusion_steps]) * z
            optimizer.zero_grad()
            # Loss propagation
            epsilon_theta = model((transformed_X, targets, None, diffusion_steps,))
            loss = criterion(epsilon_theta, z)
            loss.backward()

            # Gradient Norm Clipping
            if train_config.grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=train_config.grad_norm)

            # Gradient Val Clipping
            if train_config.grad_val is not None:
                torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=train_config.grad_val)
            # Optimizer step
            optimizer.step()
            # Interrupt process
            num_curr_samples += ecg_signal.shape[0]
            if num_curr_samples > train_config.num_viewed_samples:
                return model

In [None]:
# Initialize params
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.MSELoss()
sssd_model = sssd_model.to(device)
optimizer = torch.optim.Adam(params=sssd_model.parameters(), lr=train_config.lr)
trained_sssd_model = train(sssd_model, optimizer, criterion, train_loader, train_config, device)

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

## SSSD-ECG-nle Generation

In [None]:
# Configure parameters of dataset
@dataclass(repr=True)
class GenerationParams:
    generation_path: str = 'afib_synthetic_ptbxl'
    batch_size: int = 10
    leads: List[int] = field(default_factory=lambda: data_config.leads)
    ecg_length: int = 1000
    T: int = train_config.T # denoising num steps
    beta_0: float = train_config.beta_0 # first beta in markov chain
    beta_T: float = train_config.beta_T   # last beta. Linear interpolation between beta_0 and beta_T
    syndromes: List[str] = field(default_factory=lambda: data_config.syndromes) # CRBBB, 1AVB, PVC or any code from `scp_statements.csv`. For multilabel-scenario append to list
    pass

gen_config = GenerationParams()
gen_config.diffusion_hyperparams = calc_diffusion_hyperparams(gen_config.T, gen_config.beta_0, gen_config.beta_T)

In [None]:
def sampling_label(net, size, diffusion_hyperparams, device, cond=None, metadata=None):
    """
    Perform the complete sampling step according to p(x_0|x_T) = \prod_{t=1}^T p_{\theta}(x_{t-1}|x_t)

    Parameters:
    net (torch network):            the wavenet model
    size (tuple):                   size of tensor to be generated, 
                                    usually is (number of audios to generate, channels=1, length of audio)
    diffusion_hyperparams (dict):   dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams
                                    note, the tensors need to be cuda tensors 
    cond: conditioning as integer tensor
    guidance_weight: weight for classifier-free guidance (if trained with conditioning_dropout>0)
    
    Returns:
    the generated audio(s) in torch.tensor, shape=size
    """

    _dh = diffusion_hyperparams
    T, Alpha, Alpha_bar, Sigma = _dh["T"], _dh["Alpha"], _dh["Alpha_bar"], _dh["Sigma"]
    bs = size[0]
    assert len(Alpha) == T
    assert len(Alpha_bar) == T
    assert len(Sigma) == T
    assert len(size) == 3

    x = torch.normal(0, 1, size=size).to(device)
    with torch.no_grad():
        for t in range(T-1, -1, -1):
            diffusion_steps = (t * torch.ones((bs, 1))).to(device)  # use the corresponding reverse step
            epsilon_theta = net((x, cond, metadata, diffusion_steps,))  # predict \epsilon according to \epsilon_\theta               
            x = (x - (1-Alpha[t])/torch.sqrt(1-Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t])  # update x_{t-1} to \mu_\theta(x_t)
            if t > 0:
                x = x + Sigma[t] * torch.normal(0, 1, size=size).to(device)  # add the variance term to x_{t-1}
    return x

In [None]:
def generate_four_leads(tensor):
    leadI = tensor[:,0,:].unsqueeze(1) # I = LA - RA
    leadavf = tensor[:,1,:].unsqueeze(1) # 3/2(LL - V_w),    V_w = 1/3(RA + LA + LL)
    leadschest = tensor[:,2:8,:] # Vi - V_w

    leadII = (0.5*leadI) + leadavf # II = LL - RA

    leadIII = -(0.5*leadI) + leadavf # III = LL - LA
    leadavr = -(0.75*leadI) -(0.5*leadavf) # 3/2 (RA - V_w)
    leadavl = (0.75*leadI) - (0.5*leadavf) # 3/2 (LA - V_w)

    leads12 = torch.cat([leadI, leadII, leadIII, leadavr, leadavl, leadavf, leadschest], dim=1)

    return leads12

In [None]:
def generation(model, ecg_info, gen_config, device, mode='train'):
    # Main pipeline
    bar = tqdm(total=len(ecg_info))
    bs = gen_config.batch_size
    leads_num = len(gen_config.leads)
    raw_df = ecg_info.copy()
    if SAMPLE_FREQUENCY == 100:
        postfix = 'lr' # low-rate
    else:
        postfix = 'hr' # high-rate
    for i in range(0, len(ecg_info), bs):
        # Get batch
        if i + bs > len(ecg_info):
            batch_df = ecg_info.iloc[i:]
        else:
            batch_df = ecg_info.iloc[i:i+bs]
        # Get condition
        targets = [[1.0 if synd in eval(ecg_info.iloc[i+j]['scp_codes']).keys() else 0.0 for synd in gen_config.syndromes] 
            for j in range(len(batch_df))]
        cond = torch.tensor(targets).long().to(device)
        # Sampling
        generated_ecg = sampling_label(
            model, 
            (len(batch_df), leads_num, gen_config.ecg_length),
            gen_config.diffusion_hyperparams,
            device,
            cond,
            metadata=None
        )
        if leads_num == 8:
            generated_ecg12 = generate_four_leads(generated_ecg)
        else:
            assert leads_num == 12
            generated_ecg12 = generated_ecg
        # Saving
        for j, (_, row) in enumerate(batch_df.iterrows()):
            file_name = '_'.join(row[f'filename_{postfix}'].split('/')[1:]) + '.npz'
            fpath = os.path.join(gen_config.generation_path, 'data', file_name)
            np.savez(fpath, generated_ecg12[j].detach().cpu().numpy())
            raw_df.loc[i+j, 'fpath'] = fpath
        # update bar
        bar.update(len(batch_df))
    raw_df.to_csv(os.path.join(gen_config.generation_path, f"{gen_config.generation_path}_{mode}_map_file.csv"), index=False)

In [None]:
os.makedirs(os.path.join(gen_config.generation_path, 'data'), exist_ok=True)
sssd_model.to(device)
sssd_model.eval()
generation(sssd_model, train_ptbxl_info, gen_config, device)