In [None]:
import sys
import joblib
import time
import datetime
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from imblearn.over_sampling import RandomOverSampler

from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, f1_score, confusion_matrix, precision_score, recall_score
import torch.nn.functional as F

sys.path.append("Oxford/")
# Import functions from data.py
from data import NormalDataset, resize, get_inverse_class_weights

# Import function from utils.py
from utils import EarlyStopping

# Need this for the pre-trained SSL model
device = 'cpu'

In [None]:
# Paths to data
wrist_dir = "/domino/datasets/local/dataset/idea_fast/for_s3"
mapped_dir = "Outputs/Lower Back Predictions Mapped To Wrist/"

# List of participants
# subjects = ['N7TFS3G']
subjects = os.listdir(mapped_dir)
subjects = [subject for subject in subjects if subject not in ['.ipynb_checkpoints']]

# Initialize an empty list to store all subjects' processed data
processed_dfs = []

# Columns to load
wrist_columns_to_load = ['accel_x', 'accel_y', 'accel_z']
mapped_columns_to_load = ['index', 'accel_x', 'accel_y', 'accel_z', 'lower_back_mapped_value']

# Maximum rows to process per participant
max_rows_per_participant = 70_000_000

# Loop through each participant
for pid, subject in enumerate(subjects, start=1):
    print(f"Processing subject: {subject}")
    
    # Start timer for the subject
    subject_start_time = time.time()

    # Load mapped signal data
    mapped_signal_path = os.path.join(mapped_dir, subject, 'wrist_lower_back_df.csv')
    mapped_pd = pd.read_csv(mapped_signal_path, usecols=mapped_columns_to_load)
    
    # Rename 'lower_back_mapped_value' to 'label'
    mapped_pd.rename(columns={'lower_back_mapped_value': 'label'}, inplace=True)
    
    # Convert accelerometer values to float32
    mapped_pd[['accel_x', 'accel_y', 'accel_z']] = mapped_pd[['accel_x', 'accel_y', 'accel_z']].astype('float32')
    
    # Add participant ID column
    mapped_pd['pid'] = pid
    
    # Store mapped indices in a set for fast lookups
    mapped_indices_set = set(mapped_pd['index'])
    
    # Add mapped DataFrame to the list
    processed_dfs.append(mapped_pd)
    
    # Load entire wrist signal data
    whole_signal_path = os.path.join(wrist_dir, subject, 'combined_ax6_df.csv')

    chunk_size = 1_000_000  # Chunk size
    rows_read = 0  # Track the number of rows read

    for chunk_idx, chunk in enumerate(pd.read_csv(whole_signal_path, usecols=wrist_columns_to_load, chunksize=chunk_size)):
        # Start timer for the chunk
        chunk_start_time = time.time()

        # Break the loop if max rows are reached
        if rows_read >= max_rows_per_participant:
            print(f"Reached max rows ({max_rows_per_participant}) for participant {subject}.")
            break
        
        # Create the `index` column for whole signal data
        chunk.reset_index(inplace=True)  # Adds 'index' column with global row numbers

        # Drop rows where accelerometer values are NaN
        chunk = chunk.dropna(subset=['accel_x', 'accel_y', 'accel_z']).copy()

        # Filter out rows in chunk that have an index present in mapped_pd
        chunk = chunk[~chunk['index'].isin(mapped_indices_set)]
        
        # Add participant ID column
        chunk.loc[:, 'pid'] = pid
        
        # Convert accelerometer values to float32
        chunk.loc[:, ['accel_x', 'accel_y', 'accel_z']] = chunk[['accel_x', 'accel_y', 'accel_z']].astype('float32')
            
        # Assign label 0 for non-mapped data
        chunk.loc[:, 'label'] = 0
        
        # Reorder columns to match mapped_df
        chunk = chunk[['index', 'accel_x', 'accel_y', 'accel_z', 'label', 'pid']]
        
        # Add chunk to the processed list
        processed_dfs.append(chunk)
        
        # Update the row count
        rows_read += len(chunk)

    # End timer for the subject
    subject_total_time = time.time() - subject_start_time
    print(f"Finished processing subject {subject} in {subject_total_time:.2f} seconds")

# Concatenate all processed chunks into a single DataFrame
df_combined = pd.concat(processed_dfs, ignore_index=True)

# Print combined DataFrame info
print(f"Combined DataFrame info:")
print(df_combined.info())
print(f"Label counts:\n{df_combined['label'].value_counts()}")

