# Spatio-temporal Gait Characterization from a Wrist-Worn Inertial Measurement Units in Daily Life

## Objective
The primary goal of this project is to develop and validate algorithms that can accurately characterize spatio-temporal gait parameters using data from a wrist-worn IMU. 

Specific objectives include: 
* Apply open-source lower back gait detection algorithms to partially labeled lower back data to validate the best approach for generating labels for the whole dataset.
* Use generated labels from the lower back model to train a custom DCNN to make gait detection predictions on unsupervised wrist-worn sensor data.
* Based on this self-supervised learning approach, use the data from predicted walking bouts to characterize subjects’ gait focusing on walking speed.


In [None]:
# Imports
import os
import re
import sys
import time
import datetime
import pickle
import joblib
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from imblearn.over_sampling import RandomOverSampler
# from imblearn.over_sampling import SMOTE
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    f1_score,
    confusion_matrix,
    precision_score,
    recall_score,
    mean_absolute_error,
    mean_squared_error,
    r2_score,
)
import statsmodels.api as sm

# Mobgap Imports
import mobgap  # Ensure mobgap is properly installed and imported
from mobgap.pipeline import MobilisedPipelineHealthy, GsIterator
from mobgap.gait_sequences import GsdIluz, GsdIonescu, GsdAdaptiveIonescu
from mobgap.utils.conversions import to_body_frame
from mobgap.consts import GRAV_MS2
from mobgap.initial_contacts import IcdShinImproved, refine_gs
from mobgap.laterality import LrcUllrich
from mobgap.stride_length import SlZijlstra
from mobgap.turning import TdElGohary
from mobgap.walking_speed import WsNaive
from mobgap.cadence import CadFromIc

# Custom Imports
from data import NormalDataset, resize, get_inverse_class_weights
from utils import EarlyStopping
from scipy.interpolate import interp1d

## Lumbar Activity Classification and Gait Speed

In [None]:
# Define directories
subjects_dir = '/XXXX/'
output_folder = '/XXXX/'

# Ensure output folder exists
os.makedirs(output_folder, exist_ok=True)

# Preprocessing function for each file
def preprocess_file(file_path):
    print(f"Starting preprocessing for file: {file_path}")
    df = pd.read_csv(file_path)
    df.columns = df.columns.str.strip()
    if 'Time' in df.columns:
        df.rename(columns={'Time': 'timestamp'}, inplace=True)
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    # Standardize column names for compatibility
    df.rename(columns={
        'Accel-X (g)': 'acc_x',
        'Accel-Y (g)': 'acc_y',
        'Accel-Z (g)': 'acc_z',
        'Gyro-X (d/s)': 'gyr_x',
        'Gyro-Y (d/s)': 'gyr_y',
        'Gyro-Z (d/s)': 'gyr_z',
        'Mag-X': 'Mag_X',
        'Mag-Y': 'Mag_Y',
        'Mag-Z': 'Mag_Z'
    }, inplace=True)
    # Select relevant columns
    df = df[['timestamp', 'acc_x', 'acc_y', 'acc_z', 'gyr_x', 'gyr_y', 'gyr_z']]
    df[["acc_x", "acc_y", "acc_z"]] = (
        df[["acc_x", "acc_y", "acc_z"]] * GRAV_MS2
    )
    print(f"Completed preprocessing for file: {file_path}")
    return df

icd = IcdShinImproved()
lrc = LrcUllrich()
cad = CadFromIc()
sl = SlZijlstra()
speed = WsNaive()
turn = TdElGohary()

