<a href="https://colab.research.google.com/github/med-adam-alimi/Development-of-a-Deep-Learning-model-for-ECG-based-arrhythmia-classification/blob/main/ECGV02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import warnings
# Ignore all warnings
warnings.filterwarnings("ignore")

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
train_data = pd.read_csv("mitbih_train.csv", header=None)
test_data = pd.read_csv("mitbih_test.csv", header=None)

## Exploratory Data Analysis

In [None]:
train_data.shape

In [None]:
train_data.iloc[:, -1].value_counts().sort_index()

In [None]:
train_data.head()

In [None]:
test_data.head()

In [None]:
print(train_data.isnull().sum())
print(test_data.isnull().sum())

In [None]:
train_data.iloc[:, -1].unique()

In [None]:
train_data.iloc[:, -1].value_counts().plot(kind='bar', title='Class Distribution')
plt.show()

##Seperate features and labels

In [None]:
X_train=train_data.iloc[:,:-1].values
y_train=train_data.iloc[:,-1].values

X_test=test_data.iloc[:,:-1].values
y_test=test_data.iloc[:,-1].values

##Apply cross validation

In [None]:
from sklearn.model_selection import StratifiedKFold
X =X_train
y =y_train
k_folds = 5
skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    print(f'\nFOLD {fold + 1}')
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

## Visualize Random Heartbeats from the train_data

In [None]:
#Shuffle the train_data
from sklearn.utils import shuffle
X_train, y_train = shuffle(X_train, y_train, random_state=42)
plt.figure(figsize=(10, 7))
import random
random_indices = random.sample(range(len(X_train)), 5)
for i, idx in enumerate(random_indices):
    plt.subplot(1, 5, i + 1)
    plt.plot(X_train[idx])
    plt.title(f"Label: {int(y_train[idx])}")
    plt.xlabel("Time")
    plt.ylabel("Amplitude")
    plt.tight_layout()
#plt.suptitle("Random Heartbeats from Training Data", fontsize=16)
plt.show();

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# If y_train is a NumPy array, convert it to a pandas Series
import pandas as pd
y_train_series = pd.Series(y_train, name='Class')

# Plot class distribution
plt.figure(figsize=(6,4))
sns.countplot(x='Class', data=y_train_series.to_frame())
plt.title('Class Distribution Before Balancing')
plt.show()

In [None]:
print( pd.Series(y_train).value_counts())

## Balance the dataset by:

#Downsampling class 0 to 10,000 samples

#Upsampling classes 1 and 3 to 4,000 samples each

#Keeping other classes (2 and 4) as they are

In [None]:
import pandas as pd
from sklearn.utils import resample
from imblearn.over_sampling import RandomOverSampler

# 1. Combine X and y
df = pd.concat([pd.DataFrame(X_train), pd.Series(y_train, name='label')], axis=1)

# 2. Downsample class 0 to 10,000
df_majority = df[df['label'] == 0.0]
df_majority_downsampled = resample(df_majority,
                                   replace=False,
                                   n_samples=10000,
                                   random_state=42)

# 3. Keep classes 2 and 4 as-is
df_class_2 = df[df['label'] == 2.0]
df_class_4 = df[df['label'] == 4.0]

# 4. Collect classes 1 and 3 for upsampling
df_class_1 = df[df['label'] == 1.0]
df_class_3 = df[df['label'] == 3.0]

# 5. Upsample class 1 to 4=4000
df_class_1_upsampled = resample(df_class_1,
                                replace=True,
                                n_samples=4000,
                                random_state=42)

# 6. Upsample class 3 to 4000
df_class_3_upsampled = resample(df_class_3,
                                replace=True,
                                n_samples=4000,
                                random_state=42)

# 7. Combine all classes
df_balanced = pd.concat([
    df_majority_downsampled,
    df_class_2,
    df_class_4,
    df_class_1_upsampled,
    df_class_3_upsampled
]).sample(frac=1, random_state=42)  # shuffle