In [None]:
def load_data_from_df(df, window_size=3000, target_freq=30, original_freq=100):
    """
    Adjusts the downsampling method to match Oxford's approach using linear interpolation.
    The function prepares data for a pre-trained SSL model, which expects 30Hz sampling and 30s windows.
    It also balances the training and validation sets using RandomOverSampler.

    Parameters:
        df (pd.DataFrame): Input DataFrame containing accelerometer data.
        window_size (int): Number of samples per window at the original frequency (default: 3000 for 100Hz).
        target_freq (int): Target frequency for downsampling (default: 30Hz).
        original_freq (int): Original frequency of the data (default: 100Hz).

    Returns:
        Train, validation, and test splits with balanced training and validation sets.
    """

    # Calculate the downsampled window size
    downsampled_window_size = int(window_size * (target_freq / original_freq))  # 3000 * (30/100) = 900

    # Trim excess data that doesn't fit into full windows of original size
    num_windows = len(df) // window_size
    df = df.iloc[:num_windows * window_size]

    # Reshape the df into windows of shape (num_windows, window_size, 3)
    X = df[['accel_x', 'accel_y', 'accel_z']].to_numpy().reshape(num_windows, window_size, 3)
    y = df['label'].values[:num_windows * window_size].reshape(num_windows, window_size).mean(axis=1).astype(int)
    pid = df['pid'].values[:num_windows * window_size].reshape(num_windows, window_size)[:, 0].astype(int)

    # Downsample X using linear interpolation (Oxford approach)
    from scipy.interpolate import interp1d
    t_original = np.linspace(0, 1, window_size)  # Original time points
    t_target = np.linspace(0, 1, downsampled_window_size)  # Target time points
    X_downsampled = np.zeros((num_windows, downsampled_window_size, 3))  # Preallocate array

    for i in range(num_windows):
        for axis in range(3):  # Loop over accel_x, accel_y, accel_z
            interp_func = interp1d(t_original, X[i, :, axis], kind="linear", assume_sorted=True)
            X_downsampled[i, :, axis] = interp_func(t_target)

    # Assign participants to train, validation, and test sets (60/20/20 split)
    unique_pids = np.unique(pid)
    train_pids, test_pids = train_test_split(unique_pids, test_size=0.2, random_state=42)
    train_pids, val_pids = train_test_split(train_pids, test_size=0.25, random_state=41)  # 0.25 * 80% = 20%

    train_idx = np.isin(pid, train_pids)
    val_idx = np.isin(pid, val_pids)
    test_idx = np.isin(pid, test_pids)

    x_train, y_train, pid_train = X_downsampled[train_idx], y[train_idx], pid[train_idx]
    x_val, y_val, pid_val = X_downsampled[val_idx], y[val_idx], pid[val_idx]
    x_test, y_test, pid_test = X_downsampled[test_idx], y[test_idx], pid[test_idx]

    # Balance the training and validation sets using RandomOverSampler
    def oversample_with_noise(X, y, pid):
        # Flatten features for oversampling
        X_flat = X.reshape(X.shape[0], -1)
        ros = RandomOverSampler(random_state=42)
        X_resampled, y_resampled = ros.fit_resample(X_flat, y)

        # Expand pid to match the resampled data
        pid_resampled = ros.fit_resample(pid.reshape(-1, 1), y)[0].ravel()

        # Add small random noise to avoid duplicates
        noise = np.random.normal(0, 0.01, X_resampled.shape)
        X_resampled += noise

        # Reshape back to the original 3D format
        return X_resampled.reshape(-1, downsampled_window_size, 3), y_resampled, pid_resampled

    x_train, y_train, pid_train = oversample_with_noise(x_train, y_train, pid_train)
    x_val, y_val, pid_val = oversample_with_noise(x_val, y_val, pid_val)

    return (
        x_train, y_train, pid_train,
        x_val, y_val, pid_val,
        x_test, y_test, pid_test
    )

In [None]:
# Run function above to get training, val, and test splits
(
    x_train, y_train, group_train,
    x_val, y_val, group_val,
    x_test, y_test, group_test
) = load_data_from_df(
    df_combined,
    window_size=3000,   # Original window size at 100Hz
    target_freq=30,     # Target frequency for the SSL model
    original_freq=100   # Original frequency of the input data
)