sampling_rate_hz = 100
# Process each subject folder
for subject_folder in os.listdir(subjects_dir):
    subject_path = os.path.join(subjects_dir, subject_folder)
    # Ensure it's a directory
    if os.path.isdir(subject_path):
        print(f"\nProcessing subject folder: {subject_folder}")
        
        # Find relevant CSV files in the folder based on the pattern
        csv_files = [f for f in os.listdir(subject_path) 
                     if re.match(rf"{subject_folder}-\w{{9}}-\d{{8}}-\d{{8}}\.csv", f)]
        
        # Check if any CSV files were found
        if not csv_files:
            print(f"No relevant CSV files found in {subject_folder}. Skipping this folder.")
            continue
        
        print(f"Found {len(csv_files)} relevant CSV file(s) in {subject_folder}: {csv_files}")
        
        # Initialize an empty list to store data from each file
        subject_data = []
        meta_file = os.path.join(subject_path, 'meta.csv')
        if not os.path.exists(meta_file):
            print(f"Meta file not found for subject {subject_path}. Skipping...")
            continue
            
        try:
            meta_df = pd.read_csv(meta_file, header=None)
            participant_metadata = meta_df.to_dict()
        except Exception as e:
            print(f"Error reading metadata for subject {subject_folder}: {str(e)}. Skipping...")
            continue
        # Process each relevant CSV file
        for csv_file in csv_files:
            file_path = os.path.join(subject_path, csv_file)
            print(f"Processing file: {csv_file}")
            
            # Preprocess the file and append the data
            df = preprocess_file(file_path)
            subject_data.append(df)
        
        # Combine data if any relevant files were found
        if subject_data:
            print(f"Combining data from {len(subject_data)} file(s) for subject {subject_folder}")
            combined_data = pd.concat(subject_data, ignore_index=True)
            subject_data[0]['timestamp'] = pd.to_datetime(subject_data[0]['timestamp'])
            df_time = subject_data[0]['timestamp'] 
            time_diffs= df_time.diff().dropna()
            avg_sampling_rate = time_diffs.mean()
            average_sampling_rate_second = avg_sampling_rate.total_seconds()
            sampling_rate_hz = abs(1/average_sampling_rate_second)
            print("sampling rate: ", sampling_rate_hz)
            # Sort data by timestamp in case of overlapping records
            combined_data.sort_values(by='timestamp', inplace=True)
            combined_data.attrs["participant_metadata"] = participant_metadata
            
            gsd = GsdIonescu()
            imu_data = to_body_frame(combined_data)
            gsd.detect(data=imu_data, sampling_rate_hz=sampling_rate_hz)
            gait_sequences = gsd.gs_list_
            """try:
                start_index = gait_sequences.loc[0, 'start']
                end_index = gait_sequences.loc[0, 'end']
                print(combined_data.iloc[start_index:end_index,:]['mapped_value'].unique())
            except Exception as e:
                print("No gait sequences", e)
                continue"""
            gs_iterator = GsIterator()
            for (_, gs_data), r in gs_iterator.iterate(imu_data, gait_sequences):
                icd = icd.clone().detect(gs_data, sampling_rate_hz=sampling_rate_hz)
                lrc = lrc.clone().predict(gs_data, icd.ic_list_, sampling_rate_hz=sampling_rate_hz)
                r.ic_list = lrc.ic_lr_list_
                turn = turn.clone().detect(gs_data, sampling_rate_hz=sampling_rate_hz)
                r.turn_list = turn.turn_list_

                refined_gs, refined_ic_list = refine_gs(r.ic_list)

                with gs_iterator.subregion(refined_gs) as ((_, refined_gs_data), rr):
                    cad = cad.clone().calculate(
                        refined_gs_data,
                        initial_contacts=refined_ic_list,
                        sampling_rate_hz=sampling_rate_hz
                    )
                    rr.cadence_per_sec = cad.cadence_per_sec_
                    sl = sl.clone().calculate(
                        refined_gs_data,
                        initial_contacts=refined_ic_list,
                        sampling_rate_hz=sampling_rate_hz,
                        sensor_height_m = 1.8
                    )
                    rr.stride_length_per_sec = sl.stride_length_per_sec_
                    speed = speed.clone().calculate(
                        refined_gs_data,
                        initial_contacts=refined_ic_list,
                        cadence_per_sec=cad.cadence_per_sec_,
                        stride_length_per_sec=sl.stride_length_per_sec_,
                        sampling_rate_hz=sampling_rate_hz
                    )
                    rr.walking_speed_per_sec = speed.walking_speed_per_sec_
            results = gs_iterator.results_
            results.ic_list
            gait_analysis_results = pd.concat(
                [
                    results.cadence_per_sec,
                    results.stride_length_per_sec,
                    results.walking_speed_per_sec,
                ],
                axis=1,
            )
            print(gait_analysis_results)
            subject_output_dir = os.path.join(output_folder, subject_folder)
            os.makedirs(subject_output_dir, exist_ok=True)
            gs_list_file = os.path.join(subject_output_dir, "gs_list.csv")
            gait_sequences.to_csv(gs_list_file)
            gait_analysis_results_file = os.path.join(subject_output_dir, "gait_analysis_results.csv")
            gait_analysis_results.to_csv(gait_analysis_results_file)

## Self-Superision: Lumbar Predictons are Labels for Wrist

In [None]:
# Store file path to folders with gait sequence predictions from lower back signal
gait_dir = "/XXXX/"
# Store file path to folders with signal from wrist sensor
wrist_dir = "/XXXX/"
# Output directory
output_dir = '/XXXX/'

# Extract list of subjects with lower back predictions
subjects = os.listdir(gait_dir)
# Exclude two subjects with no demographics data
subjects = [subject for subject in subjects if subject not in ['XXXX', 'XXXX']]