# 8. Separate X and y
X_train_balanced = df_balanced.drop('label', axis=1).values
y_train_balanced = df_balanced['label'].values


In [None]:
print( pd.Series(y_train_balanced).value_counts())

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
sns.countplot(x=y_train_balanced, palette='coolwarm')
plt.title("Class Distribution After Balancing")
plt.xlabel("Class")
plt.ylabel("Sample Count")
plt.show()

##Normalize the ECG signals

In [None]:
X_train=X_train_balanced
y_train=y_train_balanced
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

## Visualize Random Heartbeats from the train_data after balancing it

In [None]:
plt.figure(figsize=(10, 7))
import random
random_indices = random.sample(range(len(X_train)), 5)
for i, idx in enumerate(random_indices):
    plt.subplot(1, 5, i + 1)
    plt.plot(X_train[idx])
    plt.title(f"Label: {int(y_train[idx])}")
    plt.xlabel("Time")
    plt.ylabel("Amplitude")
    plt.tight_layout()
#plt.suptitle("Random Heartbeats from Training Data", fontsize=16)
plt.show();

##Convert to a PyTorch Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

class ECGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [None]:
train_dataset = ECGDataset(X_train, y_train)
test_dataset = ECGDataset(X_test, y_test)
val_dataset = ECGDataset(X_val, y_val)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)

train_dataloader, val_dataloader, test_dataloader

###Generate Wavelet Features(for Dual Branch)

In [None]:
!pip install PyWavelets

In [None]:
import pywt
# --- Wavelet Feature Extraction ---
def compute_wavelet(signal):
    coeffs, _ = pywt.cwt(signal, scales=np.arange(1, 31), wavelet='morl')
    return np.abs(coeffs).mean(axis=0)

# Apply to all training and validation data
wavelet_train = np.array([compute_wavelet(x) for x in X_train])
wavelet_test = np.array([compute_wavelet(x) for x in X_test])
wavelet_val = np.array([compute_wavelet(x) for x in X_val])

##Visualize Random Frequential (Wavelet) Signals from the wavelet_train

In [None]:
# Select 5 random indices
random_indices = random.sample(range(len(wavelet_train)), 5)

plt.figure(figsize=(12, 6))
for i, idx in enumerate(random_indices):
    plt.subplot(1, 5, i + 1)
    plt.plot(wavelet_train[idx])
    plt.title(f"Label: {int(y_train[idx])}")
    plt.xlabel("Freq")
    plt.ylabel("Magnitude")
    plt.tight_layout()

# Uncomment if you want a title over the whole figure
# plt.suptitle("Random Frequential Representations (Wavelet)")

plt.show()

In [None]:
# Model Architecture (Dual-Branch Transformer)
class ECGTransformer(nn.Module):
    def __init__(self, input_dim, num_classes, dropout_rate=0.5):
        super(ECGTransformer, self).__init__()

        # Temporal branch
        self.temporal_proj = nn.Linear(input_dim, 64)
        self.temporal_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=64, nhead=8, dim_feedforward=256),
            num_layers=3
        )

        # Frequency branch (assuming wavelet features)
        self.freq_proj = nn.Linear(input_dim, 64)
        self.freq_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=64, nhead=8, dim_feedforward=256),
            num_layers=3
        )

        # Combine branches
        self.combine = nn.Linear(128, 64)
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(64, num_classes)

        # Batch normalization
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)

    def forward(self, x):
        # Temporal branch
        x_temp = self.temporal_proj(x)
        x_temp = self.bn1(x_temp.transpose(1, 2)).transpose(1, 2)
        x_temp = self.temporal_transformer(x_temp)
        x_temp = x_temp.mean(dim=1)

        # Frequency branch (using same input for demo - replace with wavelet features)
        x_freq = self.freq_proj(x)
        x_freq = self.bn2(x_freq.transpose(1, 2)).transpose(1, 2)
        x_freq = self.freq_transformer(x_freq)
        x_freq = x_freq.mean(dim=1)

        # Combine branches
        x = torch.cat([x_temp, x_freq], dim=1)
        x = self.combine(x)
        x = self.dropout(x)
        x = self.classifier(x)

        return x

