In [1]:
import wfdb
import scipy
import os
import glob
import subprocess
import random
import pandas as pd
import numpy as np

In [2]:

DATA_ROOT = 'D:/ECG/processed/physionet2021'
PHYSIONET_PATH = 'D:/ECG/physionet.org/files/challenge-2021/1.0.3/training'

In [3]:
MANIFEST_PATH = f'{DATA_ROOT}/manifests'
FINE_TUNE_MANIFEST = f'{MANIFEST_PATH}/cinc'
MANIFEST_PATH_NOISE = f'{DATA_ROOT}/manifests_noise'
FINE_TUNE_MANIFEST_NOISE = f'{MANIFEST_PATH_NOISE}/cinc'
TOTAL_MANIFEST = f'{MANIFEST_PATH}/total/train.tsv'
PRE_TRAINING_MANIFEST = f"{MANIFEST_PATH}/cmsc"

In [4]:
FAIRSEQ_SIG_DIR = 'C:/Users/david/Documents/PythonScripts/fairseq-signals'
CONVERT_TO_CLOCS_MANIFEST = f'{FAIRSEQ_SIG_DIR}/fairseq_signals/data/ecg/preprocess/convert_to_clocs_manifest.py'
PRE_TRAINING_CONFIG = f'{FAIRSEQ_SIG_DIR}/examples/w2v_cmsc/config/pretraining'
FINE_TUNING_CONFIG = f'{FAIRSEQ_SIG_DIR}/examples/w2v_cmsc/config/finetuning/ecg_transformer'
FINE_TUNING_CONFIG_CMSC = f'{FAIRSEQ_SIG_DIR}/examples/w2v_cmsc/config/finetuning/ecg_transformer'

In [5]:
NOISE_DB = 'D:/ECG/physionet.org/files/nstdb/1.0.0'
WEIGHTS_PATH = f'{PHYSIONET_PATH}/weights.csv'

## Dataset Statistics

In [6]:
stats_df = pd.read_csv('dx_mapping_scored.csv', index_col='Dx')[['CPSC', 'CPSC_Extra', 'Chapman_Shaoxing', 'Georgia']]

In [7]:
stats_df

Unnamed: 0_level_0,CPSC,CPSC_Extra,Chapman_Shaoxing,Georgia
Dx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
atrial fibrillation,1221,153,1780,570
atrial flutter,0,54,445,186
bundle branch block,0,0,0,116
bradycardia,0,271,0,6
complete left bundle branch block,0,0,0,0
complete right bundle branch block,0,113,0,28
1st degree av block,722,106,247,769
incomplete right bundle branch block,0,86,0,407
left axis deviation,0,0,382,940
left anterior fascicular block,0,0,0,180


## Preprocessing

Step 1: Take 10 second, 500 fs ECG WFDB data from select databases, split them into two 5 second data segments.

In [5]:
cmd = f'python {FAIRSEQ_SIG_DIR}/fairseq_signals/data/ecg/preprocess/preprocess_physionet2021.py \
        {PHYSIONET_PATH} \
        --dest {DATA_ROOT} \
        --subset "WFDB_Ga, WFDB_CPSC2018, WFDB_CPSC2018_2, WFDB_ChapmanShaoxing"\
        --workers 1'

In [6]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))




Step 2: Take processed ECG data and split them into pretraining and fine-tuning sets. In our case we use Georgia (Ga) as our fine tuning set.

In [10]:
cmd = f'python {FAIRSEQ_SIG_DIR}/fairseq_signals/data/ecg/preprocess/manifest.py \
        {DATA_ROOT} \
        --dest {MANIFEST_PATH} \
        --subset "CPSC2018, CPSC2018_2, ChapmanShaoxing, Ga" \
        --combine_subsets Ga'

In [11]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))




Step 3: For only the fine-tuning set, add noise from MIT-BIH stress db. Noise comes in 360fs, so resample to 500fs. Randomly sample 5 seconds noise segments, then add to the original data.

