# Experiment B

Objective: Train a 1D-CNN Image Encoder on the SPH dataset.

Split the data into train, validation, test subsets.

Example code for finding f1_thresholds

```
def find_threshold_f1(trues, logits, eps=1e-9):
    precision, recall, thresholds = precision_recall_curve(trues, logits)
    f1_scores = 2 * precision * recall / (precision + recall + eps)
    threshold = float(thresholds[np.argmax(f1_scores)])  
    return threshold
```

trues = true labels (binarized)
logits = row outputs of the model (sigmoid output/probabilties)

For each label, there will be individual thresholds.

Filter out class samples where value counts < 100.

In [None]:
import sys
import h5py
from glob import glob
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score, precision_score, recall_score, f1_score

from albumentations.pytorch import ToTensorV2
from transformers import AutoTokenizer, AutoModel
import albumentations as A

from sklearn.model_selection import train_test_split
from tqdm import tqdm
import cv2
import random
import tarfile

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Use a fixed seed value
set_seed(42)

## Notebook Setup

In [None]:
class CONFIG:
    debug = False
    batch_size = 256
    num_workers = 8
    head_lr = 0.001
    image_encoder_lr = 0.001
    patience = 5
    factor = 0.8
    epochs = 20
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Image Model
    model_name = 'resnet18'
    image_embedding_size = 512

    # Text Model
    text_encoder_model = 'emilyalsentzer/Bio_ClinicalBERT'
    text_tokenizer = 'emilyalsentzer/Bio_ClinicalBERT'
    text_embedding_size = 768
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 10.0
    optimizer = torch.optim.Adam

    # image size
    size = 224

    # for projection head; used for both image and text encoder
    num_projection_layers = 1
    projection_dim = 128
    dropout = 0.0
    ecg_sr = 128

In [None]:
_ACTIVATION_DICT = {'relu': nn.ReLU,
                    'tanh': nn.Tanh,
                    'none': nn.Identity,
                    'leaky_relu': lambda: nn.LeakyReLU(negative_slope=0.2)}


class Conv1dBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 act='relu', bn=True, dropout=None,
                 maxpool=None, padding=None, stride=1):

        super().__init__()

        if padding is None or padding == 'same':
            padding = kernel_size // 2

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, bias=not bn, stride=stride)
        self.bn = nn.BatchNorm1d(out_channels) if bn else None
        self.act = _ACTIVATION_DICT[act]()
        self.dropout = None if dropout is None else nn.Dropout(dropout)
        self.maxpool = None if maxpool is None else nn.MaxPool1d(maxpool)

    def forward(self, x):
        x = self.conv(x)

        if self.bn is not None:
            x = self.bn(x)

        x = self.act(x)

        if self.dropout is not None:
            x = self.dropout(x)

        if self.maxpool is not None:
            x = self.maxpool(x)

        return x


class LinearBlock(nn.Module):
    def __init__(self, in_channels, out_channels, act='relu', bn=True, dropout=None):

        super().__init__()

        self.linear = nn.Linear(in_channels, out_channels, bias=not bn)
        self.bn = nn.BatchNorm1d(out_channels) if bn else None
        self.act = _ACTIVATION_DICT[act]()
        self.dropout = None if dropout is None else nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear(x)

        if self.bn is not None:
            x = self.bn(x)

        x = self.act(x)

        if self.dropout is not None:
            x = self.dropout(x)
        return x


class ConvEncoder(nn.Module):
    def __init__(self, in_channels, channels, kernels, bn=True, dropout=None, maxpool=2, padding=0, stride=None):
        super().__init__()



        num_layers = len(channels)
        if stride is None:
            stride = [1] * num_layers

        self.in_layer = Conv1dBlock(in_channels, channels[0], kernels[0], bn=bn, dropout=dropout, maxpool=maxpool, padding=padding, stride=stride[0])

        conv_layers = list()
        for i in range(1, num_layers):
            conv_layers.append(Conv1dBlock(channels[i - 1], channels[i], kernels[i], bn=bn, dropout=dropout, maxpool=maxpool, padding=padding, stride=stride[i]))
        self.conv_layers = nn.ModuleList(conv_layers)

    def forward(self, x):
        x = self.in_layer(x)
        for layer in self.conv_layers:
            x = layer(x)
        return x