model = ECGTransformer(input_dim=X_train.shape[1], num_classes=5).to(device)
model

In [None]:
class_counts = np.bincount(y_train.astype(int))
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
class_weights = class_weights.to(device)

# Loss function with class weights
loss_fn = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer with weight decay (L2 regularization)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# Learning rate scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)


In [None]:
best_val_loss = float('inf')
patience = 5
patience_counter = 0
accumulation_steps = 2
train_losses, val_losses = [], []
train_accs, val_accs = [], []
epochs=3
for epoch in range(epochs):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0

  for i,(inputs, labels) in enumerate(train_dataloader):
    inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
    with torch.cuda.amp.autocast():
      outputs = model(inputs.unsqueeze(1))
      loss = loss_fn(outputs, labels)/accumulation_steps

    # Backward pass and optimizer
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

    running_loss += loss.item()
    _, predicted = outputs.max(1)
    total += labels.size(0)
    correct += predicted.eq(labels).sum().item()

        # Validation
  model.eval()
  val_loss = 0.0
  val_correct = 0
  val_total = 0

  with torch.no_grad():
    for inputs, labels in val_dataloader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs.unsqueeze(1))
      loss = loss_fn(outputs, labels)

      val_loss += loss.item()
      _, predicted = outputs.max(1)
      val_total += labels.size(0)
      val_correct += predicted.eq(labels).sum().item()

      # Calculate metrics
  train_loss = running_loss / len(train_dataloader)
  val_loss = val_loss / len(val_dataloader)
  train_acc = 100. * correct / total
  val_acc = 100. * val_correct / val_total

  # Update learning rate
  scheduler.step(val_loss)

        # Store metrics
  train_losses.append(train_loss)
  val_losses.append(val_loss)
  train_accs.append(train_acc)
  val_accs.append(val_acc)

  print(f'Epoch: {epoch+1} | Train Loss: {train_loss:.5f}, Train Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.5f}, Val Acc: {val_acc:.2f}%')

        # Early stopping check
  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(model.state_dict(), 'best_model.pth')
    patience_counter = 0
  else:
    patience_counter += 1
    if patience_counter >= patience:
      print("Early stopping triggered")
      break

In [None]:
best_val_loss = float('inf')
patience = 5
patience_counter = 0
accumulation_steps = 2
train_losses, val_losses ,test_losses= [], [],[]
train_accs, val_accs,test_accs = [], [],[]
epochs=100
for epoch in range(epochs):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0

  for i,(inputs, labels) in enumerate(train_dataloader):
    inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
    with torch.cuda.amp.autocast():
      outputs = model(inputs.unsqueeze(1))
      loss = loss_fn(outputs, labels)/accumulation_steps

    # Backward pass and optimizer
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

    running_loss += loss.item()
    _, predicted = outputs.max(1)
    total += labels.size(0)
    correct += predicted.eq(labels).sum().item()

        # Validation
  model.eval()
  val_loss = 0.0
  val_correct = 0
  val_total = 0

  with torch.no_grad():
    for inputs, labels in val_dataloader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs.unsqueeze(1))
      loss = loss_fn(outputs, labels)

      val_loss += loss.item()
      _, predicted = outputs.max(1)
      val_total += labels.size(0)
      val_correct += predicted.eq(labels).sum().item()

        # Testing
  model.eval()
  test_loss = 0.0
  test_correct = 0
  test_total = 0

  with torch.no_grad():
    for inputs, labels in test_dataloader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs.unsqueeze(1))
      loss = loss_fn(outputs, labels)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      test_total += labels.size(0)
      test_correct += predicted.eq(labels).sum().item()


      # Calculate metrics
  train_loss = running_loss / len(train_dataloader)
  val_loss = val_loss / len(val_dataloader)
  test_loss=test_loss/len(test_dataloader)
  train_acc = 100. * correct / total
  val_acc = 100. * val_correct / val_total
  test_acc=100.*test_correct/test_total

  # Update learning rate
  scheduler.step(val_loss)

        # Store metrics
  train_losses.append(train_loss)
  val_losses.append(val_loss)
  test_losses.append(test_loss)
  train_accs.append(train_acc)
  val_accs.append(val_acc)
  test_accs.append(test_acc)
  if (epoch + 1) == 1 or (epoch + 1) % 10 == 0:
    print(f'Epoch: {epoch+1} | Train Loss: {train_loss:.5f}, Train Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.5f}, Val Acc: {val_acc:.2f}% |'
              f'Test Loss: {test_loss:.5f}, Test Acc: {test_acc:.2f}%')

  # still save the best model
  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(model.state_dict(), 'best_model.pth')