In [35]:
def generate_noise(noise_path):
    noise_dict = {}
    for noise_type in ['bw', 'ma', 'em']:
        samp = wfdb.rdsamp(f'{noise_path}/{noise_type}')
        x = samp[0]
        fs = samp[1]['fs']
        resampled_noise = scipy.signal.resample(x, round(x.shape[0]*500/fs))
        noise_dict[noise_type] = resampled_noise.T
    return noise_dict

In [56]:
def add_noise(noise_dict, data_root):
    for dataset in ['Ga']:
        isExist = os.path.exists(f'{data_root}/{dataset}_noise')
        if not isExist:
            os.makedirs(f'{data_root}/{dataset}_noise')
        for file in os.listdir(f'{data_root}/{dataset}'):
            sample = scipy.io.loadmat(f'{DATA_ROOT}/{dataset}/{file}')
            for noise_type in ['bw', 'ma', 'em']:
                start_idx = random.randint(0, 902778-2501)
                sampled_noise = noise_dict[noise_type][:, start_idx:start_idx+2500]
                sample['feats'][:2,] = sample['feats'][:2,]+sampled_noise
            scipy.io.savemat(f'{DATA_ROOT}/{dataset}_noise/{file}', sample)

In [36]:
noise_dict = generate_noise(NOISE_DB)

In [57]:
add_noise(noise_dict, DATA_ROOT)

Step 4: Repeat step 2 for the noisy data

In [62]:
cmd = f'python {FAIRSEQ_SIG_DIR}/fairseq_signals/data/ecg/preprocess/manifest.py \
        {DATA_ROOT} \
        --dest {MANIFEST_PATH_NOISE} \
        --subset "Ga_noise" \
        --combine_subsets Ga_noise'

In [63]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))




## Pre-training

We will train two different pretraining models, w2v_cmsc_rlm, and w2v_cmsc. 

In [7]:
cmd = f"python {CONVERT_TO_CLOCS_MANIFEST} \
            {TOTAL_MANIFEST} \
            --dest {MANIFEST_PATH}"
print(cmd)

python C:/Users/david/Documents/PythonScripts/fairseq-signals/fairseq_signals/data/ecg/preprocess/convert_to_clocs_manifest.py             D:/ECG/processed/physionet2021/manifests/total/train.tsv             --dest D:/ECG/processed/physionet2021/manifests


In [8]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))




In [9]:
cmd = f"fairseq-hydra-train \
            task.data={PRE_TRAINING_MANIFEST} \
            --config-dir {PRE_TRAINING_CONFIG} \
            --config-name w2v_cmsc_rlm"

In [None]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))

In [10]:
cmd = f"fairseq-hydra-train \
            task.data={PRE_TRAINING_MANIFEST} \
            --config-dir {PRE_TRAINING_CONFIG} \
            --config-name w2v_cmsc"

In [None]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))

## Fine-tuning

First, we will perform fine-tuning on the w2v_cmsc_rlm

In [10]:
PRETRAIN_CHECKPOINT = f'C:/Users/david/Documents/PythonScripts/ECGPretraining/outputs/2023-04-22/11-42-07/checkpoints/checkpoint100.pt'

1. With first two leads

In [7]:
cmd = f"fairseq-hydra-train \
        task.data={FINE_TUNE_MANIFEST} \
        model.model_path={PRETRAIN_CHECKPOINT} \
        --config-dir {FINE_TUNING_CONFIG} \
        --config-name diagnosis"

In [None]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))

2. With noisy first two leads

In [68]:
cmd = f"fairseq-hydra-train \
        task.data={FINE_TUNE_MANIFEST_NOISE} \
        model.model_path={PRETRAIN_CHECKPOINT} \
        --config-dir {FINE_TUNING_CONFIG} \
        --config-name diagnosis"

In [None]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))

Validation

In [11]:
FINETUNE_CHECKPOINT = f'C:/Users/david/Documents/PythonScripts/ECGPretraining/outputs/2023-04-29/16-42-15/checkpoints/checkpoint_last.pt'

