# Vision Transformer for ECG

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset # wraps an iterable around the dataset
from torchvision import datasets    # stores the samples and their corresponding labels
from torchvision.transforms import transforms  # transformations we can perform on our dataset
from torchvision.transforms import ToTensor
import pandas as pd
import numpy as np
import os
#import wandb
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter


import torch.optim as optim
import torch.nn.functional as F


import math
import numpy as np

In [2]:
# Get cpu, gpu or mps device for training 
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [3]:
device

'cuda'

# Dataset

In [21]:
class ECGDataSet(Dataset):
    
    def __init__(self, split='train'):

        self.split = split

        # data loading
        current_directory = os.getcwd()
        self.parent_directory = os.path.dirname(current_directory)
        train_small_path = os.path.join(self.parent_directory, 'data', 'deepfake-ecg-small', str(self.split) + '.csv')
        self.df = pd.read_csv(train_small_path)  # Skip the header row
        
        # Avg RR interval
        # in milli seconds
        RR = torch.tensor(self.df['avgrrinterval'].values, dtype=torch.float32)
        # calculate HR
        self.y = 60 * 1000/RR

        # Size of the dataset
        self.samples = self.df.shape[0]

    def __getitem__(self, index):
        
        # file path
        filename= self.df['patid'].values[index]
        asc_path = os.path.join(self.parent_directory, 'data', 'deepfake-ecg-small', str(self.split), str(filename) + '.asc')
        
        ecg_signals = pd.read_csv( asc_path, header=None, sep=" ") # read into dataframe
        ecg_signals = torch.tensor(ecg_signals.values) # convert dataframe values to tensor
        
        ecg_signals = ecg_signals.float()
        
        # Transposing the ecg signals
        ecg_signals = ecg_signals/6000 # normalization
        ecg_signals = ecg_signals.t() 
        
        qt = self.y[index]
        # Retrieve a sample from x and y based on the index
        return ecg_signals, qt

    def __len__(self):
        # Return the total number of samples in the dataset
        return self.samples

In [22]:
# ECG dataset
train_dataset = ECGDataSet(split='train')
validate_dataset = ECGDataSet(split='validate')

## Data Loaders

In [23]:
# data loader
# It allows you to efficiently load and iterate over batches of data during the training or evaluation process.
train_dataloader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, num_workers=20)
validate_dataloader = DataLoader(dataset=validate_dataset, batch_size=8, shuffle=False, num_workers=20)

## Utility Funtions

In [24]:
# for the tensorboard
writer = SummaryWriter()

In [25]:
# train function with tensorbard
def trainVIT(dataloader, model, loss_fn, optimizer, epoch):
    #size = len(dataloader.dataset)
    model.train()
    loss = 0

    total_loss = 0
    # get the number of batches
    num_batches = len(dataloader)

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        
        # check the shape of pred and y here
        if batch == 1:
            print(pred.shape)       # this is [8,1]
            print(y.shape)          # this is [8]

        loss = loss_fn(pred, y)

        total_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    loss_avg = total_loss/num_batches
    print(f"Epoch [{epoch+1}], Average Loss: {loss_avg:.4f}")
    writer.add_scalar("Loss/train", loss_avg, epoch)

In [18]:
# A dummy train function just to check the patch embedding
def trainPE(dataloader, model1,model2):
    #size = len(dataloader.dataset)
    model1.train()
    model2.train()
    #loss = 0

    #total_loss = 0
    # get the number of batches
    #num_batches = len(dataloader)

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model1(X)
        print("_________")
        pred2 = model2(pred)
        
        # check the shape of pred and y here
        #if batch == 1:
         #   print(pred.shape)       # this is [8,1]
          #  print(y.shape)          # this is [8]

        #loss = loss_fn(pred, y)

        #total_loss += loss.item()

        # Backpropagation
        #loss.backward()
        #optimizer.step()
        #optimizer.zero_grad()

    #loss_avg = total_loss/num_batches
    #print(f"Epoch [{epoch+1}], Average Loss: {loss_avg:.4f}")
    #writer.add_scalar("Loss/train", loss_avg, epoch)