In [None]:
import matplotlib.pyplot as plt

epochs_range = range(1, len(train_losses) + 1)

fig, axs = plt.subplots(1, 2, figsize=(14, 5))

# Left subplot: Train Loss & Train Accuracy
axs[0].plot(epochs_range, train_losses, label='Train Loss', color='tab:blue')
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('Loss', color='tab:blue')
axs[0].tick_params(axis='y', labelcolor='tab:blue')
axs[0].set_title('Train Loss and Accuracy')

ax2 = axs[0].twinx()
ax2.plot(epochs_range, train_accs, label='Train Accuracy', color='tab:orange')
ax2.set_ylabel('Accuracy (%)', color='tab:orange')
ax2.tick_params(axis='y', labelcolor='tab:orange')

# Right subplot: Test Loss & Test Accuracy
axs[1].plot(epochs_range, test_losses, label='Test Loss', color='tab:blue')
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Loss', color='tab:blue')
axs[1].tick_params(axis='y', labelcolor='tab:blue')
axs[1].set_title('Test Loss and Accuracy')

ax4 = axs[1].twinx()
ax4.plot(epochs_range, test_accs, label='Test Accuracy', color='tab:orange')
ax4.set_ylabel('Accuracy (%)', color='tab:orange')
ax4.tick_params(axis='y', labelcolor='tab:orange')

plt.tight_layout()
plt.show()


In [None]:
model_results = pd.DataFrame([{
    'model_name': model.__class__.__name__,
    'model_loss': test_loss,
    'model_acc': test_acc
}])

model_results

##Make and evaluate random prediction with our model

In [None]:
def make_predictions(model: torch.nn.Module, data: list, device: torch.device = device):
    pred_probs = []
    model.eval()
    with torch.inference_mode():
        for sample in data:
            # Prepare sample
            sample = torch.unsqueeze(sample, dim=0).to(device) # Add an extra dimension and send sample to device
            # Reshape the sample to have 3 dimensions for the model input (unsqueeze on dim=1)
            sample = sample.unsqueeze(1)
            # Forward pass (model outputs raw logit)
            pred_logit = model(sample)

            # Get prediction probability (logit -> prediction probability)
            pred_prob = torch.softmax(pred_logit.squeeze(), dim=0)

            # Get pred_prob off GPU for further calculations
            pred_probs.append(pred_prob.cpu())

    # Stack the pred_probs to turn list into a tensor
    return torch.stack(pred_probs)

In [None]:
import random
#random.seed(42)
test_samples = []
test_labels = []
# Get a list of random row indices
random_indices = random.sample(range(len(test_data)), k=9)

# Use the indices to access data and labels
for idx in random_indices:
    # Convert sample to PyTorch tensor
    sample = torch.tensor(test_data.iloc[idx, :-1].values, dtype=torch.float32)
    label = test_data.iloc[idx, -1]
    test_samples.append(sample)
    test_labels.append(label)

# View the first test sample shape and label
print(f"Test sample image shape: {test_samples[0].shape}\nTest sample label: {test_labels[0]}")

In [None]:
# Make predictions on test samples with model 2
pred_probs= make_predictions(model=model,
                             data=test_samples)

# View first 5 prediction probabilities list
pred_probs[:5]