In [12]:
cmd = f"fairseq-hydra-validate \
        common_eval.path={FINETUNE_CHECKPOINT} \
        task.data={FINE_TUNE_MANIFEST} \
        model.model_path={PRETRAIN_CHECKPOINT} \
        --config-dir {FINE_TUNING_CONFIG} \
        --config-name diagnosis"

In [13]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))

[2023-05-06 01:44:06,443][fairseq_cli.validate][INFO] - loading model from C:/Users/david/Documents/PythonScripts/ECGPretraining/outputs/2023-04-29/16-42-15/checkpoints/checkpoint_last.pt
[2023-05-06 01:44:08,412][fairseq_signals.models.transformer][INFO] - Loaded pre-trained model parameters from C:/Users/david/Documents/PythonScripts/ECGPretraining/outputs/2023-04-22/11-42-07/checkpoints/checkpoint100.pt
[2023-05-06 01:44:08,432][fairseq_signals.utils.checkpoint_utils][INFO] - Loaded a checkpoint in 1.99s
[2023-05-06 01:44:08,433][fairseq_cli.validate][INFO] - num. shared model params: 62,041,754 (num. trained: 62,041,754)
[2023-05-06 01:44:08,434][fairseq_cli.validate][INFO] - num. expert model params: 0 (num. trained: 0)
[2023-05-06 01:44:08,583][fairseq_cli.validate][INFO] - {'_name': None,
 'checkpoint': {'_name': None, 'save_dir': 'checkpoints', 'restore_file': 'checkpoint_last.pt', 'finetune_from_model': None, 'reset_dataloader': False, 'reset_lr_scheduler': False, 'reset

Next, do the same for w2v_cmsc

In [6]:
PRETRAIN_CHECKPOINT_CMSC = f'C:/Users/david/Documents/PythonScripts/ECGPretraining/outputs/2023-05-02/19-58-34/checkpoints/checkpoint_last.pt'


1. With first two leads

In [8]:
cmd = f"fairseq-hydra-train \
        task.data={FINE_TUNE_MANIFEST} \
        model.model_path={PRETRAIN_CHECKPOINT_CMSC} \
        --config-dir {FINE_TUNING_CONFIG} \
        --config-name diagnosis"

In [None]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))

2. With noisy first two leads

In [7]:
cmd = f"fairseq-hydra-train \
        task.data={FINE_TUNE_MANIFEST_NOISE} \
        model.model_path={PRETRAIN_CHECKPOINT_CMSC} \
        --config-dir {FINE_TUNING_CONFIG} \
        --config-name diagnosis"

In [None]:
result = subprocess.run(cmd, stdout=subprocess.PIPE, shell=True)
print(result.stdout.decode('utf-8'))

## Classical classification for Baseline Models

In [8]:
from torch.utils.data import Dataset


class ECGDataset(Dataset):
    
    def __init__(self, manifest, data_root):
        self.manifest = pd.read_table(manifest, header=None, skiprows=1)
        self.data_root = data_root
        
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, index):
        file = self.manifest.iloc[index, 0]
        ecg_data = scipy.io.loadmat(f"{self.data_root}/{file}")
        return np.expand_dims(ecg_data['feats'], axis=0), ecg_data['label']
        # return ecg_data['feats'], ecg_data['label']

train_dataset = ECGDataset(TOTAL_MANIFEST, DATA_ROOT)
val_dataset = ECGDataset(f'{FINE_TUNE_MANIFEST}/valid.tsv', DATA_ROOT)
test_dataset = ECGDataset(f'{FINE_TUNE_MANIFEST}/test.tsv', DATA_ROOT)

In [9]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

In [10]:
train_features, train_labels = next(iter(train_dataloader))

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [37]:
class CNN(nn.Module):  
    def __init__(self, num_labels):
        super().__init__()

        self.conv11 = nn.Conv2d(1, 64, (1, 32), (1, 8))
        self.bn11 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()
        self.conv12 = nn.Conv2d(64, 64, (1, 32), (1, 8))
        self.bn12 = nn.BatchNorm2d(64)



        self.conv21 = nn.Conv2d(64, 128, (1, 32), (2, 8))
        self.bn21 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()
        self.conv22 = nn.Conv2d(128, 128, (1, 32), (2, 8))
        self.bn22 = nn.BatchNorm2d(128)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = nn.Linear(128, num_labels)
    
    def forward(self, x):
        x = self.conv11(x)
        x = self.bn11(x)
        x = self.relu1(x)
        #x = self.conv12(x)
        #x = self.bn12(x)
        
        x = self.conv21(x)
        x = self.bn21(x)
        x = self.relu2(x)
        #x = self.conv22(x)
        #x = self.bn22(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, start_dim=1)
        out = self.fc1(x)
        
        return out
    