# Model

## Patch Embedding

In [26]:
class PatchEmbed(nn.Module):
    """Split image (ECG in our case) into patches and then embed them.

    ECG --> 8,5000

    Paramerters
    ----------
    img_size : int
        Size of image (ECG) in pixels (samples).    (This is 1D 5000)

    patch_size : int

    in_chans : int
        Number of input channels. (This is 8)

    embed_dim : int
        Embedding dimension.

    Attributes
    ----------

    n_patches : int
        Number of patches inside of our image.

    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches and their embedding.

    """
    # This class is modified so that it works with 1D data.
    def __init__(self, img_size=5000, patch_size=50, in_chans=8, embed_dim=768):
        super().__init__()
        img_size = img_size
        patch_size = patch_size
        self.n_patches = (img_size // patch_size)

        # embed_dim is the output channel size of the convolutional layer.
        self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape is `(batch_size, in_chans, img_size)`.

        Returns
        -------
        torch.Tensor
            Shape is `(batch_size, n_patches, embed_dim)`.

        """

        x = self.proj(x) # (batch_size, embed_dim, n_patches)
        # I dont think flatten is needed for 1D data.
        #x = x.flatten(2) # flatten with 1st 2 dims intact
        # (batch_size, embed_dim, n_patches) --> (batch_size, n_patches, embed_dim)
        x = x.transpose(1,2)    # (batch_size, n_patches, embed_dim)
        # print the shape of x
        #print(x.shape)

        return x

## Attention Network

In [27]:
class Attention(nn.Module):
    """Attention mechanism.

    Parameters
    ----------
    dim : int
        Last dimension of the input tensors (embed_dim).
        The input and out dimension of per token features

    n_heads : int
        Number of attention heads.

    qkv_bias : bool
        If True then we include bias to the query, key and value projections.

    attn_p : float
        Dropout probability applied to the query, key and value tensors.

    proj_p : float
        Dropout probability applied to the output tensor.

    Attributes
    ----------

    scale : float
        Normalizing constant for the dot product.

    qkv : nn.Linear
        Linear projection for the query, key and value.

    proj : nn.Linear
        Linear mapping that takes in the concatenated output of all attention
        heads and maps it into a new space.

    attn_drop, proj_drop : nn.Dropout
        Dropout layers.

    """


    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads

        # define the dimentionality of each of the heads
        self.head_dim = dim // n_heads
        # when we concatonate all the heads we should get the same dim as the input
        
        # from attention is all you need paper
        self.scale = dim ** -0.5

        # get an embedding and output q, k, v
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)

        # get the concatenated output of all the heads and maps to a new mapping
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape is `(batch_size, n_patches + 1, dim)`.

        Returns
        -------
        torch.Tensor
            Shape is `(batch_size, n_patches + 1, dim)`.

        """

        n_samples, n_tokens, dim = x.shape

        if dim != self.dim:
            raise ValueError
        
        # n_patches + 1 --> class token as the first token
        qkv = self.qkv(x)   # (batch_size, n_patches + 1, 3 * dim) x is mulit dimensional
        qkv = qkv.reshape(n_samples,n_tokens,3,self.n_heads,self.head_dim) # (batch_size, n_patches + 1, 3, n_heads, head_dim)
        qkv = qkv.permute(2,0,3,1,4)    # (3, batch_size, n_heads, n_patches + 1, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2]
        k_t = k.transpose(-2,-1)    # (batch_size, n_heads, head_dim, n_patches + 1)
        dp = (q @ k_t) * self.scale # (batch_size, n_heads, n_patches + 1, n_patches + 1)
        attn = dp.softmax(dim=-1) # (batch_size, n_heads, n_patches + 1, n_patches + 1)
        attn = self.attn_drop(attn)

        weighted_avg = attn @ v # (batch_size, n_heads, n_patches + 1, head_dim)
        weighted_avg = weighted_avg.transpose(1,2) # (batch_size, n_patches + 1, n_heads, head_dim)
        weighted_avg = weighted_avg.flatten(2) # (batch_size, n_patches + 1, dim)

        x = self.proj(weighted_avg) # (batch_size, n_patches + 1, dim)
        x = self.proj_drop(x) # (batch_size, n_patches + 1, dim)
        #print(x.shape)
        return x


## MLP

In [28]:
class MLP(nn.Module):
    """Multi-layer perceptron.
    
    Parameters
    ----------
    in_features : int
        Number of input features.

    hidden_features : int
        Number of nodes in the hidden layer.

    out_features : int
        Number of output features.

    p : float
        Dropout probability.

    Attributes
    ----------
    fc : nn.Linear
        The first linear layer.

    act : nn.GELU
        GELU activation function.

    fc2 : nn.Linear
        The second linear layer.

    drop : nn.Dropout
        Dropout layer.
    
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, p=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape is `(batch_size, n_patches + 1, in_features)`.

        Returns
        -------
        torch.Tensor
            Shape is `(batch_size, n_patches + 1, out_features)`.

        """

        x = self.fc1(x) # (batch_size, n_patches + 1, hidden_features)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)    # (batch_size, n_patches + 1, out_features)

        return x


## Block

In [29]:
class Block(nn.Module):
    """Transformer block.

    Parameters
    ----------

    dim : int
        Number of input channels (embed_dim).

    n_heads : int
        Number of attention heads.

    mlp_ratio : float
        Determines the hidden dimension size of the `MLP` module relative to `dim`.

    qkv_bias : bool
        If True then we include bias to the query, key and value projections.

    p, attn_p : float
        Dropout probability.

    Attributes
    ----------

    norm1, norm2 : nn.LayerNorm
        Layer normalization.

    attn : Attention
        Attention module.
    
    mlp : MLP
        MLP module.

    """

    def __init__(self,dim,n_heads, mlp_ratio=4.0, qkv_basis=True, p=0, attn_p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(
            dim,
            n_heads=n_heads,
            qkv_bias=qkv_basis,
            attn_p=attn_p,
            proj_p=p
        )

        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=hidden_features,
            out_features=dim
        )

    def forward(self, x):
        """Run forward pass.
        
        Parameters
        ----------
        x : torch.Tensor
            Shape is `(batch_size, n_patches + 1, dim)`.

        Returns
        -------
        torch.Tensor
            Shape is `(batch_size, n_patches + 1, dim)`.

        """

        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x
    

## Vision Transformer

In [30]:
class VisionTransformer(nn.Module):
    """Simplified implementation of the Vision Transformer.

    Parameters
    ----------
    img_size : int
        Both height and the width of the image (ECG in our case). (This is 1D 5000)

    patch_size : int
        Both height and the width of the patch.

    in_chans : int
        Number of input channels. (This is 8)

    n_classes : int
        Number of classes to predict.

    embed_dim : int
        Dimensionality of the token/patch embeddings.

    depth : int
        Number of blocks.

    n_heads : int
        Number of attention heads.

    mlp_ratio : float
        Determines the hidden dimension size of the `MLP` module relative to `embed_dim`.

    qkv_bias : bool
        If True then we include bias to the query, key and value projections.

    p, attn_p : float
        Dropout probability.

    Attributes
    ----------

    patch_embed : PatchEmbed
        Instance of `PatchEmbed` layer.

    cls_token : nn.Parameter
        Learnable parameter that will represent the first token in the sequence.
        It has `embed_dim` elements.

    pos_emb : nn.Parameter
        Positional embedding of the cls_token + all the patches.
        It has `(n_patches + 1) * embed_dim` elements.

    pos_drop : nn.Dropout
        Dropout layer.

    blocks : nn.ModuleList
        List of `Block` modules.

    norm : nn.LayerNorm
        Layer normalization.

    """

    def __init__(
            self,
            img_size=5000,
            patch_size=50,
            in_chans=8,
            n_classes=1,
            embed_dim=768,
            depth=12,
            n_heads=12,
            mlp_ratio=4.,
            qkv_bias=True,
            p=0.,
            attn_p=0.
    ):
        super().__init__()

        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim
        )

        # zero tensors for the class token and the positional embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim))

        self.pos_drop = nn.Dropout(p=p)

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                n_heads=n_heads,
                mlp_ratio=mlp_ratio,
                qkv_basis=qkv_bias,
                p=p,
                attn_p=attn_p
            ) for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        """Run the forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape is `(batch_size, in_chans, img_size)`.

        Returns
        -------
        logits: torch.Tensor
            Logits over all the classes - `(batch_size, n_classes)`.

        """

        n_samples = x.shape[0]
        x = self.patch_embed(x) # (batch_size, n_patches + 1, embed_dim)

        cls_token = self.cls_token.expand(n_samples, -1, -1) # (batch_size, 1, embed_dim)
        x = torch.cat((cls_token, x), dim=1) # (batch_size, n_patches + 1, embed_dim)
        x = x + self.pos_embed # (batch_size, n_patches + 1, embed_dim)
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        cls_token_final = x[:, 0] # just the cls token
        x = self.head(cls_token_final)

        x = torch.squeeze(x)

        return x


