# CSE676 Project SLL for 12-lead ECG
The code was adapted from https://github.com/tmehari/ecg-selfsupervised/tree/main?tab=readme-ov-file

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from helper_code2 import *
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import pandas as pd

transformation

In [None]:
class TimeOut:
    """Set random segment to 0. Expect Input is Tensor in (B,C,T) form. Output is Tensor in (B,C,T) form.
    """
    def __init__(self, crop_ratio_range=[0.0, 0.5]):
        self.crop_ratio_range = crop_ratio_range
        
    def __call__(self, sample):
        data, label = sample
        data = data.clone()
        timesteps = data.shape[-1]
        crop_ratio = random.uniform(*self.crop_ratio_range)
        crop_timesteps = int(crop_ratio*timesteps)
        start_idx = random.randint(0, timesteps - crop_timesteps-1)
        if data.dim() == 3:
            data[:, :, start_idx:start_idx+crop_timesteps] = 0
        else:
            data[:, start_idx:start_idx+crop_timesteps] = 0
        return data, label
    
class RandomResizeCrop:
    """Random crop and resize to original size. Input is Tensor in (B,C,T) form. Output is Tensor in (B,C,T) form
    """
    def __init__(self, crop_ratio_range=[0.5, 1.0], output_size=4096):
        self.crop_ratio_range = crop_ratio_range
        self.output_size=output_size
        
    def __call__(self, sample):
        data, label = sample
        timesteps = data.shape[-1]
        crop_ratio = random.uniform(*self.crop_ratio_range)
        crop_timesteps = int(crop_ratio*timesteps)
        start = random.randint(0, timesteps - crop_timesteps-1)
        if data.dim() == 3:
            cropped_data = data[:, :, start: start + crop_timesteps]
            resized = F.interpolate(cropped_data, size=self.output_size, mode='linear')
            return resized, label
        else:
            cropped_data = data[:, start: start + crop_timesteps]
            resized = F.interpolate(cropped_data.unsqueeze(0), size=self.output_size, mode='linear')
            return resized.squeeze(), label
    
class RandomTransformation:
    """Generate augmentated data.
    """
    def __init__(self, to_range=[0.0, 0.5], rrc_range=[0.5, 1.0]):
        self.to = TimeOut(to_range)
        self.rrc = RandomResizeCrop(rrc_range)
        
    def __call__(self, x):
        z1 = self.to(self.rrc(x))
        z2 = self.to(self.rrc(x))
        return z1, z2

In [3]:
class XResBlock1d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, stride=1):
        super().__init__()
        self.conv1d1 = nn.Conv1d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2)
        self.bn1 = nn.BatchNorm1d(out_channel)
        self.conv1d2 = nn.Conv1d(out_channel, out_channel, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2)
        self.bn2 = nn.BatchNorm1d(out_channel)
        self.relu = nn.ReLU()
        if stride != 1 or in_channel != out_channel:
            self.shorcut = nn.Sequential(
                nn.AvgPool1d(kernel_size=2, stride=stride, ceil_mode=True),
                nn.Conv1d(in_channel, out_channel, kernel_size=1),
                nn.BatchNorm1d(out_channel)
            )
        else:
            self.shorcut = nn.Identity()
        nn.init.constant_(self.bn2.weight, 0)
        
    def forward(self, x):
        output = self.relu(self.bn1(self.conv1d1(x)))
        output = self.bn2(self.conv1d2(output))
        output += self.shorcut(x)
        output = self.relu(output)
        return output
    