# Count occurrences of each label in validation and test sets
val_classes, val_counts = np.unique(y_val, return_counts=True)
test_classes, test_counts = np.unique(y_test, return_counts=True)

print("Validation class distribution:")
for cls, count in zip(val_classes, val_counts):
    print(f"  Class {cls}: {count} instances")

print("\nTest class distribution:")
for cls, count in zip(test_classes, test_counts):
    print(f"  Class {cls}: {count} instances")

Run chunk below to skip fine tuning

In [None]:
# Specify fine tuning approach used
fine_tuning_approach = "no fine tuning"

# Load the pretrained model
# GitHub token was free and can be replaced as needed
os.environ['GITHUB_TOKEN'] = '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'

repo = 'OxWearables/ssl-wearables'

sslnet: nn.Module = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=2, pretrained=True, weights_only=False)
sslnet = sslnet.to(device).float()

Only run chunk below to perform fine tuning by freezing the convolutional layers

In [None]:
# Specify fine tuning approach used
fine_tuning_approach = "freeze conv layers"

# Load the pretrained model
# GitHub token used below was free and can be replaced as needed
os.environ['GITHUB_TOKEN'] = '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'

repo = 'OxWearables/ssl-wearables'

sslnet: nn.Module = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=2, pretrained=True, weights_only=False)
sslnet = sslnet.to(device).float()

# Freeze the convolutional layers while keeping linear layers trainable
def set_bn_eval(m):
    if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
        m.eval()

# Initialize a counter to track frozen weights
i = 0
for name, param in sslnet.named_parameters():
    # Check if the parameter belongs to convolutional layers
    if "conv" in name or "bn" in name or "feature_extractor" in name:
        param.requires_grad = False
        i += 1
    else:
        param.requires_grad = True

# Apply the batch normalization setting
sslnet.apply(set_bn_eval)

# Print the number of weights frozen in convolutional layers
print(f"Weights being frozen in the convolutional layers: {i}")

Only run chunk below to perform fine tuning by freezing the first residual block

In [None]:
# Specify fine tuning approach used
fine_tuning_approach = "freeze first residual block"

# Load the pretrained model
# GitHub token below was free and can be replaced as needed
os.environ['GITHUB_TOKEN'] = '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'

repo = 'OxWearables/ssl-wearables'

sslnet: nn.Module = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=2, pretrained=True, weights_only=False)
sslnet = sslnet.to(device).float()

# Freeze the first residual block
def set_bn_eval(m):
    if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
        m.eval()

# Initialize a counter to track frozen weights
i = 0
for name, param in sslnet.named_parameters():
    # Check if the parameter belongs to the first residual block
    if name.startswith("feature_extractor.layer1"):
        param.requires_grad = False
        i += 1

# Apply the batch normalization setting
sslnet.apply(set_bn_eval)

# Print the number of weights frozen in the first residual block
print(f"Weights being frozen in the first residual block: {i}")

In [None]:
# Specify fine tuning approach used
fine_tuning_approach = "adapter layers"

# Define the adapter model
class AdapterModel(nn.Module):
    def __init__(self, base_model, feature_dim=1024):  # Set feature_dim based on feature extractor output
        super(AdapterModel, self).__init__()
        self.feature_extractor = base_model.feature_extractor
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        self.adapter = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)  # Output layer for binary classification
        )

    def forward(self, x):
        x = self.feature_extractor(x)  # Pass through the feature extractor
        x = x.squeeze(-1)  # Remove the singleton dimension
        x = self.adapter(x)  # Pass through the adapter layers
        return x

# Load the base pretrained model
# GitHub token below was free and can be replaced as needed
os.environ['GITHUB_TOKEN'] = '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'
repo = 'OxWearables/ssl-wearables'

base_model = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=2, pretrained=True, weights_only=False)
base_model = base_model.to(device).float()

# Wrap the base model with the adapter layers
model = AdapterModel(base_model, feature_dim=1024).to(device)

# Print the model structure
print(model)

In [None]:
# Construct datasets
train_dataset = NormalDataset(x_train, y_train, group_train, name="training", transform=True)
val_dataset = NormalDataset(x_val, y_val, group_val, name="validation")
test_dataset = NormalDataset(x_test, y_test, group_test, name="test")

# Construct dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=0,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=0,
)

