# Train the 1D CNN Model on the PTB-XL Dataset

## Import modules

In [1]:
import h5py
import numpy as np

from pathlib import Path

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch.optim import Adam

In [2]:
project_path = Path.cwd().parent
project_path

PosixPath('/home/dk/Desktop/projects/split-learning-1D-HE')

## Load the Dataset

In [3]:
class PTBXL(Dataset):
    """
    The class used by the client to 
    load the PTBXL dataset

    Args:
        Dataset ([type]): [description]
    """
    def __init__(self, train=True):
        if train:
            with h5py.File(project_path/'data/train_ptbxl.hdf5', 'r') as hdf:
                self.x = hdf['X_train'][:]
                self.y = hdf['y_train'][:]
        else:
            with h5py.File(project_path/'data/test_ptbxl.hdf5', 'r') as hdf:
                self.x = hdf['X_test'][:]
                self.y = hdf['y_test'][:]
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=torch.float), torch.tensor(self.y[idx])

In [4]:
train_ds = PTBXL(train=True)
test_ds = PTBXL(train=False)
train_loader = DataLoader(train_ds, batch_size=4)
test_loader = DataLoader(test_ds, batch_size=4)

Test if we loaded the data correctly

In [9]:
len(train_ds)

19267

In [84]:
x, y = next(iter(train_loader))
print(x.shape)
print(y)

torch.Size([4, 12, 1000])
tensor([0, 0, 0, 0])


## Model

In [85]:
class ECGConv1D(nn.Module):
    def __init__(self):
        super(ECGConv1D, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=12,
                               out_channels=16,
                               kernel_size=7,
                               padding=3,
                               stride=1)  # 16 x 1000
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool1d(2)  # 16 x 500
        self.conv2 = nn.Conv1d(in_channels=16,
                               out_channels=8,
                               kernel_size=5,
                               padding=2)  # 8 x 500
        self.relu2 = nn.LeakyReLU()
        self.pool2 = nn.MaxPool1d(2)  # 8 x 250 = 2000

        self.linear = nn.Linear(8*250, 5)
        self.softmax = nn.Softmax(dim=1)  # 0 is the batch dimension

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(-1, 250*8)
        x = self.linear(x)
        x = self.softmax(x)
        
        return x

## Training Loop

In [86]:
def train(model: nn.Module, 
          lr: float, 
          epoch: int, 
          device: torch.device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr)
    
    train_losses = list()
    train_accs = list()
    test_losses = list()
    test_accs = list()
    # best_test_acc = 0

    for e in range(epoch):
        print(f'Epoch {e+1} - ', end= '')

        # train
        train_loss = 0.0
        correct, total = 0, 0
        
        for batch in train_loader:
            # forward pass
            optimizer.zero_grad()
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            # backward pass
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            correct += torch.sum(y_hat.argmax(dim=1) == y).item()
            total += len(y)           

        train_losses.append(train_loss / len(train_loader))
        train_accs.append(correct / total)
        print(f"train_loss: {train_losses[-1]:.4f}, train_acc: {train_accs[-1]*100:.2f}%", end=' / ')
        
        # test
        with torch.no_grad():
            test_loss = 0.0
            correct, total = 0, 0
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                y_hat = model(x)
                loss = criterion(y_hat, y)
            
                test_loss += loss.item()
                correct += torch.sum(y_hat.argmax(dim=1) == y).item()
                total += len(y)
            
            test_losses.append(test_loss / len(test_loader))
            test_accs.append(correct / total)

        print(f"test_loss: {test_losses[-1]:.4f}, test_acc: {test_accs[-1]*100:.2f}%")


## Train the model

Set the seed for generating random numbers

In [87]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Initialize the mode, save init weights

In [88]:
model = ECGConv1D()
torch.save(model.state_dict(), './weights/init_weight_ptbxl.pth')

Hyperparameters

In [89]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print(f'device: {torch.cuda.get_device_name(0)}')

learning_rate = 0.001
num_epoch = 10

device: NVIDIA GeForce RTX 2060


In [90]:
train(model, learning_rate, num_epoch, device)

Epoch 1 - train_loss: 1.3404, train_acc: 56.30% / test_loss: 1.2660, test_acc: 63.25%
Epoch 2 - train_loss: 1.2340, train_acc: 66.72% / test_loss: 1.2431, test_acc: 65.74%
Epoch 3 - train_loss: 1.2106, train_acc: 69.22% / test_loss: 1.2342, test_acc: 66.76%
Epoch 4 - train_loss: 1.1996, train_acc: 70.36% / test_loss: 1.2267, test_acc: 67.50%
Epoch 5 - train_loss: 1.1914, train_acc: 71.22% / test_loss: 1.2398, test_acc: 66.25%
Epoch 6 - train_loss: 1.1865, train_acc: 71.73% / test_loss: 1.2384, test_acc: 66.16%
Epoch 7 - train_loss: 1.1813, train_acc: 72.17% / test_loss: 1.2212, test_acc: 68.05%
Epoch 8 - train_loss: 1.1759, train_acc: 72.67% / test_loss: 1.2386, test_acc: 66.02%
Epoch 9 - train_loss: 1.1705, train_acc: 73.31% / test_loss: 1.2409, test_acc: 66.34%
Epoch 10 - train_loss: 1.1640, train_acc: 74.01% / test_loss: 1.2290, test_acc: 67.36%


In [91]:
torch.save(model.state_dict(), './weights/trained_weight_ptbxl.pth')