In [None]:
# Turn the prediction probabilities into prediction labels by taking the argmax()
pred_classes = pred_probs.argmax(dim=1)
pred_classes

In [None]:
# Are our predictions in the same form as our test labels?
test_labels_tensor = torch.tensor(test_labels, dtype=torch.long)
print(f'test labels :{test_labels_tensor}')
print(f'pred calsses:{pred_classes}')

In [None]:
#Now let's visualize:
# Plot predictions
plt.figure(figsize=(9, 9))
nrows = 3
ncols = 3
for i, sample in enumerate(test_samples):
  # Create a subplot
  plt.subplot(nrows, ncols, i+1)

  plt.plot(sample.squeeze())

  # Find the prediction label (in text form, e.g. "0")
  pred_label = pred_classes[i]

  # Get the truth label (in text form, e.g. "2")
  truth_label = test_labels[i]

  # Create the title text of the plot
  title_text = f"Pred: {pred_label} | Truth: {truth_label}"

  # Check for equality and change title colour accordingly
  if pred_label == truth_label:
      plt.title(title_text, fontsize=10, c="g") # green text if correct
  else:
      plt.title(title_text, fontsize=10, c="r") # red text if wrong
  plt.axis(False);

###Making a confusion matrix for further prediction evaluation :!

In [None]:
from tqdm.auto import tqdm
y_preds=[]
model.eval()
with torch.no_grad():
  for input, label in tqdm(test_dataloader,desc='Making predictions'):
    input, label = input.to(device), label.to(device)
    y_pred = model(input.unsqueeze(1))
    #Put predictions on CPU for evaluation
    y_preds.append(y_pred.cpu())
# Concatenate list of predictions into a tensor
y_pred_tensor = torch.cat(y_preds)

In [None]:
# See if torchmetrics exists, if not, install it
try:
    import torchmetrics, mlxtend
    print(f"mlxtend version: {mlxtend.__version__}")
    assert int(mlxtend.__version__.split(".")[1]) >= 19, "mlxtend verison should be 0.19.0 or higher"
except:
    !pip install -q torchmetrics -U mlxtend
    import torchmetrics, mlxtend
    print(f"mlxtend version: {mlxtend.__version__}")

In [None]:
# Import mlxtend upgraded version
import mlxtend
print(mlxtend.__version__)
assert int(mlxtend.__version__.split(".")[1]) >= 19

In [None]:
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix

# 'y_test' contains the target labels
confmat = ConfusionMatrix(num_classes=5, task='multiclass')
confmat_tensor = confmat(preds=y_pred_tensor.argmax(dim=1),
                         target=torch.tensor(y_test, dtype=torch.int64))  # Convert y_test to a PyTorch tensor

# Define class names
class_names = ['0', '1', '2', '3', '4']

# Plot the confusion matrix
fig, ax = plot_confusion_matrix(
    conf_mat=confmat_tensor.numpy(),
    class_names=class_names,
    figsize=(10, 7)
);

In [None]:
from sklearn.metrics import confusion_matrix
import itertools
import torch

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    Displays the confusion matrix with or without normalization.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

y_pred_labels = y_pred_tensor.argmax(dim=1).cpu().numpy()
y_true_labels = torch.tensor(y_test).cpu().numpy()

# --- Compute the confusion matrix using scikit-learn ---
cnf_matrix = confusion_matrix(y_true_labels, y_pred_labels)

plt.figure(figsize=(7, 7))
plot_confusion_matrix(cnf_matrix,
                      classes= ['0', '1', '2', '3', '4'],
                      normalize=True,
                      title='Normalized Confusion Matrix')
plt.show()

In [None]:
from sklearn.metrics import classification_report
print(classification_report(y_true_labels, y_pred_labels, digits=4))

## ROC and AUC

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
import torch # Make sure torch is imported if not already

n_classes = 5

# Use y_true_labels which contains the ground truth labels
y_true_bin = label_binarize(y_true_labels, classes=[0,1,2,3,4])