In [None]:
# Define loss function that prioritizes precision
class PrecisionLoss(nn.Module):
    def __init__(self, weighted_fp=2, weighted_fn=1):
        super().__init__()
        self.weighted_fp = weighted_fp
        self.weighted_fn = weighted_fn

    def forward(self, outputs, labels):
        probs = F.softmax(outputs, dim=1)
        pos_mask = labels.float()  # Use float precision for labels
        fp_loss = -torch.log(probs[:, 0] + 1e-6) * (1 - pos_mask) * self.weighted_fp
        fn_loss = -torch.log(probs[:, 1] + 1e-6) * pos_mask * self.weighted_fn
        return fp_loss.mean() + fn_loss.mean()

# Initialize with higher penalty for false positives
loss_fn = PrecisionLoss(weighted_fp=10, weighted_fn=1).to(device)

In [None]:
def train_with_precision(model, train_loader, val_loader, device, fine_tuning_approach, timestamp):
    """
    Function to train a model while focusing on improving precision

    Parameters:
    - model: The neural network to be trained.
    - train_loader: DataLoader object for training data.
    - val_loader: DataLoader object for validation data.
    - device: Device to run the model on ('cpu' or 'cuda').
    - fine tuning method
    - timestamp for outputs
    """

    # Define the optimizer: Adam optimizer with a learning rate of 0.0001
    if fine_tuning_approach == "adapter layers":
        # For adapter layer
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, amsgrad=True)
        

    # Set the number of epochs
    num_epochs = 20

    # Variable to track the best validation precision seen so far
    best_val_precision = 0

    # Counter to track epochs without improvement for early stopping
    epochs_without_improvement = 0

    # Number of epochs to wait before stopping if no improvement
    patience = 10
    

    # Loop through each epoch
    for epoch in range(num_epochs):
        # Set the model to training mode
        model.train()

        train_losses = []  # Keeps track of loss during training
        
        # Initialize accumulators for true positives and false positives
        train_true_positives = 0
        train_false_positives = 0
        train_total_positives = 0

        # Training Loop: Process one batch at a time from the training DataLoader
        for batch in train_loader:
            # Unpack the batch: inputs (features), labels (targets), optional metadata
            if len(batch) == 3:
                inputs, labels, _ = batch  # Extract inputs and labels
            else:
                inputs, labels = batch

            # Move inputs and labels to the specified device
            inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device)

            # Reset gradients to avoid accumulation from previous steps
            optimizer.zero_grad()

            # Forward pass: Compute model predictions
            outputs = model(inputs)

            # Compute the loss using the custom loss function (PrecisionLoss)
            loss = loss_fn(outputs, labels)

            # Backward pass: Compute gradients for all parameters
            loss.backward()

            # Update the model parameters based on the computed gradients
            optimizer.step()

            # Record the training loss for analysis
            train_losses.append(loss.item())
            
            # Convert model outputs to predicted labels (argmax gives class index)
            _, predicted = torch.max(outputs, 1)
             # Record true positives and false positives directly
            train_true_positives += ((predicted == 1) & (labels == 1)).sum().item()
            train_false_positives += ((predicted == 1) & (labels == 0)).sum().item()
            train_total_positives += (predicted == 1).sum().item()

        # Compute precision for training data
        train_precision = (
            train_true_positives / train_total_positives
            if train_total_positives > 0
            else 0.0
        )

        # Validation Loop: Evaluate the model on validation data
        model.eval()  # Set model to evaluation mode (disables dropout, batch norm updates)

        # Initialize accumulators for validation precision
        val_true_positives = 0
        val_false_positives = 0
        val_total_positives = 0 

        # Disable gradient computation for validation (faster and saves memory)
        with torch.no_grad():
            for batch in val_loader:
                # Unpack the batch: inputs (features), labels (targets), optional metadata
                if len(batch) == 3:
                    inputs, labels, _ = batch
                else:
                    inputs, labels = batch

                # Move inputs and labels to the specified device
                inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device)

                # Forward pass: Compute model predictions
                outputs = model(inputs)

                # Convert model outputs to predicted labels
                # Record true positives and false positives directly
                _, predicted = torch.max(outputs, 1)
                val_true_positives += ((predicted == 1) & (labels == 1)).sum().item()
                val_false_positives += ((predicted == 1) & (labels == 0)).sum().item()
                val_total_positives += (predicted == 1).sum().item()

        # Compute precision for validation data
        val_precision = (
            val_true_positives / val_total_positives
            if val_total_positives > 0
            else 0.0
        )
        
        # Early stopping: Check if validation precision has improved
        if val_precision > best_val_precision:
            best_val_precision = val_precision  # Update the best precision
            epochs_without_improvement = 0  # Reset the counter
        else:
            epochs_without_improvement += 1  # Increment counter if no improvement

        # Print metrics for the current epoch
        print(f"Epoch [{epoch + 1}/{num_epochs}]")
        print(f"  Train Loss: {sum(train_losses) / len(train_losses):.4f}")  # Average loss
        print(f"  Train Precision: {train_precision:.4f}")
        print(f"  Validation Precision: {val_precision:.4f}")
        
        # If no improvement for 'patience' epochs, stop training early
        if epochs_without_improvement >= patience:
            print(f"Early stopping on epoch {epoch + 1} as validation precision did not improve for {patience} epochs.")
            
            # Save the final model weights with a fine-tuning approach and timestamp
            weights_path = os.path.join(f"Outputs/SSL Weights Saved/model_{fine_tuning_approach}_{timestamp}.pt")
            torch.save(model.state_dict(), weights_path)
            print(f"Weights saved for epoch {epoch + 1} as {weights_path}.")
            
            break