class XResNet18(nn.Module):
    def __init__(self, in_channel=12, out_channel=64, layers=[2, 2, 2, 2]):
        super().__init__()
        # Encoder
        self.stem = nn.Sequential(
            nn.Conv1d(in_channel, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(32),
            nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )
        self.stem_pool = nn.MaxPool1d(3,2, padding=1)
        
        self.block1 = self.make_layer(64, 64, layers[0])
        self.block2 = self.make_layer(64, 128, layers[1], stride=2)
        self.block3 = self.make_layer(128, 256, layers[2], stride=2)
        self.block4 = self.make_layer(256, 512, layers[3], stride=2)
        
        # Projector
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.projection = nn.Sequential(nn.Linear(512, 2048),
                        nn.ReLU(),
                        nn.Linear(2048, out_channel))
        
    def make_layer(self, in_channel, out_channel, n_block, stride=1):
        blocks = [XResBlock1d(in_channel, out_channel, stride=stride)]
        for _ in range(n_block-1):
            blocks.append(XResBlock1d(out_channel, out_channel, stride=1))
        return nn.Sequential(*blocks)
    
    def forward_encoder(self, x):
        out = self.stem_pool(self.stem(x))
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.avgpool(out)
        return out.squeeze(-1)
    
    def forward_projection(self, feature):
        out = self.projection(feature)
        return F.normalize(out, dim=1)
    
    def forward(self, x):
        feature = self.forward_encoder(x)
        out = self.forward_projection(feature)
        return out

Dataloader

In [95]:
class ToTensor:
    def __call__(self, array):
        return torch.from_numpy(array).T.float() # (T, C) -> (C, T)


class NormalizeECG:

    def __call__(self, sample, eps=1e-7):
        mean = sample.mean(dim=0, keepdim=True)
        std = sample.std(dim=0, keepdim=True)
        result = (sample - mean) / (std + eps)
        return result

class PadECG:

    def __init__(self, pad_to=4096):
        self.pad_to = pad_to

    def __call__(self, sample):
        if sample.shape[-1] >= self.pad_to:
            return sample[:, :self.pad_to] 
        else:
            padding = (0, self.pad_to - sample.shape[1], 0, 0)
            data = F.pad(sample, padding, "constant", 0)
            return data
        
class ResizeECG:
    def __init__(self, out_size=4096):
        self.out_size = out_size
        
    def __call__(self, sample):
        if sample.shape[-1] >= self.out_size:
            return sample[:, :self.out_size]
        else:
            resized = F.interpolate(sample.unsqueeze(0), size=self.out_size, mode='linear')
            return resized.squeeze()


class FolderDataset(Dataset):
    def __init__(self, folder, transform=None, min_len=800, upsampling=False):
        """
        Args:
            folder (str): Path to the folder containing the .dat and .hea pairs.
        """
        self.folder = folder
        self.min_len = min_len

        self.transform = transform

        self.record_paths, self.labels = self.find_records()
        self.remove_short()
        self.N_pos = sum(np.array(self.labels)==1)
        self.N_neg = sum(np.array(self.labels)==0)
        if upsampling:
            self.upsampling()
        
    def upsampling(self):
        pos_indices = [i for i, label in enumerate(self.labels) if label == 1]
        while sum(np.array(self.labels)==1)<sum(np.array(self.labels)==0):
            sampled_indice = random.choices(pos_indices)
            self.labels.append(1)
            self.record_paths.append(self.record_paths[sampled_indice[0]])

    def find_records(self):
        root = Path(self.folder)

        records = []
        for p in root.rglob('*.dat'):
            p = p.with_suffix('')
            header = load_header(p)
            label = get_label(header)
            records.append([p, label])

        paths, labels = zip(*records)
        return list(paths), list(labels)


    def remove_short(self):
        i = 0
        while i < len(self.record_paths):
            path = self.record_paths[i]
            signal, fields = load_signals(str(path))
            signal_len = signal.shape[0]
            if signal_len < self.min_len:
                self.record_paths.pop(i)
                self.labels.pop(i)
            else:
                i += 1

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        record = self.record_paths[idx]
        signal, fields = load_signals(record)

        if self.transform:
            signal = self.transform(signal)

        return signal, self.labels[idx]
    
    def get_weight(self):
        return self.N_neg / self.N_pos
    
    def get_n_pos(self):
        return sum(np.array(self.labels)==1)
    
    def get_n_neg(self):
        return sum(np.array(self.labels)==0)
    
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, x):
        for t in self.transforms:
            x = t(x)
        return x