# Loop through all subjects
for folder_name in subjects:
    print(f"Processing subject: {folder_name}")
    # Start timer
    start_time = time.time()
    # Load wrist signal data
    wrist_folder_path = os.path.join(wrist_dir, folder_name)
    csv_path = os.path.join(wrist_folder_path, 'wrist_data.csv')
    df = pd.read_csv(csv_path, usecols=['accel_x', 'accel_y', 'accel_z'])
    signal_load_time = time.time() - start_time
    print(f"Wrist file load time: {signal_load_time:.2f} seconds")
    df.reset_index(inplace=True)  # Add row numbers as an "index" column

    # Start timer for walking_df creation
    walking_df_start_time = time.time()
    # Load gait sequence data
    gs_folder_path = os.path.join(gait_dir, folder_name, 'gs_list.csv')
    gs = pd.read_csv(gs_folder_path)
    # Expand gait sequence to a DataFrame of indices for walking intervals
    walking_df = pd.DataFrame({'index': np.concatenate([np.arange(row['start'], row['end'] + 1) for _, row in gs.iterrows()])})
    walking_df['lower_back_mapped_value'] = 1
    # Check for duplicates in walking_df
    if walking_df.duplicated(subset=['index']).any():
        print(f"Duplicate indices found in walking_df for subject: {folder_name}")
        walking_df.drop_duplicates(subset=['index'], keep='first', inplace=True)
    # Load gait analysis data
    analysis_folder_path = os.path.join(gait_dir, folder_name, 'gait_analysis_results.csv')
    gait_analysis = pd.read_csv(analysis_folder_path, usecols=['sec_center_samples', 'cadence_spm',
                                                               'stride_length_m', 'walking_speed_mps'])
    # Expand gait analysis intervals and add features
    gait_analysis['start'] = (gait_analysis['sec_center_samples'] - 50).astype(int)
    gait_analysis['end'] = (gait_analysis['sec_center_samples'] + 49).astype(int)
    
    analysis_intervals = []
    for row in gait_analysis.itertuples(index=False):
        for idx in range(row.start, row.end + 1):
            analysis_intervals.append({
                'index': idx,
                'cadence_spm': row.cadence_spm,
                'stride_length_m': row.stride_length_m,
                'walking_speed_mps': row.walking_speed_mps
            })
    analysis_intervals = pd.DataFrame(analysis_intervals)
    # Check for duplicates in analysis_intervals
    if analysis_intervals.duplicated(subset=['index']).any():
        print(f"Duplicate indices found in analysis_intervals for subject: {folder_name}")
        analysis_intervals.drop_duplicates(subset=['index'], keep='first', inplace=True)
    # Merge the walking intervals and analysis features
    walking_df = walking_df.merge(analysis_intervals, on='index', how='left')
    walking_df_time = time.time() - walking_df_start_time
    print(f"Time to create walking_df: {walking_df_time:.2f} seconds")
    
    # Step 2: Inner join walking data to wrist signal data
    df = df.merge(walking_df, on='index', how='inner')
    df = df[['index', 'accel_x', 'accel_y', 'accel_z', 'lower_back_mapped_value',
             'cadence_spm', 'stride_length_m', 'walking_speed_mps']]

    output_folder = os.path.join(output_dir, folder_name)
    os.makedirs(output_folder, exist_ok=True)
    output_path = os.path.join(output_folder, 'wrist_lower_back_df.csv')
    df.to_csv(output_path, index=False, chunksize=10**6)
    total_time = time.time() - start_time
    print(f"Finished processing subject: {folder_name} in {total_time:.2f} seconds")


## Wrist Activity Classification

In [None]:
sys.path.append("Oxford/")
device = 'cpu'
wrist_dir = '/XXXX/'
mapped_dir = '/XXXX/'

# List of participants
subjects = os.listdir(mapped_dir)
subjects = [subject for subject in subjects if subject not in ['.ipynb_checkpoints']]

# Initialize an empty list to store all subjects' processed data, select columns
processed_dfs = []
wrist_columns_to_load = ['accel_x', 'accel_y', 'accel_z']
mapped_columns_to_load = ['index', 'accel_x', 'accel_y', 'accel_z', 'lower_back_mapped_value']
# Maximum rows to process per participant
max_rows_per_participant = 70_000_000