class ECGEncoder(nn.Module):
    def __init__(self,
                 window=1280,
                 in_channels=12,
                 channels=(32, 32, 64, 64, 128, 128, 256, 256),
                 kernels=(7, 7, 5, 5, 3, 3, 3, 3),
                 linear=512,
                 output=512):

        super().__init__()

        self.conv_encoder = ConvEncoder(in_channels, channels,  kernels, bn=True)

        with torch.no_grad():
            inpt = torch.zeros((1, in_channels, window), dtype=torch.float32)
            outpt = self.conv_encoder(inpt)
            output_window = outpt.shape[2]

        self.flatten = nn.Flatten()
        self.conv_to_linear = nn.Linear(output_window * channels[-1], linear)
        self.act = nn.ReLU()
        self.out_layer = nn.Linear(linear, output)

    def forward(self, x):
        x = self.conv_encoder(x)
        x = self.flatten(x)
        x = self.conv_to_linear(x)
        x = self.act(x)
        x = self.out_layer(x)
        return x

## Loading Data

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!cp -r  '/content/drive/MyDrive/ECG Project (Shared Folder)/SPH Database/metadata.csv' "/content"
!cp -r '/content/drive/MyDrive/ECG Project (Shared Folder)/SPH Database/code.csv' "/content"

In [None]:
!cp -r "/content/drive/MyDrive/ECG Project (Shared Folder)/SPH Database/records.tar.gz" "/content"

In [None]:
with tarfile.open('records.tar.gz', 'r') as tar_ref:
    tar_ref.extractall('/content')

In [None]:
data_path = '/content/records'

In [None]:
ecg_files = sorted(glob(f'{data_path}/records/*.h5'))