model = CNN(num_labels=26)
model

CNN(
  (conv11): Conv2d(1, 64, kernel_size=(1, 32), stride=(1, 8))
  (bn11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (conv12): Conv2d(64, 64, kernel_size=(1, 32), stride=(1, 8))
  (bn12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv21): Conv2d(64, 128, kernel_size=(1, 32), stride=(2, 8))
  (bn21): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (conv22): Conv2d(128, 128, kernel_size=(1, 32), stride=(2, 8))
  (bn22): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc1): Linear(in_features=128, out_features=26, bias=True)
)

In [13]:
from torchvision import models

In [18]:
def get_resnet():
    model = models.resnet18(weights=None)
    #model = models.resnet50(weights=None)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.fc = nn.Linear(512, 26)
    #model.fc = nn.Linear(2048, 26)
    
    return model

In [19]:
model = get_resnet()

In [20]:
model

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [38]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

924058

In [39]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [40]:
import numpy as np
from tqdm import tqdm


In [41]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [42]:
device

device(type='cuda', index=0)

In [43]:
model.to(device)

CNN(
  (conv11): Conv2d(1, 64, kernel_size=(1, 32), stride=(1, 8))
  (bn11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (conv12): Conv2d(64, 64, kernel_size=(1, 32), stride=(1, 8))
  (bn12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv21): Conv2d(64, 128, kernel_size=(1, 32), stride=(2, 8))
  (bn21): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (conv22): Conv2d(128, 128, kernel_size=(1, 32), stride=(2, 8))
  (bn22): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc1): Linear(in_features=128, out_features=26, bias=True)
)

In [44]:
weights = pd.read_csv(WEIGHTS_PATH, index_col=0)
weights = torch.from_numpy(weights.to_numpy())

In [45]:
def compute_mcm(labels, outputs):
    num_recordings, num_classes = labels.size()
    A = torch.zeros((num_classes, num_classes))

    for i in range(num_recordings):
        normalization = float(max(torch.sum(torch.any(torch.cat((labels[i, :].reshape(-1,1), outputs[i, :].reshape(-1,1)), dim=1), dim=0)), 1))
        for j in range(num_classes):
            if labels[i, j]:
                for k in range(num_classes):
                    if outputs[i, k]:
                        A[j, k] += 1.0/normalization
    return A

def get_cinc_score(y_true, y_pred):
    confusion_matrix = compute_mcm(y_true, y_pred)
    observed_score = torch.nansum(confusion_matrix * weights)

    confusion_matrix_correct = compute_mcm(y_true, y_true)
    correct_score = torch.nansum(confusion_matrix_correct * weights)

    sinus = torch.zeros(64, 26, dtype=torch.int32)
    sinus[:,14] = 1
    confusion_matrix_inactive = compute_mcm(y_true, sinus)
    inactive_score = torch.nansum(confusion_matrix_inactive * weights)

    #print(y_true, y_pred)
    
    #print(observed_score, inactive_score, correct_score)

    if correct_score != inactive_score:
        score = ((observed_score - inactive_score) / (correct_score - inactive_score))
    else:
        score = 0.0
        
    return score