# Loop through each participant
for pid, subject in enumerate(subjects, start=1):
    print(f"Processing subject: {subject}")
    subject_start_time = time.time()
    # Load mapped signal data
    mapped_signal_path = os.path.join(mapped_dir, subject, 'wrist_lower_back_df.csv')
    mapped_pd = pd.read_csv(mapped_signal_path, usecols=mapped_columns_to_load)
    # Rename 'lower_back_mapped_value' to 'label'
    mapped_pd.rename(columns={'lower_back_mapped_value': 'label'}, inplace=True)
    # Convert accelerometer values to float32
    mapped_pd[['accel_x', 'accel_y', 'accel_z']] = mapped_pd[['accel_x', 'accel_y', 'accel_z']].astype('float32')
    # Add participant ID column
    mapped_pd['pid'] = pid
    # Store mapped indices in a set for fast lookups
    mapped_indices_set = set(mapped_pd['index'])
    # Add mapped DataFrame to the list
    processed_dfs.append(mapped_pd)
   
    # Load entire wrist signal data
    whole_signal_path = os.path.join(wrist_dir, subject, 'wrist_data.csv')
    chunk_size = 1_000_000  # Chunk size
    rows_read = 0  # Track the number of rows read
    for chunk_idx, chunk in enumerate(pd.read_csv(whole_signal_path, usecols=wrist_columns_to_load, chunksize=chunk_size)):
        chunk_start_time = time.time()
        # Break the loop if max rows are reached
        if rows_read >= max_rows_per_participant:
            print(f"Reached max rows ({max_rows_per_participant}) for participant {subject}.")
            break
        # Create the `index` column for whole signal data
        chunk.reset_index(inplace=True)  # Adds 'index' column with global row numbers
        # Drop rows where accelerometer values are NaN
        chunk = chunk.dropna(subset=['accel_x', 'accel_y', 'accel_z']).copy()
        # Filter out rows in chunk that have an index present in mapped_pd
        chunk = chunk[~chunk['index'].isin(mapped_indices_set)]
        # Add participant ID column
        chunk.loc[:, 'pid'] = pid
        # Convert accelerometer values to float32
        chunk.loc[:, ['accel_x', 'accel_y', 'accel_z']] = chunk[['accel_x', 'accel_y', 'accel_z']].astype('float32')
        # Assign label 0 for non-mapped data
        chunk.loc[:, 'label'] = 0
        # Reorder columns to match mapped_df
        chunk = chunk[['index', 'accel_x', 'accel_y', 'accel_z', 'label', 'pid']]
        # Add chunk to the processed list
        processed_dfs.append(chunk)
        # Update the row count
        rows_read += len(chunk)
    subject_total_time = time.time() - subject_start_time
    print(f"Finished processing subject {subject} in {subject_total_time:.2f} seconds")

# Concatenate all processed chunks into a single DataFrame
df_combined = pd.concat(processed_dfs, ignore_index=True)
print(f"Combined DataFrame info:")
print(df_combined.info())
print(f"Label counts:\n{df_combined['label'].value_counts()}")


def load_data_from_df(df, window_size=3000, target_freq=30, original_freq=100):
    """
    Adjusts the downsampling method to match Oxford's approach using linear interpolation.
    The function prepares data for a pre-trained SSL model, which expects 30Hz sampling and 30s windows.
    It also balances the training and validation sets using RandomOverSampler.

    Parameters:
        df (pd.DataFrame): Input DataFrame containing accelerometer data.
        window_size (int): Number of samples per window at the original frequency (default: 3000 for 100Hz).
        target_freq (int): Target frequency for downsampling (default: 30Hz).
        original_freq (int): Original frequency of the data (default: 100Hz).

    Returns:
        Train, validation, and test splits with balanced training and validation sets.
    """
    # Calculate the downsampled window size
    downsampled_window_size = int(window_size * (target_freq / original_freq))  # 3000 * (30/100) = 900
    # Trim excess data that doesn't fit into full windows of original size
    num_windows = len(df) // window_size
    df = df.iloc[:num_windows * window_size]
    # Reshape the df into windows of shape (num_windows, window_size, 3)
    X = df[['accel_x', 'accel_y', 'accel_z']].to_numpy().reshape(num_windows, window_size, 3)
    y = df['label'].values[:num_windows * window_size].reshape(num_windows, window_size).mean(axis=1).astype(int)
    pid = df['pid'].values[:num_windows * window_size].reshape(num_windows, window_size)[:, 0].astype(int)

    # Downsample X using linear interpolation (Oxford approach)
    t_original = np.linspace(0, 1, window_size)  # Original time points
    t_target = np.linspace(0, 1, downsampled_window_size)  # Target time points
    X_downsampled = np.zeros((num_windows, downsampled_window_size, 3))  # Preallocate array
    for i in range(num_windows):
        for axis in range(3):  # Loop over accel_x, accel_y, accel_z
            interp_func = interp1d(t_original, X[i, :, axis], kind="linear", assume_sorted=True)
            X_downsampled[i, :, axis] = interp_func(t_target)

    # Assign participants to train, validation, and test sets (60/20/20 split)
    unique_pids = np.unique(pid)
    train_pids, test_pids = train_test_split(unique_pids, test_size=0.2, random_state=42)
    train_pids, val_pids = train_test_split(train_pids, test_size=0.25, random_state=41)  # 0.25 * 80% = 20%
    train_idx = np.isin(pid, train_pids)
    val_idx = np.isin(pid, val_pids)
    test_idx = np.isin(pid, test_pids)
    x_train, y_train, pid_train = X_downsampled[train_idx], y[train_idx], pid[train_idx]
    x_val, y_val, pid_val = X_downsampled[val_idx], y[val_idx], pid[val_idx]
    x_test, y_test, pid_test = X_downsampled[test_idx], y[test_idx], pid[test_idx]

    # Balance the training and validation sets using RandomOverSampler
    def oversample_with_noise(X, y, pid):
        # Flatten features for oversampling
        X_flat = X.reshape(X.shape[0], -1)
        ros = RandomOverSampler(random_state=42)
        X_resampled, y_resampled = ros.fit_resample(X_flat, y)
        # Expand pid to match the resampled data
        pid_resampled = ros.fit_resample(pid.reshape(-1, 1), y)[0].ravel()
        # Add small random noise to avoid duplicates
        noise = np.random.normal(0, 0.01, X_resampled.shape)
        X_resampled += noise
        return X_resampled.reshape(-1, downsampled_window_size, 3), y_resampled, pid_resampled

    x_train, y_train, pid_train = oversample_with_noise(x_train, y_train, pid_train)
    x_val, y_val, pid_val = oversample_with_noise(x_val, y_val, pid_val)

    return (
        x_train, y_train, pid_train,
        x_val, y_val, pid_val,
        x_test, y_test, pid_test
    )

