# Modeling #

## Import APIs ##

In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import wfdb
import ast

## Load data ##

### Metadata ###

In [4]:
ptbxl_df = pd.read_csv('./cleaned_data/cleaned_ptbxl_metadata.csv', index_col='ecg_id')

In [5]:
ptbxl_df.head()

Unnamed: 0_level_0,age,sex,device,validated_by_human,diagnostic_superclass,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1,56.0,1,CS-12 E,True,['NORM'],3,records100/00000/00001_lr,records500/00000/00001_hr
2,19.0,0,CS-12 E,True,['NORM'],2,records100/00000/00002_lr,records500/00000/00002_hr
3,37.0,1,CS-12 E,True,['NORM'],5,records100/00000/00003_lr,records500/00000/00003_hr
4,24.0,0,CS-12 E,True,['NORM'],3,records100/00000/00004_lr,records500/00000/00004_hr
5,19.0,1,CS-12 E,True,['NORM'],4,records100/00000/00005_lr,records500/00000/00005_hr


In [6]:
metadata = ptbxl_df.loc[:, ['age', 'sex', 'device', 'validated_by_human']].copy()
metadata.head()

Unnamed: 0_level_0,age,sex,device,validated_by_human
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,56.0,1,CS-12 E,True
2,19.0,0,CS-12 E,True
3,37.0,1,CS-12 E,True
4,24.0,0,CS-12 E,True
5,19.0,1,CS-12 E,True


### Waveform data ###

In [7]:
waveform_data = []
for idx in ptbxl_df.index:
    record_path = ptbxl_df.loc[idx]['filename_hr']
    waveform_df = pd.read_csv('./cleaned_data/waveform_data/' + record_path + '.csv', index_col='Time (s)')
    waveform_data.append(waveform_df)
waveform_data = np.array(waveform_data)

In [8]:
waveform_data.shape

(21799, 1000, 12)

## Create recommended train-test split ##

This recommended train-test split code was obtained from the downloaded folder with the dataset: https://physionet.org/content/ptb-xl/1.0.3/.

In [9]:
# Split data into train and test
test_fold = 10

# Train
waveform_train = waveform_data[np.where(ptbxl_df.strat_fold != test_fold)]
metadata_train = metadata[ptbxl_df.strat_fold != test_fold]
y_train = ptbxl_df[ptbxl_df.strat_fold != test_fold].diagnostic_superclass

# Test
waveform_test = waveform_data[np.where(ptbxl_df.strat_fold == test_fold)]
metadata_test = metadata[ptbxl_df.strat_fold == test_fold]
y_test = ptbxl_df[ptbxl_df.strat_fold == test_fold].diagnostic_superclass

# Normalize data???

## Modeling ##

### CNN Autoencoder ###

In [10]:
class CNNAutoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels=12,
                      out_channels=32,
                      kernel_size=5,
                      stride=1),
            nn.ReLU(),
            nn.Conv1d(in_channels=32,
                      out_channels=64,
                      kernel_size=5,
                      stride=1),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(in_channels=64,
                               out_channels=32,
                               kernel_size=5,
                               stride=1),
            nn.ReLU(),
            nn.ConvTranspose1d(in_channels=32,
                               out_channels=12,
                               kernel_size=5,
                               stride=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        encoded_output = self.encoder(x)
        decoded_output = self.decoded(encoded_output)
        return encoded_output

### TCN Autoencoder ###

Model Card for the Hybrid Autoencoder
Model Name: Hybrid Autoencoder for ECG and Metadata

Description: This model is designed to learn compressed representations of combined ECG waveform and patient metadata. It utilizes separate pathways for waveform data and metadata, merging them into a dense representation which is then used to reconstruct both types of data.

Model Architecture:

Waveform Pathway: Convolutional layers followed by pooling and flattening.
Metadata Pathway: Dense layers.
Combined Encoding and Decoding: Dense layers.
Intended Use: Intended for anomaly detection in ECG data where additional patient metadata is available and considered relevant.

Data Used for Training: Assumes a dataset comprising ECG waveform data aligned with patient metadata such as age, sex, and device information.

Limitations: The model's effectiveness is highly dependent on the quality and preprocessing of the input data. The architecture needs fine-tuning and validation using real-world data to ensure robustness.

Ethical Considerations: Care should be taken to avoid biases that may arise from imbalanced data across different demographic groups. Privacy concerns should be addressed when handling patient data.

This framework sets up the foundation of your model; further tuning, training, and validation steps are needed to adapt it to specific tasks or datasets.

In [3]:
from pytorch_tcn import TCN

class TCNAutoencoder(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size, dropout, metadata_dims):
        super(TCNAutoencoder, self).__init__()
        self.encoder = TCN(
            num_inputs=num_inputs,
            num_channels=num_channels,
            kernel_size=kernel_size,
            dropout=dropout,
            causal=True,
        )
        self.age_embedding = nn.Embedding(120, metadata_dims[0])  # Assuming age range from 0 to 119
        self.sex_embedding = nn.Embedding(2, metadata_dims[1])  # Assuming sex is binary (0 or 1)
        self.device_embedding = nn.Embedding(num_devices, metadata_dims[2])  # num_devices is the number of unique devices
        
        decoder_input_dim = num_channels[-1] + sum(metadata_dims)
        self.decoder = TCN(
            num_inputs=decoder_input_dim,
            num_channels=num_channels[::-1],
            kernel_size=kernel_size,
            dropout=dropout,    
            causal=True,
            output_projection=num_inputs,
        )
        
    def forward(self, x, age, sex, device):
        encoded = self.encoder(x)
        
        age_emb = self.age_embedding(age)
        sex_emb = self.sex_embedding(sex)
        device_emb = self.device_embedding(device)
        
        metadata_emb = torch.cat([age_emb, sex_emb, device_emb], dim=-1)
        metadata_emb = metadata_emb.unsqueeze(2).expand(-1, -1, encoded.size(2))
        
        concatenated = torch.cat([encoded, metadata_emb], dim=1)
        decoded = self.decoder(concatenated)
        return decoded