In [46]:
def load_table(table_file):
    table = list()
    with open(table_file, 'r') as f:
        for i, l in enumerate(f):
            arrs = [arr.strip() for arr in l.split(',')]
            table.append(arrs)

    num_rows = len(table)-1
    if num_rows<1:
        raise Exception('The table {} is empty.'.format(table_file))
    row_lengths = set(len(table[i])-1 for i in range(num_rows))
    if len(row_lengths)!=1:
        raise Exception('The table {} has rows with different lengths.'.format(table_file))
    num_cols = min(row_lengths)
    if num_cols<1:
        raise Exception('The table {} is empty.'.format(table_file))

    # Find the row and column labels.
    rows = [table[0][j+1] for j in range(num_rows)]
    cols = [table[i+1][0] for i in range(num_cols)]

    # Find the entries of the table.
    values = np.zeros((num_rows, num_cols), dtype=np.float64)
    for i in range(num_rows):
        for j in range(num_cols):
            value = table[i+1][j+1]
            values[i, j] = float(value)

    return rows, cols, values

def load_weights(weight_file):
    # Load the table with the weight matrix.
    rows, cols, values = load_table(weight_file)

    # Split the equivalent classes.
    rows = [set(row.split('|')) for row in rows]
    cols = [set(col.split('|')) for col in cols]
    assert(rows == cols)

    # Identify the classes and the weight matrix.
    classes = rows
    weights = values

    return classes, weights

def compute_modified_confusion_matrix(labels, outputs):
    # Compute a binary multi-class, multi-label confusion matrix, where the rows
    # are the labels and the columns are the outputs.
    num_recordings, num_classes = np.shape(labels)
    A = np.zeros((num_classes, num_classes))

    # Iterate over all of the recordings.
    for i in range(num_recordings):
        # Calculate the number of positive labels and/or outputs.
        normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1))
        # Iterate over all of the classes.
        for j in range(num_classes):
            # Assign full and/or partial credit for each positive class.
            if labels[i, j]:
                for k in range(num_classes):
                    if outputs[i, k]:
                        A[j, k] += 1.0/normalization

    return A

def compute_challenge_metric(weights, labels, outputs, classes, sinus_rhythm={'426783006'}):
    num_recordings, num_classes = np.shape(labels)
    if sinus_rhythm in classes:
        sinus_rhythm_index = classes.index(sinus_rhythm)
    else:
        raise ValueError('The sinus rhythm class is not available.')

    # Compute the observed score.
    A = compute_modified_confusion_matrix(labels, outputs)
    observed_score = np.nansum(weights * A)

    # Compute the score for the model that always chooses the correct label(s).
    correct_outputs = labels
    A = compute_modified_confusion_matrix(labels, correct_outputs)
    correct_score = np.nansum(weights * A)

    # Compute the score for the model that always chooses the sinus rhythm class.
    inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
    inactive_outputs[:, sinus_rhythm_index] = 1
    A = compute_modified_confusion_matrix(labels, inactive_outputs)
    inactive_score = np.nansum(weights * A)

    if correct_score != inactive_score:
        normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
    else:
        normalized_score = 0.0

    return normalized_score

In [47]:
classes, weights = load_weights(WEIGHTS_PATH)

In [48]:
from torchmetrics.classification import MultilabelAUROC

def eval_model(model, val_loader):
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.IntTensor()
    for data, target in val_loader:
        data, target = data.to(device, dtype=torch.float), target.to(device, dtype=torch.float)
        y_hat = torch.sigmoid(model(data))
        
        y_score = torch.cat((y_score, y_hat.detach().to('cpu')), dim=0)
        y_hat = (y_hat > 0.2).int().detach().to('cpu')
        target = target.detach().to('cpu').type(torch.int32)
        
        y_pred = torch.cat((y_pred, y_hat), dim=0)
        y_true = torch.cat((y_true, torch.squeeze(target, 1)), dim=0)
        
    auroc = MultilabelAUROC(num_labels=26, average="macro", thresholds=5)
    roc_auc = auroc(y_score, torch.squeeze(y_true, 1))
    print(torch.sum(y_pred))
    cinc_score = compute_challenge_metric(weights, y_true.numpy(), y_pred.numpy(), classes)

    return roc_auc, cinc_score