# Execution
(
    x_train, y_train, group_train,
    x_val, y_val, group_val,
    x_test, y_test, group_test
) = load_data_from_df(
    df_combined,
    window_size=3000,   # Original window size at 100Hz
    target_freq=30,     # Target frequency for the SSL model
    original_freq=100   # Original frequency of the input data
)

# Count occurrences of each label in validation and test sets
val_classes, val_counts = np.unique(y_val, return_counts=True)
test_classes, test_counts = np.unique(y_test, return_counts=True)
print("Validation class distribution:")
for cls, count in zip(val_classes, val_counts):
    print(f"  Class {cls}: {count} instances")
print("\nTest class distribution:")
for cls, count in zip(test_classes, test_counts):
    print(f"  Class {cls}: {count} instances")


# Load the pretrained model
os.environ['GITHUB_TOKEN'] = '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'
repo = 'OxWearables/ssl-wearables'
sslnet: nn.Module = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=2, pretrained=True, weights_only=False)
sslnet = sslnet.to(device).float()
# Specify fine tuning approach used
fine_tuning_approach = "no fine tuning"

# Approach 1: Freeze the convolutional layers while keeping linear layers trainable
fine_tuning_approach = "freeze conv layers"
os.environ['GITHUB_TOKEN'] = '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'
repo = 'OxWearables/ssl-wearables'
sslnet: nn.Module = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=2, pretrained=True, weights_only=False)
sslnet = sslnet.to(device).float()
def set_bn_eval(m):
    if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
        m.eval()
i = 0
for name, param in sslnet.named_parameters():
    # Check if the parameter belongs to convolutional layers (typically in feature extractor)
    if "conv" in name or "bn" in name or "feature_extractor" in name:
        param.requires_grad = False  # Freeze convolutional layers
        i += 1
    else:
        param.requires_grad = True  # Keep linear layers trainable

# Apply the batch normalization setting
sslnet.apply(set_bn_eval)
print(f"Weights being frozen in the convolutional layers: {i}")

# Approach 2:Freeze the first residual block
fine_tuning_approach = "freeze first residual block"
os.environ['GITHUB_TOKEN'] = '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'
repo = 'OxWearables/ssl-wearables'
sslnet: nn.Module = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=2, pretrained=True, weights_only=False)
sslnet = sslnet.to(device).float()
def set_bn_eval(m):
    if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
        m.eval()
i = 0
for name, param in sslnet.named_parameters():
    # Check if the parameter belongs to the first residual block
    if name.startswith("feature_extractor.layer1"):
        param.requires_grad = False
        i += 1
# Apply the batch normalization setting
sslnet.apply(set_bn_eval)
print(f"Weights being frozen in the first residual block: {i}")

# Approach 3: Adapter Layers
class AdapterModel(nn.Module):
    def __init__(self, base_model, feature_dim=1024):  # Set feature_dim based on feature extractor output
        super(AdapterModel, self).__init__()
        self.feature_extractor = base_model.feature_extractor
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        self.adapter = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)  # Output layer for binary classification
        )

    def forward(self, x):
        x = self.feature_extractor(x) 
        x = x.squeeze(-1)  
        x = self.adapter(x)  
        return x

fine_tuning_approach = "adapter layers"
os.environ['GITHUB_TOKEN'] = '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'
repo = 'OxWearables/ssl-wearables'
base_model = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=2, pretrained=True, weights_only=False)
base_model = base_model.to(device).float()
# Wrap the base model with the adapter layers
model = AdapterModel(base_model, feature_dim=1024).to(device)
print(model)

# Construct datasets
train_dataset = NormalDataset(x_train, y_train, group_train, name="training", transform=True)
val_dataset = NormalDataset(x_val, y_val, group_val, name="validation")
test_dataset = NormalDataset(x_test, y_test, group_test, name="test")

# Construct dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=0,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=0,
)