## Transformer Encoder 2

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])


In [None]:
class Transformer(nn.Module):
    def __init__(
            self,
            num_tokens,
            dim_model,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            dropout_p,

    ):
        super().__init__()

        # INFO
        self.model_type = "Transformer"
        self.dim_model = dim_model

        #Layers
        self.positional_encoder = PositionalEncoder(
            dim_model = dim_model,
            dropout_p = dropout_p,
            max_len = 5000 )
        # change this embedding to the patch embedding
        self.embedding = nn.Embedding(num_tokens, dim_model)

        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p
        )
        self.out = nn.Linear(dim_model, num_tokens)

    def forward(self,
                src,
                tgt):
        # Src size must be (batch_size, src sequence length)
        # Tgt size must be (batch_size, tgt sequence length)

        # Embedding + positional encoding - Out size = (batch_size, sequence length, dim_model)
        src = self.embedding(src) * math.sqrt(self.dim_model)
        tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)

        # we permute to obtain size (sequence length, batch_size, dim_model),
        src = src.permute(1, 0, 2)
        tgt = tgt.permute(1, 0, 2)

        # Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
        transformer_out = self.transformer(src, tgt)
        out = self.out(transformer_out)


# Training

In [31]:
input_shape = (8,5000)  # Modify this according to your input shape
# 128 is the batch size, 8 is the number of channels, 5000 is the number of time steps