In [None]:
def load_h5(file):
    with h5py.File(file, 'r') as f:
        signal = f['ecg'][()]
    fs = 500
    leads = ('I', 'II', 'III', 'aVF', 'aVR', 'aVL', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
    return signal.astype('float32'), leads, fs


def remove_nonprimary_code(x):
    r = []
    for cx in x:
        for c in cx.split('+'):
            if int(c) < 200 or int(c) >= 500:
                if c not in r:
                    r.append(c)
    return r


def codes_to_caption(codes):
    classes = [description_dict[int(code)].lower() for code in codes]
    caption = ', '.join(classes)
    return caption

df = pd.read_csv(f'/content/metadata.csv')
df['primary_codes'] = df['AHA_Code'].str.split(';').apply(remove_nonprimary_code)
description_dict = pd.read_csv(f'/content/code.csv').set_index('Code')['Description'].to_dict()
df['label'] = df['primary_codes'].apply(codes_to_caption)
df['ecg_file'] = df['ECG_ID'].apply(lambda x: f'/content/records/{x}.h5')

In [None]:
load_h5(df['ecg_file'].values[0])

(array([[ 0.02160645,  0.02160645,  0.02079773, ...,  0.05758667,
          0.0552063 ,  0.05041504],
        [ 0.1303711 ,  0.13195801,  0.13195801, ...,  0.03759766,
          0.03839111,  0.03518677],
        [ 0.10882568,  0.10961914,  0.11120605, ..., -0.02000427,
         -0.01600647, -0.01439667],
        ...,
        [-0.05441284, -0.05441284, -0.05279541, ..., -0.02479553,
         -0.02879333, -0.03439331],
        [-0.03759766, -0.0368042 , -0.0368042 , ..., -0.02160645,
         -0.02160645, -0.02400208],
        [-0.06240845, -0.06240845, -0.06240845, ...,  0.        ,
          0.00080013,  0.00080013]], dtype=float32),
 ('I', 'II', 'III', 'aVF', 'aVR', 'aVL', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'),
 500)

In [None]:
class ImageEncoder(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = CONFIG
        self.encoder = ECGEncoder(output=CONFIG.image_embedding_size)
        # Add some non-linear activation here like RELU because ECG encoder already returns linear layer so this will help.
        self.fc = nn.Linear(CONFIG.image_embedding_size, 47)  # Added this line

    def forward(self, x):
        x = self.encoder(x)
        x = self.fc(x)  # Added this line
        return x


class TextEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = CONFIG
        if CONFIG.pretrained:
            self.model = AutoModel.from_pretrained(CONFIG.text_encoder_model)
        else:
            self.model = AutoModel.from_config(CONFIG.text_encoder_model)

        self.tokenizer = AutoTokenizer.from_pretrained(CONFIG.text_tokenizer)

        for p in self.model.parameters():
            p.requires_grad = False  # Set requires_grad to False for all parameters

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, texts):
        input_ids, attention_mask = self.tokenize_texts(texts)
        embeddinbgs = self.inputs_to_embeddings(input_ids, attention_mask)
        return embeddinbgs

    def tokenize_texts(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=self.CONFIG.max_length, return_tensors='pt')
        input_ids = inputs['input_ids'].detach().to(self.CONFIG.device)
        attention_mask = inputs['attention_mask'].detach().to(self.CONFIG.device)
        return input_ids, attention_mask

    def inputs_to_embeddings(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :].detach()


class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CONFIG.projection_dim,
        dropout=CONFIG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CONFIG.temperature,
        image_embedding=CONFIG.image_embedding_size,
        text_embedding=CONFIG.text_embedding_size,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder(CONFIG)
        self.text_encoder = TextEncoder(CONFIG)
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        image_embeddings = self.image_to_embeddings(batch['image'])
        text_embeddings = self.text_to_embeddings(batch['caption'])

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax((images_similarity + texts_similarity) / 2 * self.temperature, dim=-1)
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean(), image_embeddings, text_embeddings

    def text_to_embeddings(self, texts):
        text_features = self.text_encoder(texts)
        text_embeddings = self.text_projection(text_features)
        return text_embeddings

    def image_to_embeddings(self, images):
        image_features = self.image_encoder(images)
        image_embeddings = self.image_projection(image_features)
        return image_embeddings


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()


def nxn_cos_sim(A, B, dim=1):
    a_norm = F.normalize(A, p=2, dim=dim)
    b_norm = F.normalize(B, p=2, dim=dim)
    return torch.mm(a_norm, b_norm.transpose(0, 1))

In [None]:
class CLIP_ECG_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, config):
        self.df = df
        self.config = CONFIG

        self.ecg_files = self.df['ecg_file'].values
        self.captions = self.df['label'].values

    def __len__(self, ):
        return len(self.df)

    def __getitem__(self, idx):
        ecg, leads, sr = load_h5(self.ecg_files[idx])
        ecg = ecg[:, :5000]
        caption = self.captions[idx]
        image = self.process_ecg(ecg, sr)
        return {'image': image, 'caption': caption}

    def process_ecg(self, ecg, sr):
        new_shape = int(self.config.ecg_sr * ecg.shape[1] / sr)
        ecg = resample(ecg, new_shape)
        return ecg


def resample(ecg, shape):
    resized = cv2.resize(ecg, (shape, ecg.shape[0]))
    resized = resized.astype(ecg.dtype)
    return resized

In [None]:
train, valid = train_test_split(df, test_size=0.3, random_state=42)

In [None]:
def classes_from_captions(captions, threshold=100):
    all_classes = [name.strip() for caption in captions for name in caption.strip().split(',')]
    counts = pd.Series(all_classes).value_counts()
    classes = counts[counts >= threshold].index.to_list()
    return classes

train_classes =  classes_from_captions(train['label'].values)
valid_classes =  classes_from_captions(valid['label'].values)
print(len(train_classes), len(valid_classes))

18 13


In [None]:
train_ds = CLIP_ECG_Dataset(train, CONFIG)
valid_ds = CLIP_ECG_Dataset(valid, CONFIG)

### Example of data objects

In [None]:
# ECG signal
train_ds[0]['image']

array([[-0.01913643,  0.06218195,  0.05179119, ..., -0.00331165,
        -0.01490986, -0.00428778],
       [-0.06293344, -0.00512707, -0.02361894, ..., -0.05451632,
        -0.07244587, -0.0581708 ],
       [-0.0437665 , -0.06813908, -0.07619953, ..., -0.05083656,
        -0.05754423, -0.05309439],
       ...,
       [-0.07562923, -0.04300499, -0.04617596, ..., -0.08976078,
        -0.09810352, -0.09272671],
       [-0.0504756 , -0.017416  , -0.01658916, ..., -0.07239342,
        -0.07944679, -0.07958317],
       [-0.04530811, -0.00483952, -0.00021254, ..., -0.06540203,
        -0.06481934, -0.06525326]], dtype=float32)

In [None]:
# Label(s)
train_ds[0]['caption']

'sinus bradycardia, atrial premature complex(es), sinus arrhythmia'

## Model Training Setup

### Create X_train, y_train, X_test, y_test variables

In [None]:
X_train = []
y_train = []
for i in range(len(train_ds)):
    sample = train_ds[i]
    X_train.append(sample['image'])
    y_train.append(sample['caption'])

X_test = []
y_test = []
for i in range(len(valid_ds)):
    sample = valid_ds[i]
    X_test.append(sample['image'])
    y_test.append(sample['caption'])

In [None]:
len(X_train), len(y_train), len(X_test), len(y_test)

(18039, 18039, 7731, 7731)

### OHE multiclass labels for y_train, y_test

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer

y_train_labels = [labels.split(', ') for labels in y_train]
y_test_labels = [labels.split(', ') for labels in y_test]

# Create a MultiLabelBinarizer object
mlb = MultiLabelBinarizer()

# Fit the MultiLabelBinarizer and transform the labels
y_train_binary = mlb.fit_transform(y_train_labels)
y_test_binary = mlb.transform(y_test_labels)

In [None]:
# Convert data and labels to tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train_binary, dtype=torch.float32)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test_binary, dtype=torch.float32)

In [None]:
# len(X_train_tensor), len(y_train_tensor), len(X_test_tensor), len(y_test_tensor)

In [None]:
print(mlb.classes_)

['2:1 av block' 'advanced (high-grade)' 'anterior mi' 'anteroseptal mi'
 'atrial fibrillation' 'atrial flutter' 'atrial premature complex(es)'
 'atrial premature complexes' 'av block' 'av conduction ratio n:d'
 'complete (third-degree)' 'early repolarization' 'extensive anterior mi'
 'incomplete right bundle-branch block' 'inferior mi'
 'junctional escape complex(es)' 'junctional premature complex(es)'
 'junctional tachycardia' 'left anterior fascicular block'
 'left atrial enlargement' 'left bundle-branch block'
 'left posterior fascicular block' 'left ventricular hypertrophy'
 'left-axis deviation' 'low voltage' 'mobitz type i (wenckebach)'
 'mobitz type ii' 'nonconducted' 'normal ecg' 'prolonged pr interval'
 'prolonged qt interval' 'right bundle-branch block'
 'right ventricular hypertrophy' 'right-axis deviation'
 'second-degree av block' 'short pr interval' 'sinus arrhythmia'
 'sinus bradycardia' 'sinus tachycardia' 'st deviation'
 'st deviation with t-wave change'
 'st-t change 

In [None]:
mlb_classes = ['2:1 av block', 'advanced (high-grade)', 'anterior mi', 'anteroseptal mi',
 'atrial fibrillation', 'atrial flutter', 'atrial premature complex(es)',
 'atrial premature complexes', 'av block', 'av conduction ratio n:d',
 'complete (third-degree)', 'early repolarization', 'extensive anterior mi',
 'incomplete right bundle-branch block', 'inferior mi',
 'junctional escape complex(es)', 'junctional premature complex(es)',
 'junctional tachycardia', 'left anterior fascicular block',
 'left atrial enlargement', 'left bundle-branch block',
 'left posterior fascicular block', 'left ventricular hypertrophy',
 'left-axis deviation', 'low voltage', 'mobitz type i (wenckebach)',
 'mobitz type ii', 'nonconducted', 'normal ecg', 'prolonged pr interval',
 'prolonged qt interval', 'right bundle-branch block',
 'right ventricular hypertrophy', 'right-axis deviation',
 'second-degree av block', 'short pr interval', 'sinus arrhythmia',
 'sinus bradycardia', 'sinus tachycardia', 'st deviation',
 'st deviation with t-wave change',
 'st-t change due to ventricular hypertrophy', 't-wave abnormality',
 'tu fusion', 'varying conduction', 'ventricular preexcitation',
 'ventricular premature complex(es)']

### Create TensorDataset and Dataloader objections respectively

In [None]:
from torch.utils.data import TensorDataset, DataLoader

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# Create DataLoaders
train_loader = DataLoader(train_dataset,
                          batch_size=CONFIG.batch_size,
                          num_workers=CONFIG.num_workers,
                          shuffle=True)

test_loader = DataLoader(test_dataset,
                         batch_size=CONFIG.batch_size,
                         num_workers=CONFIG.num_workers,
                         shuffle=False)

In [None]:
from torch.utils.data import random_split

# Determine the lengths of the splits
train_len = int(0.7 * len(train_dataset))  # 80% for training
val_len = len(train_dataset) - train_len   # 20% for validation

# Create the subsets
train_subset, val_subset = random_split(train_dataset, [train_len, val_len])

# Create DataLoaders for the subsets
train_loader = DataLoader(train_subset,
                          batch_size=CONFIG.batch_size,
                          num_workers=CONFIG.num_workers,
                          shuffle=True)

val_loader = DataLoader(val_subset,
                        batch_size=CONFIG.batch_size,
                        num_workers=CONFIG.num_workers,
                        shuffle=False)

## Accuracy Metrics/Helper Code

In [None]:
def calc_metrics(image_embeddings, captions, class_embeddings, class_names):
    similarity = nxn_cos_sim(image_embeddings, class_embeddings, dim=1)
    predictions_ids = similarity.argmax(dim=1)
    predictions = [class_names[idx] for idx in predictions_ids]
    tps = [prediction in caption for prediction, caption in zip(predictions, captions)]
    accuracy = np.mean(tps)

    results = dict()
    results['accuracy'] = accuracy

    similarity = similarity.detach().cpu().numpy()
    for i, name in enumerate(class_names):

        true = np.array([name in caption for caption in captions]).astype('int32')

        if true.std() > 0:
            results[f'{name}_rocauc'] = roc_auc_score(true, similarity[:, i])
            results[f'{name}_prauc'] = average_precision_score(true, similarity[:, i])
        else:
            results[f'{name}_rocauc'] = None
            results[f'{name}_prauc'] = None

    return results

def calc_accuracy(image_embeddings, captions, class_embeddings, class_names):
    similarity = nxn_cos_sim(image_embeddings, class_embeddings, dim=1)
    predictions_ids = similarity.argmax(dim=1)
    predictions = [class_names[idx] for idx in predictions_ids]
    tps = [prediction in caption for prediction, caption in zip(predictions, captions)]
    accuracy = np.mean(tps)
    return accuracy

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

## Model Training

In [None]:
def plot_training_history(epoch_losses, epoch_accuracies):
    epochs = range(1, len(epoch_losses) + 1)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, epoch_losses, 'b', label='Training loss')
    plt.title('Training loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.xticks(range(1, 11))  # Set xticks from 1 through 10

    plt.subplot(1, 2, 2)
    plt.plot(epochs, epoch_accuracies, 'b', label='Training accuracy')
    plt.title('Training accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.xticks(range(1, 11))  # Set xticks from 1 through 10

    plt.tight_layout()
    plt.show()

In [None]:
# Instantiate the model
model = ImageEncoder(CONFIG)

In [None]:
# Define the loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model.to(device)

In [None]:
for inputs, labels in train_loader:
    print(inputs.shape, labels)
    break

torch.Size([256, 12, 1280]) tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


### sigmoid training

In [None]:
from sklearn.metrics import precision_recall_curve

def find_threshold_f1(trues, logits, eps=1e-9):
    precision, recall, thresholds = precision_recall_curve(trues, logits)
    f1_scores = 2 * precision * recall / (precision + recall + eps)
    threshold = float(thresholds[np.argmax(f1_scores)])
    return threshold

In [None]:
def train_model(model, criterion, optimizer, num_epochs, train_loader, val_loader, device):
    # Initialize a dictionary to store the F1 scores and thresholds for each label
    f1_scores_per_label = {label: [] for label in mlb_classes}
    thresholds_per_label = {label: [] for label in mlb_classes}

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        preds_list = []
        labels_list = []

        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Get the predicted classes for each example in the batch
            outputs = torch.sigmoid(outputs)
            preds = (outputs > 0.5).float()

            # Store the predicted and true labels for later calculation of F1 score
            preds_list.append(preds.cpu().numpy())
            labels_list.append(labels.cpu().numpy())

            running_loss += loss.item()

        # Validation phase
        model.eval()
        val_preds_list = []
        val_labels_list = []

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                outputs = torch.sigmoid(outputs)

                val_preds_list.append(outputs.cpu().numpy())
                val_labels_list.append(labels.cpu().numpy())

        # Calculate the F1 score and threshold for each label
        val_preds_array = np.vstack(val_preds_list)
        val_labels_array = np.vstack(val_labels_list)

        for i, label in enumerate(mlb_classes):
            threshold = find_threshold_f1(val_labels_array[:, i], val_preds_array[:, i])
            preds = (val_preds_array[:, i] > threshold).astype(int)

            label_f1 = f1_score(val_labels_array[:, i], preds)
            f1_scores_per_label[label].append(label_f1)
            thresholds_per_label[label].append(threshold)

        epoch_loss = running_loss / len(train_loader.dataset)

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}')

    return model, f1_scores_per_label, thresholds_per_label

In [None]:
trained_model, f1_scores_per_label, thresholds_per_label = train_model(model, criterion, optimizer, num_epochs, train_loader, val_loader, device)

Epoch 1/10, Loss: 0.0004390630069472515
Epoch 2/10, Loss: 0.00020747224003645363
Epoch 3/10, Loss: 0.00018480859700995442
Epoch 4/10, Loss: 0.00016999281538659098
Epoch 5/10, Loss: 0.00015760898688036443
Epoch 6/10, Loss: 0.00014542886805740172
Epoch 7/10, Loss: 0.00013448310014498605
Epoch 8/10, Loss: 0.0001262946164531531
Epoch 9/10, Loss: 0.00011669035089042394
Epoch 10/10, Loss: 0.00010979735969062126


In [None]:
f1_scores_df = pd.DataFrame(f1_scores_per_label)
print(f1_scores_df)

In [None]:
thresholds_per_label_df = pd.DataFrame(thresholds_per_label)
print(thresholds_per_label_df)

### sigmoid evaluation

In [None]:
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, average_precision_score

def evaluate_model(model, test_loader, device, threshold=0.5):
    model.eval()
    test_preds_list = []
    test_labels_list = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)

            test_preds_list.append(outputs.cpu().numpy())
            test_labels_list.append(labels.cpu().numpy())

    # Calculate the F1 score, accuracy, AUC-ROC and AUC-PR for each label
    test_preds_array = np.vstack(test_preds_list)
    test_labels_array = np.vstack(test_labels_list)

    f1_scores_per_label = {}
    accuracy_per_label = {}
    roc_auc_scores_per_label = {}
    pr_auc_scores_per_label = {}

    for i, label in enumerate(mlb_classes):
        preds = (test_preds_array[:, i] > threshold).astype(int)

        label_f1 = f1_score(test_labels_array[:, i], preds)
        label_accuracy = accuracy_score(test_labels_array[:, i], preds)

        if len(np.unique(test_labels_array[:, i])) > 1:  # Check if the label has more than one unique class
            label_roc_auc = roc_auc_score(test_labels_array[:, i], test_preds_array[:, i])
            label_pr_auc = average_precision_score(test_labels_array[:, i], test_preds_array[:, i])
        else:
            label_roc_auc = None
            label_pr_auc = None

        f1_scores_per_label[label] = label_f1
        accuracy_per_label[label] = label_accuracy
        roc_auc_scores_per_label[label] = label_roc_auc
        pr_auc_scores_per_label[label] = label_pr_auc

    return f1_scores_per_label, accuracy_per_label, roc_auc_scores_per_label, pr_auc_scores_per_label

In [None]:
f1_scores_per_label_test, accuracy_per_label_test, roc_auc_scores_per_label_test, pr_auc_scores_per_label_test = evaluate_model(model, test_loader, device, threshold=0.5)

In [None]:
baseline_results_df = pd.DataFrame({
    'F1 Score': f1_scores_per_label_test,
    'Accuracy': accuracy_per_label_test,
    'ROC AUC': roc_auc_scores_per_label_test,
    'PR AUC': pr_auc_scores_per_label_test
})

# Transpose the DataFrame so that each row corresponds to a diagnosis and each column corresponds to a score
baseline_results_df = baseline_results_df.T

- F1 Score: This is the harmonic mean of precision and recall. A higher F1 score means that the classifier has a better balance between precision and recall. A score of 1.0 is perfect, and a score of 0.0 is the worst.

- Accuracy: This is the proportion of true results (both true positives and true negatives) among the total number of cases examined. A higher accuracy means that the classifier is correct more often. A score of 1.0 is perfect, and a score of 0.0 is the worst.

- ROC AUC: This is the area under the receiver operating characteristic curve. A higher ROC AUC means that the classifier is better at distinguishing between positive and negative instances. A score of 1.0 is perfect, and a score of 0.0 is the worst.

- PR AUC: This is the area under the precision-recall curve. A higher PR AUC means that the classifier is better at producing high precision at different recall levels. A score of 1.0 is perfect, and a score of 0.0 is the worst.

In [None]:
pd.set_option('display.max_columns', None)
baseline_results_df

Unnamed: 0,2:1 av block,advanced (high-grade),anterior mi,anteroseptal mi,atrial fibrillation,atrial flutter,atrial premature complex(es),atrial premature complexes,av block,av conduction ratio n:d,complete (third-degree),early repolarization,extensive anterior mi,incomplete right bundle-branch block,inferior mi,junctional escape complex(es),junctional premature complex(es),junctional tachycardia,left anterior fascicular block,left atrial enlargement,left bundle-branch block,left posterior fascicular block,left ventricular hypertrophy,left-axis deviation,low voltage,mobitz type i (wenckebach),mobitz type ii,nonconducted,normal ecg,prolonged pr interval,prolonged qt interval,right bundle-branch block,right ventricular hypertrophy,right-axis deviation,second-degree av block,short pr interval,sinus arrhythmia,sinus bradycardia,sinus tachycardia,st deviation,st deviation with t-wave change,st-t change due to ventricular hypertrophy,t-wave abnormality,tu fusion,varying conduction,ventricular preexcitation,ventricular premature complex(es)
F1 Score,0.0,0.0,0.0,0.0,0.753388,0.25,0.392593,0.0,0.0,0.0,0.0,0.0,0.0,0.233645,0.0,0.0,0.0,0.0,0.0,0.0,0.56,0.0,0.061538,0.0,0.32345,0.0,0.0,0.0,0.818472,0.0,0.0,0.850856,0.0,0.0,0.0,0.0,0.149533,0.893591,0.852174,0.136213,0.393574,0.0,0.595702,0.0,0.0,0.0,0.770992
Accuracy,0.998577,1.0,0.997801,0.996378,0.988229,0.996896,0.978787,0.999871,0.997154,0.999871,0.998965,0.998448,1.0,0.957573,0.995085,0.999353,0.997672,0.999612,0.99405,0.999353,0.998577,0.999871,0.99211,0.995473,0.967533,0.999741,0.999871,0.999871,0.815936,0.991851,0.998965,0.99211,1.0,0.993144,0.999612,0.999741,0.941146,0.977881,0.991204,0.932738,0.960936,0.996637,0.895356,0.999612,0.998189,0.999353,0.984478
ROC AUC,0.996703,,0.930996,0.935971,0.993662,0.984876,0.891075,0.934023,0.982665,0.986805,0.991325,0.871723,,0.88734,0.78685,0.954207,0.900435,0.960533,0.952643,0.798059,0.99977,0.985252,0.939656,0.908806,0.939171,0.991655,0.998836,0.925485,0.915366,0.900371,0.886702,0.996848,,0.857599,0.993099,0.449799,0.819298,0.994021,0.98797,0.909812,0.964605,0.943265,0.95127,0.810732,0.979507,0.756019,0.93158
PR AUC,0.407591,,0.162119,0.097391,0.869928,0.367964,0.376989,0.001957,0.171483,0.009709,0.281624,0.007611,,0.424138,0.032706,0.053134,0.021231,0.008861,0.283266,0.042363,0.935673,0.008696,0.350619,0.054358,0.287207,0.02404,0.1,0.001733,0.901976,0.16173,0.019721,0.959948,,0.098699,0.04177,0.000343,0.391952,0.957815,0.928528,0.452788,0.538256,0.134079,0.681351,0.001831,0.107899,0.017058,0.811596


In [None]:
baseline_results_df.to_csv('baseline_results.csv')

In [None]:
from google.colab import files

files.download('baseline_results.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>