# Custom Precision Loss Function
class PrecisionLoss(nn.Module):
    def __init__(self, weighted_fp=2, weighted_fn=1):
        super().__init__()
        self.weighted_fp = weighted_fp
        self.weighted_fn = weighted_fn
    def forward(self, outputs, labels):
        probs = F.softmax(outputs, dim=1)
        pos_mask = labels.float()  # Use float precision for labels
        fp_loss = -torch.log(probs[:, 0] + 1e-6) * (1 - pos_mask) * self.weighted_fp
        fn_loss = -torch.log(probs[:, 1] + 1e-6) * pos_mask * self.weighted_fn
        return fp_loss.mean() + fn_loss.mean()
# Initialize with higher penalty for false positives
loss_fn = PrecisionLoss(weighted_fp=10, weighted_fn=1).to(device)


def train_with_precision(model, train_loader, val_loader, device, fine_tuning_approach, timestamp):
    if fine_tuning_approach == "adapter layers":
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, amsgrad=True)
    num_epochs = 20
    best_val_precision = 0 # Variable to track the best validation precision seen so far
    epochs_without_improvement = 0
    patience = 10     # Number of epochs to wait before stopping if no improvement
    for epoch in range(num_epochs):
        # Set the model to training mode
        model.train()
        train_losses = []  # Keeps track of loss during training
        # Initialize accumulators for true positives and false positives
        train_true_positives = 0
        train_false_positives = 0
        train_total_positives = 0
        
        # Training Loop
        for batch in train_loader:
            # Unpack the batch: inputs (features), labels (targets), optional metadata
            if len(batch) == 3:
                inputs, labels, _ = batch  # Extract inputs and labels
            else:
                inputs, labels = batch

            # Move inputs and labels to the specified device
            inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device)
            # Reset gradients to avoid accumulation from previous steps
            optimizer.zero_grad()
            # Forward pass: Compute model predictions
            outputs = model(inputs)
            # Compute the loss using the custom loss function (PrecisionLoss)
            loss = loss_fn(outputs, labels)
            # Backward pass: Compute gradients for all parameters
            loss.backward()
            # Update the model parameters based on the computed gradients
            optimizer.step()
            # Record the training loss for analysis
            train_losses.append(loss.item())
            # Convert model outputs to predicted labels (argmax gives class index)
            _, predicted = torch.max(outputs, 1)
             # Record true positives and false positives directly
            train_true_positives += ((predicted == 1) & (labels == 1)).sum().item()
            train_false_positives += ((predicted == 1) & (labels == 0)).sum().item()
            train_total_positives += (predicted == 1).sum().item()

        # Compute precision for training data
        train_precision = (
            train_true_positives / train_total_positives
            if train_total_positives > 0
            else 0.0
        )

        # Validation Loop:
        model.eval()  # Set model to evaluation mode (disables dropout, batch norm updates)
        # Initialize accumulators for validation precision
        val_true_positives = 0
        val_false_positives = 0
        val_total_positives = 0 
        # Disable gradient computation for validation (faster and saves memory)
        with torch.no_grad():
            for batch in val_loader:
                # Unpack the batch: inputs (features), labels (targets), optional metadata
                if len(batch) == 3:
                    inputs, labels, _ = batch
                else:
                    inputs, labels = batch
                # Move inputs and labels to the specified device
                inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device)
                # Forward pass: Compute model predictions
                outputs = model(inputs)
                # Convert model outputs to predicted labels
                # Record true positives and false positives directly
                _, predicted = torch.max(outputs, 1)
                val_true_positives += ((predicted == 1) & (labels == 1)).sum().item()
                val_false_positives += ((predicted == 1) & (labels == 0)).sum().item()
                val_total_positives += (predicted == 1).sum().item()
        # Compute precision for validation data
        val_precision = (
            val_true_positives / val_total_positives
            if val_total_positives > 0
            else 0.0
        )
        
        # Early stopping
        if val_precision > best_val_precision:
            best_val_precision = val_precision  # Update the best precision
            epochs_without_improvement = 0  # Reset the counter
        else:
            epochs_without_improvement += 1  # Increment counter if no improvement
        print(f"Epoch [{epoch + 1}/{num_epochs}]")
        print(f"  Train Loss: {sum(train_losses) / len(train_losses):.4f}")  # Average loss
        print(f"  Train Precision: {train_precision:.4f}")
        print(f"  Validation Precision: {val_precision:.4f}")
        
        if epochs_without_improvement >= patience:
            print(f"Early stopping on epoch {epoch + 1} as validation precision did not improve for {patience} epochs.")
            weights_path = os.path.join(f"Outputs/SSL Weights Saved/model_{fine_tuning_approach}_{timestamp}.pt")
            torch.save(model.state_dict(), weights_path)
            print(f"Weights saved for epoch {epoch + 1} as {weights_path}.")
            break