class DataModule(pl.LightningDataModule):
    def __init__(self, path, transformation=None, augmentation=None, batchsize=64):
        super().__init__()
        self.path = path
        self.batchsize = batchsize
        self.transformation = transformation
        self.augmentation = augmentation
        
    def setup(self, stage=None):
        self.train_dataset = FolderDataset(self.path, transform=self.transformation)
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batchsize, shuffle=True)

loss

In [5]:
#https://colab.research.google.com/drive/1UK8BD3xvTpuSj75blz8hThfXksh19YbA?usp=sharing#scrollTo=GBNm6bbDT9J3
def nt_xent_loss(out_1, out_2, temperature=0.5, eps=1e-6):
    out = torch.cat([out_1, out_2], dim=0)
    n_samples = len(out)
    
    cov = torch.mm(out, out.t().contiguous())
    sim = torch.exp(cov / temperature)
    
    mask = ~torch.eye(n_samples, device=sim.device).bool()
    neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)
    
    # Positive similarity
    pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
    pos = torch.cat([pos, pos], dim=0)

    loss = -torch.log(pos / (neg + eps)).mean()
    return loss

pl module

In [6]:
class SSLModule(pl.LightningModule):
    def __init__(self, enconder, loss_fn=None, lr=1e-3, temperature=0.5, epochs=20):
        super().__init__()
        self.encoder = enconder
        self.lr = lr
        self.temperature = temperature
        self.loss_fn = loss_fn
        self.epochs = epochs
    
    def training_step(self, batch, batch_idx):
        (x1, _), (x2, _) = self.trainer.datamodule.augmentation((batch))
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)
        loss = self.loss_fn(z1, z2, temperature=self.temperature)
        self.log('train_loss', loss, prog_bar=True, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs)
        return [optimizer], [scheduler]

### Pretrain phase

In [7]:
transformation = Compose([ToTensor(), NormalizeECG(), ResizeECG()])
augmentation = RandomTransformation()

In [8]:
epochs = 100
lr=5e-3
out_channel = 256
layers = [3, 4,  6, 3]
temperature = 0.2
bs = 128

