# COMP90089 Biomedical Time-Series tutorial

This tutorial analyzes an ECG dataset.
The ECG dataset comprises long-term Electrocardiogram (ECG) recordings of human subjects diagnosed with atrial fibrillation. The dataset captures two-channel distinct ECG signals, both sampled at a rate of 250 Hz and annotated with four different classes: [AFIB (atrial fibrillation), (AFL (atrial flutter), (J (AV junctional rhythm), and (N (used to indicate all other rhythms)](https://physionet.org/content/afdb/1.0.0/old/).
 
The data has been divided into training, validation, and test. Each sample contains a 2-channel ECG lasting 10s, i.e., 2500 points. 
 


In [None]:
# First, ensure you have the necessary packages installed in your Google Colab environment.
# You can install them by running:
!pip install torch pandas numpy scikit-learn matplotlib

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.metrics import mean_absolute_error, mean_squared_error, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix


In [None]:
# Set random seed for reproducibility
np.random.seed(0)

# Load datasets from .pkl files
x_train = pd.read_pickle('https://github.com/melbourne-cdth/comp90089_time_series_tutorial/raw/refs/heads/main/ECGless/x_train1.pkl')
x_val = pd.read_pickle('https://github.com/melbourne-cdth/comp90089_time_series_tutorial/raw/refs/heads/main/ECGless/x_val1.pkl')
x_test = pd.read_pickle("https://github.com/melbourne-cdth/comp90089_time_series_tutorial/raw/refs/heads/main/ECGless/x_test1.pkl")

y_train = pd.read_pickle('https://github.com/melbourne-cdth/comp90089_time_series_tutorial/raw/refs/heads/main/ECGless/y_train1.pkl')
y_val = pd.read_pickle('https://github.com/melbourne-cdth/comp90089_time_series_tutorial/raw/refs/heads/main/ECGless/y_val1.pkl')
y_test = pd.read_pickle('https://github.com/melbourne-cdth/comp90089_time_series_tutorial/raw/refs/heads/main/ECGless/y_test1.pkl')

print(x_train.shape)
print(y_train.shape)

# Plot a feature from the datasets
plt.figure(figsize=(15, 5))

# Plot a feature from x_train
plt.subplot(1, 3, 1)
plt.plot(x_train[20, 0], label='Train Feature 0')
plt.title('Training Data - ECG Channel 1')
plt.xlabel('Sample index')
plt.ylabel('ECG')
plt.legend()

# Plot a feature from x_val
plt.subplot(1, 3, 2)
plt.plot(x_val[10, 0], label='Validation Feature 0')
plt.title('Validation Data - ECG Channel 1')
plt.xlabel('Sample index')
plt.ylabel('ECG')
plt.legend()

# Plot a feature from x_test
plt.subplot(1, 3, 3)
plt.plot(x_test[10, 0], label='Test Feature 0')
plt.title('Test Data - ECG Channel 1')
plt.xlabel('Sample index')
plt.ylabel('ECG')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
chunk_size = 500  # Chunking a sequence of 2500 points into 5 segments, each containing 500 points.

class TimeSeriesDataset(Dataset): # chunk each 2500 into segments for LSTM to learn temporal dynamics
    def __init__(self, data, labels, chunk_size):
        self.data = data
        self.labels = labels
        self.chunk_size = chunk_size
        self.num_chunks = data.shape[2] // chunk_size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = self.data[idx]
        label = self.labels[idx]
        chunks = [sequence[:, i * self.chunk_size:(i + 1) * self.chunk_size] for i in range(self.num_chunks)]
        chunks = np.stack(chunks, axis=0)  # Shape: (5, 2, 500)
        return torch.tensor(chunks).float(), label


# Create DataLoaders
batch_size = 256
dataset = TimeSeriesDataset(x_train, y_train, chunk_size)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataset = TimeSeriesDataset(x_val, y_val, chunk_size)
val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataset = TimeSeriesDataset(x_test, y_test, chunk_size)
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
# Define the LSTM model
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x shape: (batch_size, num_chunks, chunk_size, num_features)
        batch_size, num_chunks, chunk_size, num_features = x.size()
        x = x.view(batch_size * num_chunks, chunk_size, num_features)  # Reshape for LSTM

        # Get LSTM outputs
        _, (hn, _) = self.lstm(x)

        # hn: (num_layers, batch_size * num_chunks, hidden_size)
        # Reshape hn to get hidden states for each chunk
        hn = hn[-1].view(batch_size, num_chunks, -1)  # Shape: (batch_size, num_chunks, hidden_size)

        # Aggregate across chunks for final classification: for example, take the mean
        # Here you might choose another aggregation strategy if needed
        aggregated_hidden_state = torch.mean(hn, dim=1)  # Example aggregation

        # Final output using the aggregated hidden states
        out = self.fc(aggregated_hidden_state)  # Shape: (batch_size, num_classes)

        # Return both the final output and the full hidden states
        return out, hn

class TransformerModel(nn.Module):
    def __init__(self, input_size, num_layers, nhead, dim_feedforward, output_size, problem_type):
        super(TransformerModel, self).__init__()
        self.problem_type = problem_type
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=nhead, dim_feedforward=dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(input_size, output_size)
        if problem_type == 'classification':
            self.activation = nn.Softmax(dim=1)

    def forward(self, x):
        # Assuming x is (batch_size, seq_length, input_size) and needs to be transposed
        x = x.permute(1, 0, 2)  # Transpose to (seq_length, batch_size, input_size)
        out = self.transformer_encoder(x)
        out = self.fc(out[-1])  # Use the last output for classification or regression
        if self.problem_type == 'classification':
            out = self.activation(out)
        return out

# Hyperparameters
input_size = chunk_size
problem_type = 'classification'
hidden_size = 64
num_layers = 4
output_size = 4 # how many classes

# Instantiate the model
# model = LSTMModel(input_size, hidden_size, num_layers, output_size, problem_type)
model = LSTMClassifier(input_size, hidden_size, num_layers, output_size)

# # # switching to Transformer model
# input_size = X_train.shape[-1]
# num_layers = 2
# nhead = 2              # Number of attention heads
# dim_feedforward = 64
# output_size = 4
# problem_type = 'classification'

# # Create an instance of the TransformerModel
# model = TransformerModel(input_size, num_layers, nhead, dim_feedforward, output_size, problem_type)


In [None]:
if problem_type == 'regression':
    criterion = nn.MSELoss()
elif problem_type == 'classification':
    criterion = nn.CrossEntropyLoss()

learning_rate = 1e-3
weight_decay = 3e-4
beta1 = 0.9
beta2 = 0.99

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay,
    betas=(beta1, beta2)
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)


# Training the model with validation and early stopping
num_epochs = 150
patience = 10  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
patience_counter = 0

train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    for i, (x_batch, y_batch) in enumerate(train_loader):
        outputs, _ = model(x_batch)
        if problem_type == 'classification':
            y_batch = y_batch.long()
        loss = criterion(outputs, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_losses.append(loss.item())

    # Validation phase
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            val_outputs, _ = model(x_val)
            if problem_type == 'classification':
                y_val = y_val.long()
            val_loss += criterion(val_outputs, y_val).item()

    val_loss /= len(val_loader)
    val_losses.append(val_loss)

    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {loss.item():.4f}, Validation Loss: {val_loss:.4f}')
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Best model saved at epoch {epoch+1} with validation loss: {val_loss:.4f}')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break

# Plot training and validation loss
plt.figure(figsize=(10, 4))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Evaluate the model
test_outputs_list = []
y_test_list = []
model.eval()
with torch.no_grad():
    for x_test_chunk, y_test_chunk in test_loader:
        test_outputs, hidden_states = model(x_test_chunk)
        test_outputs_list.append(test_outputs)
        y_test_list.append(y_test_chunk)

# Concatenate all outputs
y_test_pred = torch.cat(test_outputs_list, dim=0)
y_test = torch.cat(y_test_list, dim=0)
if problem_type == 'classification':
  y_test_pred_class = torch.argmax(y_test_pred, dim=1)

# with torch.no_grad():
#     y_test_pred = model(X_test)
#     y_test_pred = y_test_pred.squeeze()
#     if problem_type == 'classification':
#         y_test_pred_class = torch.argmax(y_test_pred, dim=1)

# Calculate metrics
if problem_type == 'regression':
    y_test_actual = y_test.numpy()
    predicted = y_test_pred.numpy()
    mae = mean_absolute_error(y_test_actual, predicted)
    rmse = np.sqrt(mean_squared_error(y_test_actual, predicted))
    print(f'Mean Absolute Error (MAE): {mae:.4f}')
    print(f'Root Mean Squared Error (RMSE): {rmse:.4f}')

    # Plot predictions vs actual values
    plt.figure(figsize=(10, 4))
    plt.plot(predicted, label='Predicted', alpha=0.7)
    plt.plot(y_test_actual, label='Actual', alpha=0.7)
    plt.title('Predicted vs Actual Values')
    plt.xlabel('Sample')
    plt.ylabel('Value')
    plt.legend()
    plt.show()

elif problem_type == 'classification':
    y_test_class = y_test.int().numpy()
    accuracy = accuracy_score(y_test_class, y_test_pred_class)
    precision = precision_score(y_test_class, y_test_pred_class, average='macro')
    recall = recall_score(y_test_class, y_test_pred_class, average='macro')
    f1 = f1_score(y_test_class, y_test_pred_class, average='macro')
    conf_matrix = confusion_matrix(y_test_class, y_test_pred_class)

    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')
    print('Confusion Matrix:')
    print(conf_matrix)

In [None]:
# Select the hidden states for the first (and only) example
hn_example = hidden_states[0].detach().numpy()  # Shape: (num_chunks, hidden_size)

# Plot the heatmap
fig, axes = plt.subplots(2, 1, figsize=(10, 12))  # Create a figure with 2 rows and 1 column

# Top plot: Heatmap of Hidden States
axes[0].imshow(hn_example.T, aspect='auto', cmap='viridis')
axes[0].set_title('Heatmap of Hidden States')
axes[0].set_ylabel('Hidden Unit')
axes[0].set_xlabel('Chunk')
cb = plt.colorbar(axes[0].images[0], ax=axes[0], orientation='vertical')
cb.set_label('Activation')

# Bottom plot: ECG Time Series
print(x_test_chunk[0, :, 0].shape)
axes[1].plot(x_test_chunk[0, :, 0].reshape(-1))
axes[1].set_title('ECG Time Series')
axes[1].set_xlabel('Sample Index')
axes[1].set_ylabel('Value')

for x in range(500, 2500, 500):
    axes[1].axvline(x=x, color='red', linestyle='--', linewidth=0.8)

plt.tight_layout()  # Adjust subplots to fit into the figure area.
plt.show()

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Function to extract features from the LSTM
def extract_features(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    features = []
    labels = []
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs = inputs.to(next(model.parameters()).device)
            outputs, _ = model(inputs)
            feature = outputs.cpu().numpy()
            features.append(feature)
            labels.append(targets.numpy())

    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    return features, labels

# Apply t-SNE to the features
def plot_tsne(features, labels):
    tsne = TSNE(n_components=2, random_state=42)
    reduced_features = tsne.fit_transform(features)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(reduced_features[:, 0], reduced_features[:, 1], c=labels, cmap='viridis', alpha=0.7)
    plt.colorbar(scatter, label='Class Label')
    plt.title('t-SNE Visualization of LSTM Features')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.show()

# Example usage
features, labels = extract_features(model, train_loader)
plot_tsne(features, labels)