fpr = dict()
tpr = dict()
roc_auc = dict()

# Calculate the softmax probabilities to get class scores
y_score = torch.softmax(y_pred_tensor, dim=1).cpu().numpy()

for i in range(n_classes):
    # Use the binarized true labels and the predicted scores
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

plt.figure(figsize=(10,8))
colors = ['blue', 'red', 'green', 'orange', 'purple']

for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=2,
             label=f'Class {i} (AUC = {roc_auc[i]:.2f})')

plt.plot([0, 1], [0, 1], 'k--', lw=2)  # diagonal line
plt.xlim([0, 1])
plt.ylim([0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Multi-class ROC Curve')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()


In [None]:
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import numpy as np

# y_true: true labels (shape: [n_samples])
# y_score: predicted probabilities (shape: [n_samples, n_classes])
# Binarize true labels for one-vs-rest approach
from sklearn.preprocessing import label_binarize

n_classes = y_score.shape[1]
# Corrected: Use y_true_labels instead of y_true
y_test_bin = label_binarize(y_true_labels, classes=np.arange(n_classes))

# Plot precision-recall curve for each class
plt.figure(figsize=(10, 8))
for i in range(n_classes):
    precision, recall, _ = precision_recall_curve(y_test_bin[:, i], y_score[:, i])
    avg_prec = average_precision_score(y_test_bin[:, i], y_score[:, i])
    plt.plot(recall, precision, lw=2, label=f'Class {i} (AP = {avg_prec:.2f})')

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Multi-class Precision-Recall Curve")
plt.legend()
plt.grid(True)
plt.show()


##Save and load our model

In [None]:
from pathlib import Path

MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True,exist_ok=True )

# Create model save path
MODEL_NAME = "ECG_arrhythmia_classification_model_3.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# Save the model state dict
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj=model.state_dict(), # only saving the state_dict() only saves the learned parameters
           f=MODEL_SAVE_PATH)

#####Now we've got a saved model state_dict() we can load it back in using a combination of load_state_dict() and torch.load().

#####Since we're using load_state_dict(), we'll need to create a new instance of ECGTransformer() with the same input parameters as our saved model state_dict()

In [None]:
# Create a new instance of FashionMNISTModelV2 (the same class as our saved state_dict())
#Provide input_dim and num_classes while creating an instance of ECGTransformerV3
loaded_model = ECGTransformer(input_dim=X_train.shape[1],num_classes=5)

# Load in the saved state_dict()
loaded_model.load_state_dict(torch.load(f=MODEL_SAVE_PATH))

# Send model to GPU
loaded_model= loaded_model.to(device)

In [None]:
# iterate through the keys and compare individual tensors using torch.equal()
all_equal = True
for key in loaded_model.state_dict():
    if not torch.equal(loaded_model.state_dict()[key], loaded_model.state_dict()[key]):
        all_equal = False
        break

print(all_equal)  # This will print True if all tensors are equal

####Now that we've got a loaded model we can evaluate it to make sure its parameters work similarly to model_3 prior to saving

In [None]:
for epoch in range(10):
  loaded_model.eval()
  loaded_test_loss = 0.0
  test_correct = 0
  test_total = 0

  with torch.no_grad():
    for inputs, labels in test_dataloader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = loaded_model(inputs.unsqueeze(1))
      loss = loss_fn(outputs, labels)

      loaded_test_loss += loss.item()
      _, predicted = outputs.max(1)
      test_total += labels.size(0)
      test_correct += predicted.eq(labels).sum().item()


      # Calculate metrics
  loaded_test_loss=loaded_test_loss/len(test_dataloader)
  loaded_test_acc=100.*test_correct/test_total

  # Update learning rate
  scheduler.step(test_loss)


In [None]:
loaded_model_results = pd.DataFrame([{
    'model_name': loaded_model.__class__.__name__,
    'model_loss': loaded_test_loss,
    'model_acc': loaded_test_acc
}])
loaded_model_results

In [None]:
model_results