output_size = 1  # Number of output units

model = VisionTransformer(
    img_size=5000,
    patch_size=50,
    in_chans=8,
    n_classes=1,
    embed_dim=768,
    depth=12,
    n_heads=12,
    mlp_ratio=4.,
    qkv_bias=True,
    p=0.,
    attn_p=0.
)
model.to(device)
print(model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv1d(8, 768, kernel_size=(50,), stride=(50,))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=768, out_features=1, bias=True)


In [32]:
import torch.optim as optim

# Loss function for linear values (e.g., regression)
loss_fn = nn.MSELoss()  # Mean Squared Error loss

# use Nadam optimizer
optimizerN = optim.NAdam(model.parameters(), lr=0.0005)

In [33]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    #trainPE(train_dataloader, model1,model2)
    trainVIT(train_dataloader, model, loss_fn, optimizerN, t)
print("Done!")
writer.close()

Epoch 1
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [1], Average Loss: 195.1032
Epoch 2
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [2], Average Loss: 57.2543
Epoch 3
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [3], Average Loss: 55.0929
Epoch 4
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [4], Average Loss: 46.8162
Epoch 5
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [5], Average Loss: 44.7930
Epoch 6
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [6], Average Loss: 34.5139
Epoch 7
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [7], Average Loss: 32.1793
Epoch 8
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [8], Average Loss: 31.5948
Epoch 9
-------------------------------
torch.Size([8])
torch.Size([8])
Epoch [9], Average Loss: 28.5218
Epoch 10
-------------------------------
torch.Size([8