def predict(model, data_loader, device):
    predictions_list = []
    true_list = []
    pid_list = []
    model.eval()  # Set model to evaluation mode
    for i, (x, y, pid) in enumerate(tqdm(data_loader)):
        with torch.inference_mode():
            # Ensure input tensor matches model's precision
            x = x.to(device, dtype=torch.float)
            logits = model(x)
            true_list.append(y)
            pred_y = torch.argmax(logits, dim=1)
            predictions_list.append(pred_y.cpu())
            pid_list.extend(pid)

    # Combine results into numpy arrays
    true_list = torch.cat(true_list)
    predictions_list = torch.cat(predictions_list)
    return (
        torch.flatten(true_list).numpy(),
        torch.flatten(predictions_list).numpy(),
        np.array(pid_list),
    )

# Get the current timestamp for saving weights
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if fine_tuning_approach == "adapter layers":
    train_with_precision(model, train_loader, val_loader, device, fine_tuning_approach, timestamp)
else:
    train_with_precision(sslnet, train_loader, val_loader, device, fine_tuning_approach, timestamp)

# Load the best model weights from early stopping
weights_path = os.path.join(f"XXXX.pt")

fine_tuning_approach = "no fine tuning"
if fine_tuning_approach == "adapter layers":
    model.load_state_dict(torch.load(weights_path))
else:
    sslnet.load_state_dict(torch.load(weights_path))

missing_keys, unexpected_keys = sslnet.load_state_dict(torch.load(weights_path), strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")

# Evaluate on the test set
if fine_tuning_approach == "adapter layers":
    true_labels, predicted_labels, pids = predict(model, test_loader, device)
else:
    true_labels, predicted_labels, pids = predict(sslnet, test_loader, device)

# Compute evaluation metrics
test_accuracy = accuracy_score(true_labels, predicted_labels)
print(f"Test Accuracy: {test_accuracy:.2f}")
print("\nClassification Report:")
print(classification_report(true_labels, predicted_labels))
overall_f1 = f1_score(true_labels, predicted_labels, average='weighted')  # 'macro', 'micro', or 'weighted'
print(f"\nOverall F1 Score: {overall_f1:.2f}")
overall_precision = precision_score(true_labels, predicted_labels, average='binary')
print(f"\nOverall Precision: {overall_precision:.2f}")

# Compute confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)
# Plot confusion matrix using seaborn
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=[0, 1], yticklabels=[0, 1])
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()

## Wrist Gait Speed

