# Vision Transformer for ECG

# Patch Embedding

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset # wraps an iterable around the dataset
from torchvision import datasets    # stores the samples and their corresponding labels
from torchvision.transforms import transforms  # transformations we can perform on our dataset
from torchvision.transforms import ToTensor
import pandas as pd
import numpy as np
import os
#import wandb
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter


import torch.optim as optim
import torch.nn.functional as F

In [3]:
# Get cpu, gpu or mps device for training 
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

## Dataset

In [4]:
class ECGDataSet(Dataset):
    
    def __init__(self, split='train'):

        self.split = split

        # data loading
        current_directory = os.getcwd()
        self.parent_directory = os.path.dirname(current_directory)
        train_small_path = os.path.join(self.parent_directory, 'data', 'deepfake-ecg-small', str(self.split) + '.csv')
        self.df = pd.read_csv(train_small_path)  # Skip the header row
        
        # Avg RR interval
        # in milli seconds
        RR = torch.tensor(self.df['avgrrinterval'].values, dtype=torch.float32)
        # calculate HR
        self.y = 60 * 1000/RR

        # Size of the dataset
        self.samples = self.df.shape[0]

    def __getitem__(self, index):
        
        # file path
        filename= self.df['patid'].values[index]
        asc_path = os.path.join(self.parent_directory, 'data', 'deepfake-ecg-small', str(self.split), str(filename) + '.asc')
        
        ecg_signals = pd.read_csv( asc_path, header=None, sep=" ") # read into dataframe
        ecg_signals = torch.tensor(ecg_signals.values) # convert dataframe values to tensor
        
        ecg_signals = ecg_signals.float()
        
        # Transposing the ecg signals
        ecg_signals = ecg_signals/6000 # normalization
        ecg_signals = ecg_signals.t() 
        
        qt = self.y[index]
        # Retrieve a sample from x and y based on the index
        return ecg_signals, qt

    def __len__(self):
        # Return the total number of samples in the dataset
        return self.samples

In [5]:
# ECG dataset
train_dataset = ECGDataSet(split='train')
validate_dataset = ECGDataSet(split='validate')

## Data Loaders

In [8]:
# data loader
# It allows you to efficiently load and iterate over batches of data during the training or evaluation process.
train_dataloader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, num_workers=20)
validate_dataloader = DataLoader(dataset=validate_dataset, batch_size=8, shuffle=False, num_workers=20)



## Utility Funtions

In [9]:
# for the tensorboard
writer = SummaryWriter()

In [10]:
# train function with tensorbard
def trainTB(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader.dataset)
    model.train()
    loss = 0

    total_loss = 0
    # get the number of batches
    num_batches = len(dataloader)

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        
        # check the shape of pred and y here
        if batch == 1:
            print(pred.shape)       # this is [8,1]
            print(y.shape)          # this is [8]

        loss = loss_fn(pred, y)

        total_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    loss_avg = total_loss/num_batches
    print(f"Epoch [{epoch+1}], Average Loss: {loss_avg:.4f}")
    writer.add_scalar("Loss/train", loss_avg, epoch)

## Model

In [12]:
class PatchEmbed(nn.Module):
    """Split image (ECG in our case) into patches and then embed them.

    ECG --> 8,5000

    Paramerters
    ----------
    img_size : int
        Size of image (ECG) in pixels (samples).    (This is 1D 5000)

    patch_size : int

    in_chans : int
        Number of input channels. (This is 8)

    embed_dim : int
        Embedding dimension.

    Attributes
    ----------

    n_patches : int
        Number of patches inside of our image.

    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches and their embedding.

    """
    # This class is modified so that it works with 1D data.
    def __init__(self, img_size=5000, patch_size=50, in_chans=8, embed_dim=768):
        super().__init__()
        img_size = img_size
        patch_size = patch_size
        self.n_patches = (img_size // patch_size)

        # embed_dim is the output channel size of the convolutional layer.
        self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape is `(batch_size, in_chans, img_size)`.

        Returns
        -------
        torch.Tensor
            Shape is `(batch_size, n_patches, embed_dim)`.

        """

        x = self.proj(x) # (batch_size, embed_dim, n_patches)
        # I dont think flatten is needed for 1D data.
        #x = x.flatten(2) # flatten with 1st 2 dims intact
        # (batch_size, embed_dim, n_patches) --> (batch_size, n_patches, embed_dim)
        x = x.transpose(1,2)    # (batch_size, n_patches, embed_dim)

        return x

## Trainning

In [None]:
input_shape = (8,5000)  # Modify this according to your input shape
# 128 is the batch size, 8 is the number of channels, 5000 is the number of time steps

output_size = 1  # Number of output units

model = PatchEmbed(input_shape, output_size)
model.to(device)
print(model)

In [None]:
import torch.optim as optim

# Loss function for linear values (e.g., regression)
loss_fn = nn.MSELoss()  # Mean Squared Error loss

# use Nadam optimizer
optimizerN = optim.NAdam(model.parameters(), lr=0.0005)

In [None]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    trainTB(train_dataloader, model, loss_fn, optimizerN, t)
print("Done!")
writer.close()