### Model


In [1]:
import torch.nn as nn
class ECG_FeatureExtractor(nn.Module):
    def __init__(self):
        super(ECG_FeatureExtractor, self).__init__()
        self.conv1 = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2)
        self.pool = nn.AdaptiveMaxPool1d(output_size=1)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = self.pool(x)
        x = self.dropout(x)
        x = x.view(-1, 64)
        return x

class ECG_Classifier(nn.Module):
    def __init__(self):
        super(ECG_Classifier, self).__init__()
        self.fc1 = nn.Linear(64, 4)

    def forward(self, x):
        x = self.fc1(x)
        return x


### ERM Alogorithm

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Algorithm(torch.nn.Module):
    """
    A subclass of Algorithm implements a domain generalization algorithm.
    Subclasses should implement the following:
    - update()
    - predict()
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(Algorithm, self).__init__()
        self.hparams = hparams

    def update(self, minibatches, unlabeled=None):
        """
        Perform one update step, given a list of (x, y) tuples for all
        environments.
        Admits an optional list of unlabeled minibatches from the test domains,
        when task is domain_adaptation.
        """
        raise NotImplementedError

    def predict(self, x):
        raise NotImplementedError

class ERM(Algorithm):
    """
    Empirical Risk Minimization (ERM)
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(ERM, self).__init__(input_shape, num_classes, num_domains, hparams)
        self.featurizer = ECG_FeatureExtractor()
        self.classifier = ECG_Classifier()
        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

    def update(self, minibatches, unlabeled=None):
        """
        Perform one update step, given a list of (x, y) tuples for all
        environments.
        Admits an optional list of unlabeled minibatches from the test domains,
        when task is domain_adaptation.
        """
        self.network.train()
        self.optimizer.zero_grad()

        all_x = torch.cat([x for x, y in minibatches])
        all_y = torch.cat([y for x, y in minibatches])
        
        y_hat = self.network(all_x)
        loss = F.cross_entropy(y_hat, all_y)

        loss.backward()
        self.optimizer.step()
        return loss.item()

    def predict(self, x):
        self.network.eval()
        with torch.no_grad():
            x = x.unsqueeze(1)
            y_hat = self.network(x)
            return y_hat.argmax(dim=1)


### Dataloading

In [3]:
import torch
from dataloading import MultitaskDataset
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import pickle
import zarr
from sklearn.preprocessing import MinMaxScaler


Bad key "text.kerning_factor" on line 4 in
C:\Users\puranik\Anaconda3\envs\pytorch_venv\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template
or from the matplotlib source distribution


In [4]:
def replace_y(x):
    mapping = {1: 0, 2: 1, 3: 2, 4: 3}
    return mapping.get(x, x)

def replace_d(x):
    mapping = {2: 0, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7, 10: 8, 11: 9, 13: 10, 14: 11, 15: 12, 16: 13, 17: 14}
    return mapping.get(x, x)

def loso(X,y,d):
    left_out_subject = 14
    idx = (d != left_out_subject)
    X_train = X[idx]
    y_train = y[idx]
    d_train = d[idx]

    # test data selecting just 14 subject
    idxt = (d == left_out_subject)
    X_test = X[idxt]
    y_test = y[idxt]
    d_test = d[idxt]


    return X_train, y_train, d_train, X_test, y_test, d_test

def data_loader():
    zarr_array = zarr.open("./dataset/chest_ECG_w60_mw60_ts256_cl2_cs1_fp[1.0].zarr/", mode="r")
    signal = zarr_array['raw_signal'][:]
    target_all = zarr_array['target'][:]
    subjects_all = zarr_array['subject'][:]

    SUBJECTS_IDS = list(range(2, 18))
    subjects = SUBJECTS_IDS[:]
    classes = [1, 2, 3, 4]

    subset_map = [
        idx
        for idx, i in enumerate(target_all)
        if i in classes and subjects_all[idx] in subjects
    ]

    idx = subset_map
    X = signal[idx]
    y = target_all[idx]
    d = subjects_all[idx]

    y_updated = np.vectorize(replace_y)(y)
    d_updated = np.vectorize(replace_d)(d)

    X_train, y_train, d_train, X_test, y_test, d_test = loso(X,y_updated,d_updated)
    
    scaler = MinMaxScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    multitask_data = MultitaskDataset(X_train_scaled, y_train, d_train, X_test_scaled, y_test, d_test)
    trainloader = multitask_data.train_loader(batch_size=32, shuffle=False)
    testloader = multitask_data.test_loader(batch_size=32, shuffle=False)

          
    return  trainloader,testloader

trainloader, testloader = data_loader()

In [5]:
hparams = {'lr': 0.001, 'weight_decay': 0.0001}
alg = ERM(input_shape=(60, 1), num_classes=4, num_domains=14, hparams=hparams)

### Preprocessing

In [8]:
def make_minibatches(dataset, batch_size):
    minibatches = []
    for env in dataset:
        n_samples = len(env)
        n_batches = int(np.ceil(n_samples / batch_size))

        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n_samples)

            X_batch = []
            y_batch = []

            for j in range(start_idx, end_idx):
                X_batch.append(env[j][0])
                y_batch.append(env[j][1])

            minibatches.append((torch.stack(X_batch), torch.tensor(y_batch)))

    return minibatches

In [None]:
new = make_minibatches()

### Training

In [None]:
num_epochs = 20
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (x, y,d) in enumerate(dataloader):
        x, y, d = data
        minibatches = [(x, y)]
        loss = alg.update(minibatches)
        running_loss += loss
    print(f"Epoch {epoch + 1}: Loss = {running_loss / (i+1)}")

### Model evaluation

In [30]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        x, y, d = data

        
     
        outputs = alg.predict(x)
        
        outputs = outputs.unsqueeze(1)
         
        _, predicted = torch.max(outputs.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()
print(f"Accuracy on test set: {100 * correct / total}%")

Accuracy on test set: 39.295212765957444%


In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

# get predictions for the test set
y_true = []
y_pred = []
for data in testloader:
    x, y, d = data
    
    outputs = alg.predict(x)
    outputs = outputs.unsqueeze(1)
    _, predicted = torch.max(outputs.data, 1)
    print(predicted)
    y_true.extend(y.numpy())
    y_pred.extend(predicted.numpy())

# compute the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# plot the confusion matrix
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.set(xticks=np.arange(cm.shape[1]),
       yticks=np.arange(cm.shape[0]),
       xticklabels=['class 0', 'class 1', 'class 2', 'class 3'],
       yticklabels=['class 0', 'class 1', 'class 2', 'class 3'],
       ylabel='True label',
       xlabel='Predicted label')

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        text = ax.text(j, i, format(cm[i, j], 'd'),
                       ha="center", va="center", color="white")

fig.tight_layout()
plt.show()