def predict(model, data_loader, device):
    """
    Iterate over the dataloader and do inference with a PyTorch model.

    :param nn.Module model: PyTorch Module
    :param data_loader: PyTorch DataLoader
    :param str device: PyTorch map device ('cpu' or 'cuda')
    :return: true labels, model predictions, pids
    :rtype: (np.ndarray, np.ndarray, np.ndarray)
    """
    from tqdm import tqdm

    predictions_list = []
    true_list = []
    pid_list = []
    model.eval()  # Set model to evaluation mode

    for i, (x, y, pid) in enumerate(tqdm(data_loader)):
        with torch.inference_mode():
            # Ensure input tensor matches model's precision
            x = x.to(device, dtype=torch.float)
            logits = model(x)
            # Collect true labels
            true_list.append(y)
            # Get predicted class indices
            pred_y = torch.argmax(logits, dim=1)
            # Append predictions and participant IDs
            predictions_list.append(pred_y.cpu())
            pid_list.extend(pid)

    # Combine results into numpy arrays
    true_list = torch.cat(true_list)
    predictions_list = torch.cat(predictions_list)

    return (
        torch.flatten(true_list).numpy(),
        torch.flatten(predictions_list).numpy(),
        np.array(pid_list),
    )

In [None]:
# Get the current timestamp for saving weights
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

if fine_tuning_approach == "adapter layers":
    # For adapter layer
    train_with_precision(model, train_loader, val_loader, device, fine_tuning_approach, timestamp)
else:
    train_with_precision(sslnet, train_loader, val_loader, device, fine_tuning_approach, timestamp)

In [None]:
# Load the best model weights from early stopping

# Construct the path to the saved weights
weights_path = os.path.join(f"Outputs/SSL Weights Saved/model_{fine_tuning_approach}_{timestamp}.pt")

weights_path = os.path.join(f"Outputs/SSL Weights Saved/model_adapter layers_20241201_194445.pt")
fine_tuning_approach = "no fine tuning"


if fine_tuning_approach == "adapter layers":
    # For adapter layer
    model.load_state_dict(torch.load(weights_path))
else:
    sslnet.load_state_dict(torch.load(weights_path))

In [None]:
# Evaluate on the test set

if fine_tuning_approach == "adapter layers":
    # For adapter layer
    true_labels, predicted_labels, pids = predict(model, test_loader, device)
else:
    true_labels, predicted_labels, pids = predict(sslnet, test_loader, device)

# Compute evaluation metrics

# Calculate accuracy
test_accuracy = accuracy_score(true_labels, predicted_labels)
print(f"Test Accuracy: {test_accuracy:.2f}")

# Generate a classification report
print("\nClassification Report:")
print(classification_report(true_labels, predicted_labels))

# Calculate overall F1-score
overall_f1 = f1_score(true_labels, predicted_labels, average='weighted')  # 'macro', 'micro', or 'weighted'
print(f"\nOverall F1 Score: {overall_f1:.2f}")

# Calculate overall precision
overall_precision = precision_score(true_labels, predicted_labels, average='binary')
print(f"\nOverall Precision: {overall_precision:.2f}")

# Compute confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)

# Plot confusion matrix using seaborn
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=[0, 1], yticklabels=[0, 1])
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()