In [49]:
def train_model(model, train_dataloader, val_dataloader, n_epoch, optimizer, criterion): 
    for epoch in range(n_epoch):
        model.train()
        curr_epoch_loss = []
        CinC_history = []
        for data, target in tqdm(train_dataloader):
            data, target = data.to(device, dtype=torch.float), target.to(device, dtype=torch.float)

            y_hat = model(data)
            loss = criterion(y_hat, torch.squeeze(target, 1))
            loss.backward()
            optimizer.step()
            
            curr_epoch_loss.append(loss.cpu().data.numpy())
            
        print(f"Epoch {epoch}: curr_epoch_loss={np.mean(curr_epoch_loss)}")
        roc_auc, cinc_score = eval_model(model, val_dataloader)
        print(f"Epoch {epoch}: val_roc: {roc_auc}, CinC_score: {cinc_score}")
        CinC_history.append(cinc_score)
        if cinc_score >= np.max(CinC_history):
            torch.save(model.state_dict(), 'cnn.pt')
    return model

In [50]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [51]:
model = train_model(model, train_dataloader, val_dataloader, 10, optimizer, criterion)

100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [04:26<00:00,  2.17it/s]


Epoch 0: curr_epoch_loss=0.24435146152973175
tensor(5715)
Epoch 0: val_roc: 0.42860907316207886, CinC_score: 0.1887297693045686


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [03:45<00:00,  2.56it/s]


Epoch 1: curr_epoch_loss=0.20071753859519958
tensor(5203)
Epoch 1: val_roc: 0.4293023645877838, CinC_score: 0.15989659876194268


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [03:51<00:00,  2.50it/s]


Epoch 2: curr_epoch_loss=0.19508537650108337
tensor(5620)
Epoch 2: val_roc: 0.43475356698036194, CinC_score: 0.20327349343744922


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [03:50<00:00,  2.51it/s]


Epoch 3: curr_epoch_loss=0.18268758058547974
tensor(3735)
Epoch 3: val_roc: 0.4481239914894104, CinC_score: 0.12049196241736619


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [03:59<00:00,  2.42it/s]


Epoch 4: curr_epoch_loss=0.17641134560108185
tensor(4376)
Epoch 4: val_roc: 0.4330042600631714, CinC_score: 0.14364325481988385


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [04:00<00:00,  2.41it/s]


Epoch 5: curr_epoch_loss=0.17509521543979645
tensor(4770)
Epoch 5: val_roc: 0.4398568272590637, CinC_score: 0.15391858633293212


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [03:52<00:00,  2.49it/s]


Epoch 6: curr_epoch_loss=0.17308813333511353
tensor(2727)
Epoch 6: val_roc: 0.4412396550178528, CinC_score: 0.03850513117394164


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [03:32<00:00,  2.72it/s]


Epoch 7: curr_epoch_loss=0.17100264132022858
tensor(4156)
Epoch 7: val_roc: 0.44665002822875977, CinC_score: 0.186642833956303


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [03:42<00:00,  2.60it/s]


Epoch 8: curr_epoch_loss=0.16902294754981995
tensor(3818)
Epoch 8: val_roc: 0.4473414123058319, CinC_score: 0.17553322968289853


100%|████████████████████████████████████████████████████████████████████████████████| 579/579 [03:52<00:00,  2.49it/s]


Epoch 9: curr_epoch_loss=0.16772276163101196
tensor(3383)
Epoch 9: val_roc: 0.4565739333629608, CinC_score: 0.16875891449870645


In [None]:
torch.save(model.state_dict(), 'cnn.pt')

In [52]:
model.load_state_dict(torch.load('cnn.pt'))
model.to(device)

CNN(
  (conv11): Conv2d(1, 64, kernel_size=(1, 32), stride=(1, 8))
  (bn11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (conv12): Conv2d(64, 64, kernel_size=(1, 32), stride=(1, 8))
  (bn12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv21): Conv2d(64, 128, kernel_size=(1, 32), stride=(2, 8))
  (bn21): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (conv22): Conv2d(128, 128, kernel_size=(1, 32), stride=(2, 8))
  (bn22): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc1): Linear(in_features=128, out_features=26, bias=True)
)

In [53]:
eval_model(model, test_dataloader)

tensor(3376)


(tensor(0.4733), 0.14517453885142168)