Model Card for the Hybrid Autoencoder
Model Name: Hybrid Autoencoder for ECG and Metadata

Description: This model is designed to learn compressed representations of combined ECG waveform and patient metadata. It utilizes separate pathways for waveform data and metadata, merging them into a dense representation which is then used to reconstruct both types of data.

Model Architecture:

Waveform Pathway: Convolutional layers followed by pooling and flattening.
Metadata Pathway: Dense layers.
Combined Encoding and Decoding: Dense layers.
Intended Use: Intended for anomaly detection in ECG data where additional patient metadata is available and considered relevant.

Data Used for Training: Assumes a dataset comprising ECG waveform data aligned with patient metadata such as age, sex, and device information.

Limitations: The model's effectiveness is highly dependent on the quality and preprocessing of the input data. The architecture needs fine-tuning and validation using real-world data to ensure robustness.

Ethical Considerations: Care should be taken to avoid biases that may arise from imbalanced data across different demographic groups. Privacy concerns should be addressed when handling patient data.

This framework sets up the foundation of your model; further tuning, training, and validation steps are needed to adapt it to specific tasks or datasets.

In [None]:
#%pip install pytorch-tcn

In [7]:
import torch
from torch import nn
from torch import optim

from pytorch_tcn import TCN


In [8]:
class TCNAutoencoder(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size, dropout, metadata_dims):
        super(TCNAutoencoder, self).__init__()
        self.encoder = TCN(
            num_inputs=num_inputs,
            num_channels=num_channels,
            kernel_size=kernel_size,
            dropout=dropout,
            causal=True,
        )
        self.age_embedding = nn.Embedding(120, metadata_dims[0])  # Assuming age range from 0 to 119
        self.sex_embedding = nn.Embedding(2, metadata_dims[1])  # Assuming sex is binary (0 or 1)
        self.device_embedding = nn.Embedding(num_devices, metadata_dims[2])  # num_devices is the number of unique devices
        
        decoder_input_dim = num_channels[-1] + sum(metadata_dims)
        self.decoder = TCN(
            num_inputs=decoder_input_dim,
            num_channels=num_channels[::-1],
            kernel_size=kernel_size,
            dropout=dropout,    
            causal=True,
            output_projection=num_inputs,
        )
        
    def forward(self, x, age, sex, device):
        encoded = self.encoder(x)
        
        age_emb = self.age_embedding(age)
        sex_emb = self.sex_embedding(sex)
        device_emb = self.device_embedding(device)
        
        metadata_emb = torch.cat([age_emb, sex_emb, device_emb], dim=-1)
        metadata_emb = metadata_emb.unsqueeze(2).expand(-1, -1, encoded.size(2))
        
        concatenated = torch.cat([encoded, metadata_emb], dim=1)
        decoded = self.decoder(concatenated)
        return decoded