In [6]:
def wrist_gait_speed():
    gait_dir = "/XXXX/" # Store file path to folders with gait sequence predictions from lower back signal
    wrist_dir = "/XXXX/" # Store file path to folders with signal from wrist sensor
    output_dir = '/XXXX/' # Output directory
    
    # Extract list of subjects with lower back predictions
    subjects = os.listdir(gait_dir)
    # Exclude two subjects with no demographics data
    subjects = [subject for subject in subjects if subject not in ['XXXXX', 'XXXXX']]

    # Capture accuracy metrics
    maes=[]
    r_s=[]
    adj_rs=[]
    mses=[]
    rmses=[]
    # Train regression model only on the first subject, validate with all other subjects
    train_first=True 
    # Loop through all subjects
    for folder_name in subjects:
        print(f"Processing subject: {folder_name}")
        start_time = time.time()
        # Load wrist signal data
        wrist_folder_path = os.path.join(wrist_dir, folder_name)
        csv_path = os.path.join(wrist_folder_path, 'wrist_data.csv')
        df = pd.read_csv(csv_path, usecols=['accel_x', 'accel_y', 'accel_z'])
        signal_load_time = time.time() - start_time
        print(f"Wrist file load time: {signal_load_time:.2f} seconds")
        df.reset_index(inplace=True)  # Add row numbers as an "index" column
    
        walking_df_start_time = time.time()
        # Load lumbar gait sequence data
        gs_folder_path = os.path.join(gait_dir, folder_name, 'gs_list.csv')
        gs = pd.read_csv(gs_folder_path)
        # Expand gait sequence to a DataFrame of indices for walking intervals
        walking_df = pd.DataFrame({'index': np.concatenate([np.arange(row['start'], row['end'] + 1) for _, row in gs.iterrows()])})
        walking_df['lower_back_mapped_value'] = 1
        # Check for duplicates in walking_df
        if walking_df.duplicated(subset=['index']).any():
            print(f"Duplicate indices found in walking_df for subject: {folder_name}")
            walking_df.drop_duplicates(subset=['index'], keep='first', inplace=True)
        
        # Load gait analysis data
        analysis_folder_path = os.path.join(gait_dir, folder_name, 'gait_analysis_results.csv')
        gait_analysis = pd.read_csv(analysis_folder_path, usecols=['sec_center_samples', 'cadence_spm',
                                                                   'stride_length_m', 'walking_speed_mps'])
        # Expand gait analysis intervals and add features
        gait_analysis['start'] = (gait_analysis['sec_center_samples'] - 50).astype(int)
        gait_analysis['end'] = (gait_analysis['sec_center_samples'] + 49).astype(int)
        analysis_intervals = []
        for row in gait_analysis.itertuples(index=False):
            for idx in range(row.start, row.end + 1):
                analysis_intervals.append({
                    'index': idx,
                    'cadence_spm': row.cadence_spm,
                    'stride_length_m': row.stride_length_m,
                    'walking_speed_mps': row.walking_speed_mps
                })
        analysis_intervals = pd.DataFrame(analysis_intervals)
        # Check for duplicates in analysis_intervals
        if analysis_intervals.duplicated(subset=['index']).any():
            print(f"Duplicate indices found in analysis_intervals for subject: {folder_name}")
            analysis_intervals.drop_duplicates(subset=['index'], keep='first', inplace=True)
        # Merge the walking intervals and analysis features
        walking_df = walking_df.merge(analysis_intervals, on='index', how='left')
        walking_df_time = time.time() - walking_df_start_time
        print(f"Time to create walking_df: {walking_df_time:.2f} seconds")
    
        # Step 2: Inner join walking data to wrist signal data
        df = df.merge(walking_df, on='index', how='inner')
        df=df[pd.isna(df.walking_speed_mps)==False]

        # Fitting the OLS regression
        if train_first ==True: # Only train regression on first patient 
            X_test=df[[ 'accel_x', 'accel_y', 'accel_z', 'cadence_spm', 'stride_length_m']][0:50000] # Test set is the last 50,000 observations of the first patient's data
            X=df[[ 'accel_x', 'accel_y', 'accel_z', 'cadence_spm', 'stride_length_m']][50000:] # Train is the rest of the first patient's data
            y_test=df['walking_speed_mps'][0:50000]
            y=df['walking_speed_mps'][50000:]
            X = sm.add_constant(X)  
            model = sm.OLS(y, X).fit()
            print(model.summary())
            train_first=False
        else:
            X_test=df[[ 'accel_x', 'accel_y', 'accel_z', 'cadence_spm', 'stride_length_m']] # Use all other patients to validate OLS results in terms of speed prediction 
            y_test=df['walking_speed_mps']
        X_test = sm.add_constant(X_test)
        predictions = model.predict(X_test)
        try:
            # Get evaluation metrics
            mae=mean_absolute_error(y_test, predictions)
            mse=mean_squared_error(y_test, predictions)
            rmse=np.sqrt(mse)
            r2 = r2_score(y_test, predictions)
            n = len(X_test)  
            p = 5
            adjusted_r2 = 1 - ((1 - r2) * (n - 1)) / (n - p - 1)
            print("Predictions vs Truth:")
            comp = pd.DataFrame({"Truth": y_test, "Prediction": predictions})
            print(comp)
            print("\nScoring Metrics:")
            print(f"Mean Absolute Error (MAE): {mae:.4f}")
            print(f"Mean Squared Error (MSE):{mse:.4f}")
            
            print(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
            print(f"R² Score: {r2:.4f}")
            print(f"Adjusted R² Score:{adjusted_r2:.4f}")
            maes=maes+[float(mae)]
            r_s=r_s+ [float(r2)]
            adj_rs=adj_rs+[float(adjusted_r2)]
            mses=mses+ [float(mse)]
            rmses=rmses+[float(rmse)]
            results=[maes,r_s,adj_rs,mses,rmses]
            # Saving intermediate results in-case failure
            # with open("/XXXXX.pkl", "wb") as file:  # "wb" = write binary mode
            #     pickle.dump(results, file)
            # print("Data has been saved successfully!")
        except: #If it fails, add patient's ID to the list for further examination and continue on to next patient
            maes=maes+[folder_name]
            r_s=r_s+ [folder_name]
            adj_rs=adj_rs+[folder_name]
            mses=mses+ [folder_name]
            rmses=rmses+[folder_name]
            results=[maes,r_s,adj_rs,mses,rmses]
            # Saving intermediate results in-case failure
            # with open("/XXXX.pkl", "wb") as file:  # "wb" = write binary mode
            #     pickle.dump(results, file)
            # print("Data has been saved successfully, skipped file: ",folder_name)
    return model, results

# Function to average accuracy metrics
def avg_metrics(lst):
    numeric_values = [x for x in lst if isinstance(x, (int, float))]
    return sum(numeric_values) / len(numeric_values) 
    

OLS_model, OLS_metrics=wrist_gait_speed()

#Averaging results
avg_scores=[]
titles=["Average Mean Absolute Error (MAE):","Average R² Score:","Average Adjusted R² Score:","Average Mean Squared Error (MSE):","Average Root Mean Squared Error (RMSE):"]
for x in OLS_metrics:
    avg_scores=avg_scores+[avg_metrics(x)]
for x in range(len(avg_scores)):
    print(titles[x],round(avg_scores[x],4))

Average Mean Absolute Error (MAE): 0.0224
Average R² Score: 0.9627
Average Adjusted R² Score: 0.9627
Average Mean Squared Error (MSE): 0.0013
Average Root Mean Squared Error (RMSE): 0.0349
