In [71]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from pathlib import Path
import pytorch_lightning as pl
import ast
import pandas as pd

In [72]:
project_path = Path.cwd()
test_csv_path = project_path/'data/coherent/csv/test_ecg.csv'
mitdb_pretrained_weight = project_path/'weights/trained_weight_mitdb.pth'
trained_weight_path = project_path/'weights/epoch=70-step=710.ckpt'

In [73]:
class TestECGDataset(Dataset):

    def __init__(self, csv_path: Path) -> None:
        self.df = pd.read_csv(csv_path)

    def __getitem__(self, patient_id):
        """
        Return torch Tensor of a patient's ECG signal and his/her label
        """
        patient_rows = self.df[self.df['patient']==patient_id]
        ecg_str = patient_rows['value']
        patient_ecg = []
        for ecg in ecg_str:
            patient_ecg.append([float(x) for x in ecg.split()])
        patient_ecg = torch.Tensor(patient_ecg)

        return torch.unsqueeze(patient_ecg, dim=1)

    def __len__(self):
        return len(self.ecg)

test_dataset = TestECGDataset(test_csv_path)

In [74]:
patient_id = 'b1ba081a-6299-6beb-5b72-582cd986697e'
patient_ecg = test_dataset[patient_id]
patient_ecg.shape

torch.Size([3, 1, 401])

In [75]:
class ECGModel(pl.LightningModule):
    def __init__(self, pretrained_weight_path: Path, lr=0.001):
        super().__init__()
        self.lr = lr
        # load the pretrained model on the MIT-BIH dataset
        self.conv1 = nn.Conv1d(1, 16, 7, padding=3)  # [bz, 16, 401]
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool1d(2)  # [bz, 16, 200]
        self.conv2 = nn.Conv1d(16, 16, 5, padding=2)  # [bz, 16, 200]
        self.relu2 = nn.LeakyReLU()
        self.pool2 = nn.MaxPool1d(2)  # [bz, 16, 100]
        # self.load_init_weights(pretrained_weight_path)  
        self.linear1 = nn.Linear(16*100, 512)
        self.relu3 = nn.LeakyReLU()
        self.linear2 = nn.Linear(512, 128)
        self.relu4 = nn.LeakyReLU()
        self.linear3 = nn.Linear(128, 5)

    # def load_init_weights(self, init_weight_path: Path):
    #     checkpoint = torch.load(init_weight_path)
    #     self.conv1.weight.data = checkpoint["conv1.weight"]
    #     self.conv1.bias.data = checkpoint["conv1.bias"]
    #     self.conv2.weight.data = checkpoint["conv2.weight"]
    #     self.conv2.bias.data = checkpoint["conv2.bias"]
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(-1, 16*100)
        x = self.linear1(x)
        x = self.relu3(x)
        x = self.linear2(x)
        x = self.relu4(x)
        x = self.linear3(x)
        return x

    def training_step(self, batch, batch_idx):
        xs, ys = batch
        y_hats = self.forward(xs)
        loss = F.binary_cross_entropy_with_logits(y_hats, ys)
        self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        xs, ys = batch
        y_hats = self.forward(xs)
        loss = F.binary_cross_entropy_with_logits(y_hats, ys)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [76]:
model = ECGModel(mitdb_pretrained_weight)
model.load_state_dict(torch.load(trained_weight_path)['state_dict'])

<All keys matched successfully>

In [77]:
y_hats = model(patient_ecg)
y_hats

tensor([[ 5.5490, -1.5859, -4.3191, -1.1513, -3.2590],
        [ 5.3677, -1.5840, -4.2544, -1.1541, -3.1579],
        [ 5.5371, -1.5713, -4.3006, -1.1459, -3.2567]],
       grad_fn=<AddmmBackward0>)

In [78]:
y_hats = torch.sigmoid(y_hats)
y_hats

tensor([[0.9961, 0.1700, 0.0131, 0.2403, 0.0370],
        [0.9954, 0.1702, 0.0140, 0.2397, 0.0408],
        [0.9961, 0.1720, 0.0134, 0.2412, 0.0371]], grad_fn=<SigmoidBackward0>)

In [79]:
mean_y_hats = torch.mean(y_hats, dim=0)
pred = [1 if i > 0.5 else 0 for i in mean_y_hats]

In [80]:
pred

[1, 0, 0, 0, 0]