In [None]:
datamodule = DataModule(path="./code15_output/", transformation=transformation, augmentation=augmentation, batchsize=bs)
encoder = XResNet18(out_channel=out_channel, layers=layers)
model = SSLModule(enconder=encoder, loss_fn=nt_xent_loss, lr=lr, temperature=temperature, epochs=epochs)
trainer = pl.Trainer(max_epochs=epochs, accelerator='gpu')

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
d:\Tools\Anaconda\envs\gymenv\Lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [65]:
trainer.fit(model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params | Mode 
----------------------------------------------
0 | encoder | XResNet18 | 8.8 M  | train
----------------------------------------------
8.8 M     Trainable params
0         Non-trainable params
8.8 M     Total params
35.248    Total estimated model params size (MB)
140       Modules in train mode
0         Modules in eval mode
d:\Tools\Anaconda\envs\gymenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 99: 100%|██████████| 312/312 [03:36<00:00,  1.44it/s, v_num=1, train_loss_step=0.639, train_loss_epoch=1.050]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 312/312 [03:37<00:00,  1.44it/s, v_num=1, train_loss_step=0.639, train_loss_epoch=1.050]


v1: epochs = 50
lr=1e-2
out_channel = 256
layers = [3, 4,  6, 3]
temperature = 0.3
bs = 128

v2: epochs = 100
lr=5e-3
out_channel = 256
layers = [3, 4,  6, 3]
temperature = 0.2
bs = 128

In [17]:
log = pd.read_csv("lightning_logs/version_1/metrics.csv")
log = log[["epoch", "train_loss_epoch"]].dropna()
plt.figure(figsize=(12, 5))
plt.plot(log["epoch"], log["train_loss_epoch"])
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training loss")

Text(0.5, 1.0, 'Training loss')

: 

In [None]:
#torch.save(model.encoder.state_dict(), "encoder.pt")

### Finetune phase (transfer learning)

In [None]:
class ClassifierModule(pl.LightningModule):
    def __init__(self, encoder, out_dim=1, lr=1e-3, epochs=10, pos_weight=1, linear=False, frozen=True):
        super().__init__()
        self.encoder = encoder
        self.lr = lr
        if linear:
            self.classifier = nn.Linear(512, out_dim)
        else:
            self.classifier = nn.Sequential(
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(256, out_dim)
                )
        self.epochs = epochs
        self.best_val_loss = 1
        self.frozen = frozen 
        
        # Freeze encoder 
        if self.frozen:
            for param in self.encoder.parameters():
                param.requires_grad = False        
            
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        
    def forward(self, x):
        feature = self.encoder.forward_encoder(x)
        out = self.classifier(feature)
        return out.squeeze(-1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss_fn(pred, y.float())
        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss_fn(pred, y.float())
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs)
        return [optimizer], [scheduler]
    
    def on_validation_epoch_end(self):
        loss = self.trainer.callback_metrics.get("val_loss")
        if loss < self.best_val_loss:
            self.best_val_loss = loss
            torch.save(self.classifier.state_dict(), "Best_classifier.pt")
            if not self.frozen:
                torch.save(self.encoder.state_dict(), "Best_encoder.pt")
    
class FinetuneDataModule(pl.LightningDataModule):
    def __init__(self, train_path, val_path, transformation=None, augmentation=None, batchsize=64, upsampling=False):
        super().__init__()
        self.train_path = train_path
        self.val_path = val_path
        self.batchsize = batchsize
        self.transformation = transformation
        self.augmentation = augmentation
        self.upsampling = upsampling
        
    def setup(self, stage=None):
        self.train_dataset = FolderDataset(self.train_path, transform=self.transformation, upsampling=self.upsampling)
        self.val_dataset = FolderDataset(self.val_path, transform=self.transformation)
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batchsize, shuffle=True)
    
    def val_dataloader(self):
        if self.val_path:
            return DataLoader(self.val_dataset, batch_size=self.batchsize, shuffle=False)

In [148]:
encoder = XResNet18(out_channel=out_channel, layers=layers)
encoder.load_state_dict(torch.load("encoder_pretrain.pt"))

<All keys matched successfully>

In [149]:
epochs = 10
lr=5e-4
out_dim = 1
bs = 128

In [117]:
train_loader = FinetuneDataModule(train_path="./training_data/", val_path='./val_data/', transformation=transformation, batchsize=bs, upsampling=False)
train_loader.setup()
pos_weight = train_loader.train_dataset.get_weight()

In [150]:
classifier = ClassifierModule(encoder=encoder, lr=lr, out_dim = out_dim, epochs=epochs, pos_weight=pos_weight, frozen=False)
trainer = pl.Trainer(max_epochs=epochs, accelerator='gpu')

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(classifier, datamodule=train_loader)

V2 epochs = 10
lr=1e-4
out_dim = 1
bs = 128
no schedular

V5 epochs = 10
lr=5e-4
out_dim = 1
bs = 128
with schedular

V6 epochs = 20
lr=1e-3
out_dim = 1
bs = 128
with schedular

V7 epochs = 20
lr=1e-3
out_dim = 1
bs = 128
with schedular
pos_weight = 28

V8 epochs = 20
lr=1e-3
out_dim = 1
bs = 128
with schedular
upsampling

V9 epochs = 10
lr=1e-3
out_dim = 1
bs = 128
with schedular
upsampling
non_linear

V10 epochs = 20
lr=1e-3
out_dim = 1
bs = 128
with schedular
weighted
non_linear

V12 epochs = 20
lr=1e-3
out_dim = 1
bs = 128
with schedular
weighted
non_linear
no freeze