# Computational Experiment:  transfer learning our model to a supervised classification problem

In [12]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

import wfdb
import ast

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## load data

Important: keep labels

In [14]:
ptbxl_data = pd.read_csv('./cleaned_data/cleaned_ptbxl_metadata.csv', index_col='ecg_id')
normal_data = ptbxl_data[ptbxl_data['diagnostic_superclass'] == 'NORMAL']
abnormal_data = ptbxl_data[ptbxl_data['diagnostic_superclass'] == 'ABNORMAL']

#add 
normal_metadata = normal_data.loc[:, ['age', 'sex', 'device', 'validated_by_human', 'diagnostic_superclass']].copy()
abnormal_metadata = abnormal_data.loc[:, ['age', 'sex', 'device', 'validated_by_human', 'diagnostic_superclass']].copy()
print(f'Normal metadata shape: {normal_metadata.shape}')
print(f'Abnormal metadata shape: {abnormal_metadata.shape}')

normal_ecg_data = np.load("./cleaned_data/normal_ecg_data.npy")
abnormal_ecg_data = np.load("./cleaned_data/abnormal_ecg_data.npy")


Normal metadata shape: (9069, 5)
Abnormal metadata shape: (9069, 5)


## Split Data

In [15]:
#train test
split_idx = int(normal_data.shape[0] * 0.8)

normal_ecg_train = normal_ecg_data[0:split_idx]
normal_ecg_test = normal_ecg_data[split_idx:]
print(f'Normal ECG train shape: {normal_ecg_train.shape}')
print(f'Normal ECG test shape: {normal_ecg_test.shape}')

abnormal_ecg_train = abnormal_ecg_data[0:split_idx]
abnormal_ecg_test = abnormal_ecg_data[split_idx:]
print(f'Abnormal ECG train shape: {abnormal_ecg_train.shape}')
print(f'Abnormal ECG test shape: {abnormal_ecg_test.shape}')

normal_metadata_train = normal_metadata[0:split_idx]
normal_metadata_test = normal_metadata[split_idx:]
print(f'Normal metadata train shape: {normal_metadata_train.shape}')
print(f'Normal metadata test shape: {normal_metadata_test.shape}')

abnormal_metadata_train = abnormal_metadata[0:split_idx]
abnormal_metadata_test = abnormal_metadata[split_idx:]
print(f'Abnormal metadata train shape: {abnormal_metadata_train.shape}')
print(f'Abnormal metadata test shape: {abnormal_metadata_test.shape}')

Normal ECG train shape: (7255, 1000, 12)
Normal ECG test shape: (1814, 1000, 12)
Abnormal ECG train shape: (7255, 1000, 12)
Abnormal ECG test shape: (1814, 1000, 12)
Normal metadata train shape: (7255, 5)
Normal metadata test shape: (1814, 5)
Abnormal metadata train shape: (7255, 5)
Abnormal metadata test shape: (1814, 5)


## Normalize waveforms


In [16]:
#norm:
def normalize_waveform(data):
    # Code generated from Bing Copilot
    normalized_data = np.empty_like(data)
    for i in range(data.shape[0]):
        for j in range(data.shape[2]):
            min_val = np.min(data[i, :, j])
            max_val = np.max(data[i, :, j])

            if max_val == min_val:
                normalized_data[i, :, j] = 0
            else:
                normalized_data[i, :, j] = (data[i, :, j] - min_val) / (max_val - min_val)
    return normalized_data

# Since normalization occurs only within each record, there will be no contamination from train data
std_normal_ecg_train = normalize_waveform(normal_ecg_train)
std_normal_ecg_test = normalize_waveform(normal_ecg_test)

std_abnormal_ecg_train = normalize_waveform(abnormal_ecg_train)
std_abnormal_ecg_test = normalize_waveform(abnormal_ecg_test)

In [17]:
# Combine normal and abnormal data
combined_ecg_train = np.concatenate((std_normal_ecg_train, std_abnormal_ecg_train), axis=0)
combined_metadata_train = pd.concat([normal_metadata_train, abnormal_metadata_train], axis=0)
combined_ecg_test = np.concatenate((std_normal_ecg_test, std_abnormal_ecg_test), axis=0)
combined_metadata_test = pd.concat([normal_metadata_test, abnormal_metadata_test], axis=0)

## Encode metadata and make DataLoaders

In [18]:
# Transformers
label_encoder = LabelEncoder()
combined_metadata_train['diagnostic_superclass'] = label_encoder.fit_transform(combined_metadata_train['diagnostic_superclass'])
combined_metadata_test['diagnostic_superclass'] = label_encoder.transform(combined_metadata_test['diagnostic_superclass'])


scaler_transformer = Pipeline(steps=[
    ('scaler', StandardScaler())
])
encoder_transformer = Pipeline(steps=[
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
])
preprocessor = ColumnTransformer(
    transformers=[
        ('num', scaler_transformer, ['age']),
        ('cat', encoder_transformer, ['sex', 'device', 'validated_by_human']),
    ]
)
std_combined_metadata_train = preprocessor.fit_transform(combined_metadata_train.drop('diagnostic_superclass', axis=1)).toarray()
std_combined_metadata_test = preprocessor.transform(combined_metadata_test.drop('diagnostic_superclass', axis=1)).toarray()

batch_size = 32
# Create DataLoaders
train_dataset = TensorDataset(torch.from_numpy(combined_ecg_train).float(),
                              torch.from_numpy(std_combined_metadata_train).float(),
                              torch.from_numpy(combined_metadata_train['diagnostic_superclass'].values).long())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(torch.from_numpy(combined_ecg_test).float(),
                             torch.from_numpy(std_combined_metadata_test).float(),
                             torch.from_numpy(combined_metadata_test['diagnostic_superclass'].values).long())
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


## Define Models

In [None]:
from pytorch_tcn import TCN
num_unique_devices = ptbxl_data['device'].nunique()
class TCNAutoencoder(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size, dropout, metadata_dims):
        super(TCNAutoencoder, self).__init__()
        self.encoder = TCN(
            num_inputs=num_inputs,
            num_channels=num_channels,
            kernel_size=kernel_size,
            dropout=dropout,
            causal=True,
        )
        self.age_embedding = nn.Linear(1, metadata_dims[0])  # Age is a single value
        self.sex_embedding = nn.Linear(2, metadata_dims[1])  # Sex is one-hot encoded (2 columns)
        self.device_embedding = nn.Linear(num_unique_devices, metadata_dims[2]) #one hot (11 cols)
        self.validated_embedding = nn.Linear(2, metadata_dims[3]) #one hot (2 cols)
        
        decoder_input_dim = num_channels[-1] + sum(metadata_dims)
        self.decoder = TCN(
            num_inputs=decoder_input_dim,
            num_channels=num_channels[::-1],
            kernel_size=kernel_size,
            dropout=dropout,    
            causal=True,
            output_projection=num_inputs,
        )
        
    def forward(self, x, metadata):
        encoded = self.encoder(x)
        
        age_emb = self.age_embedding(metadata[:, 0].unsqueeze(1))
        sex_emb = self.sex_embedding(metadata[:, 1:3])
        device_emb = self.device_embedding(metadata[:, 3:-2])
        validated_emb = self.validated_embedding(metadata[:, -2:])
        
        metadata_emb = torch.cat([age_emb, sex_emb, device_emb, validated_emb], dim=-1)
        metadata_emb = metadata_emb.unsqueeze(2).expand(-1, -1, encoded.size(2))
        
        concatenated = torch.cat([encoded, metadata_emb], dim=1)
        decoded = self.decoder(concatenated)
        return decoded

In [None]:
class TCNClassifier(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size, dropout, metadata_dims, num_classes):
        super(TCNClassifier, self).__init__()
        self.encoder = TCN(
            num_inputs=num_inputs,
            num_channels=num_channels,
            kernel_size=kernel_size,
            dropout=dropout,
            causal=True,
        )
        self.age_embedding = nn.Linear(1, metadata_dims[0])
        self.sex_embedding = nn.Linear(2, metadata_dims[1])
        self.device_embedding = nn.Linear(num_unique_devices, metadata_dims[2])
        self.validated_embedding = nn.Linear(2, metadata_dims[3])
        
        encoder_output_dim = num_channels[-1] + sum(metadata_dims)
        #instead of a decoder use a sequential
        self.classifier = nn.Sequential(
            nn.Linear(encoder_output_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x, metadata):
        encoded = self.encoder(x)
        age_emb = self.age_embedding(metadata[:, 0].unsqueeze(1))
        sex_emb = self.sex_embedding(metadata[:, 1:3])
        device_emb = self.device_embedding(metadata[:, 3:-2])
        validated_emb = self.validated_embedding(metadata[:, -2:])
        metadata_emb = torch.cat([age_emb, sex_emb, device_emb, validated_emb], dim=-1)
        metadata_emb = metadata_emb.unsqueeze(2).expand(-1, -1, encoded.size(2))
        concatenated = torch.cat([encoded, metadata_emb], dim=1)
        avg_pooled = torch.mean(concatenated, dim=2)
        logits = self.classifier(avg_pooled)
        return logits

TODO: VERIFY AND ENSURE THIS IS CORRECT

In [None]:
batch_size = 32
num_inputs = 12  # Assuming 12 input channels in the ECG data
num_channels = [32, 64, 128]  # Example number of channels in each residual block of the encoder
kernel_size = 3  # Example kernel size for the TCN layers
dropout = 0.2  # Example dropout rate
metadata_dims = [10, 5, 20, 5]  # Example embedding dimensions for age, sex, and device, and validated

pretrained_autoencoder = TCNAutoencoder(num_inputs, num_channels, kernel_size, dropout, metadata_dims)
pretrained_autoencoder.load_state_dict(torch.load("./models/tcn.pth"))

model = TCNClassifier(num_inputs, num_channels, kernel_size, dropout, metadata_dims, 2)
model.encoder.load_state_dict(pretrained_autoencoder.encoder.state_dict())

# Freeze the encoder layers
for param in model.encoder.parameters():
    param.requires_grad = False

model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
model.train()
for epoch in range(num_epochs):
    for batch_data, batch_metadata, batch_labels in train_loader:
        ecg_data = batch_data.to(device).permute(0, 2, 1).float()
        batch_metadata = batch_metadata.to(device).float()
        batch_labels = batch_labels.to(device)
        
        optimizer.zero_grad()
        logits = model(ecg_data, batch_metadata)
        loss = criterion(logits, batch_labels)
        loss.backward()
        optimizer.step()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")