In [29]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tsai.models.TST import TST
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.nn.utils import spectral_norm
from scipy import signal
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, roc_auc_score
import random

def set_seeds(seed=42):
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seeds()

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: mps


In [30]:
# First, fix your dataset class to properly structure the data
class SensorDataset(Dataset):
    def __init__(self, X, y, session_ids, window_size=4):
        self.X = []
        self.y = []
        self.window_size = window_size

        # Group data by session
        session_groups = {}
        for idx, session_id in enumerate(session_ids):
            if session_id not in session_groups:
                session_groups[session_id] = []
            session_groups[session_id].append((X[idx], y[idx]))

        # Create sequences within each session
        for session_id, session_data in session_groups.items():
            if len(session_data) >= window_size:
                # Create sliding windows of walking steps
                for i in range(len(session_data) - window_size + 1):
                    # Create a 3D tensor with shape [window_size, time_steps, num_sensors]
                    steps = [data[0] for data in session_data[i:i+window_size]]
                    window_X = np.stack(steps)  # Stack to create 3D structure
                    
                    # Get label from the window (e.g., majority or from the last step)
                    window_y = session_data[i+window_size-1][1]
                    
                    self.X.append(window_X)
                    self.y.append(window_y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.float32)

In [31]:
X_original = np.load('sensor_data.npy')  # Shape: (num_steps, time_steps, num_sensors)
metadata = pd.read_csv('combined_metadata.csv')
y = metadata['has_ms'].values

In [32]:
# Split by session for sequence integrity while stratifying
sessions = metadata['session_id'].values
unique_sessions = np.unique(sessions)

# Create a mapping of session_id to MS status
session_to_ms_status = {}
for session_id in unique_sessions:
    # Get all rows for this session
    session_mask = metadata['session_id'] == session_id
    # If any row has MS, the whole session is labeled as MS
    has_ms = any(metadata.loc[session_mask, 'has_ms'] == 1)
    session_to_ms_status[session_id] = 1 if has_ms else 0

# Create lists of session IDs by MS status
ms_sessions = [s for s, status in session_to_ms_status.items() if status == 1]
non_ms_sessions = [s for s, status in session_to_ms_status.items() if status == 0]

# Perform stratified split on MS and non-MS sessions separately
train_ms, temp_ms = train_test_split(ms_sessions, test_size=0.3, random_state=42, shuffle=True)
train_non_ms, temp_non_ms = train_test_split(non_ms_sessions, test_size=0.3, random_state=42, shuffle=True)

# Further split temp sets into validation and test
val_ms, test_ms = train_test_split(temp_ms, test_size=0.5, random_state=42, shuffle=True)
val_non_ms, test_non_ms = train_test_split(temp_non_ms, test_size=0.5, random_state=42, shuffle=True)

# Combine MS and non-MS sessions for each split
train_sessions = train_ms + train_non_ms
val_sessions = val_ms + val_non_ms
test_sessions = test_ms + test_non_ms

train_indices = metadata['session_id'].isin(train_sessions)
val_indices = metadata['session_id'].isin(val_sessions)
test_indices = metadata['session_id'].isin(test_sessions)

In [33]:
# Downsample X to have 10 timesteps instead of 100
X = X_original.copy()#reshape(X_original.shape[0], 25, -1, X_original.shape[2]).mean(axis=2)

X_train, X_val, X_test = X[train_indices], X[val_indices], X[test_indices]
y_train, y_val, y_test = y[train_indices], y[val_indices], y[test_indices]

# Normalize data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train.reshape(-1, X.shape[2])).reshape(X_train.shape)
X_val = scaler.transform(X_val.reshape(-1, X.shape[2])).reshape(X_val.shape)

# Session IDs for reference
train_sessions_ids = sessions[train_indices]
val_sessions_ids = sessions[val_indices]

window_size = 4

# Create datasets
train_dataset = SensorDataset(X_train, y_train, train_sessions_ids, window_size=window_size)
val_dataset = SensorDataset(X_val, y_val, val_sessions_ids, window_size=window_size)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [34]:
def train(model, train_loader, criterion, optimizer, epochs=20, scheduler=None):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            output = model(X_batch)
            output = output.squeeze()
            
            # Convert target to float for BCE loss
            y_batch = y_batch.float()
            
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch + 1}, Loss: {avg_loss:.4f}')

        if scheduler is not None:
            scheduler.step(avg_loss)

In [37]:
class TSTClassifier(nn.Module):
    def __init__(self, c_in, time_steps, window_size, c_out=1):
        super().__init__()
        
        # Process each time step in each walking step
        self.tst = TST(
            c_in=c_in,              # Number of sensors
            seq_len=time_steps,     # Time steps in a single walking step
            c_out=64,               # Embedding dimension
            max_seq_len=time_steps, # Max sequence length
            n_layers=3,             
            n_heads=8,             
            d_model=128,            
            d_ff=256,               
            dropout=0.2             
        )
        
        # LSTM to process sequence of step embeddings
        self.lstm = nn.LSTM(
            input_size=64,          # Same as TST c_out
            hidden_size=32,
            num_layers=2,
            batch_first=True,
            dropout=0.2
        )
        
        # Final classifier
        self.classifier = nn.Linear(32, c_out)
        
    def forward(self, x):
        # x shape: [batch_size, window_size, time_steps, n_features]
        batch_size, window_size, time_steps, n_features = x.shape
        
        # Process each walking step in the window
        step_embeddings = []
        for i in range(window_size):
            # Extract this step and transpose for TST
            step = x[:, i]  # [batch, time_steps, n_features]
            step = step.transpose(1, 2)  # [batch, n_features, time_steps]
            
            # Get embedding for this step
            emb = self.tst(step)  # [batch, embedding_dim]
            step_embeddings.append(emb)
        
        # Stack embeddings to create sequence
        sequence = torch.stack(step_embeddings, dim=1)  # [batch, window, embedding_dim]
        
        # Process with LSTM
        _, (h_n, _) = self.lstm(sequence)
        
        # Use final hidden state for classification
        out = self.classifier(h_n[-1])
        
        return out

# Create the model
input_dim = X_train.shape[2]  # Number of sensors (7)
time_steps = X_train.shape[1]  # Time steps per walking step (100)

# Create the model
model = TSTClassifier(
    c_in=input_dim,
    time_steps=time_steps,
    window_size=window_size,
    c_out=1
).to(device)
# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

# Train the model
train(model, train_loader, criterion, optimizer, epochs=30, scheduler=scheduler)

# Validation function
def validate(model, val_loader):
    model.eval()
    val_preds = []
    val_targets = []
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            outputs = outputs.squeeze()

            probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
            val_preds.extend(probs)
            val_targets.extend(y_batch.cpu().numpy())
    
    # Binary predictions
    val_preds_binary = [1 if p > 0.5 else 0 for p in val_preds]
    
    # Metrics
    accuracy = accuracy_score(val_targets, val_preds_binary)
    auc = roc_auc_score(val_targets, val_preds)
    
    print(f"Validation - Accuracy: {accuracy:.4f}, AUC: {auc:.4f}")
    print(classification_report(val_targets, val_preds_binary))
    
    return val_preds, val_targets

# Run validation
validate(model, val_loader)

Epoch 1, Loss: 0.6355
Epoch 2, Loss: 0.6197
Epoch 3, Loss: 0.6171
Epoch 4, Loss: 0.6174
Epoch 5, Loss: 0.6175
Epoch 6, Loss: 0.6177
Epoch 7, Loss: 0.6201
Epoch 8, Loss: 0.6180
Epoch 9, Loss: 0.6171
Epoch 10, Loss: 0.6168
Epoch 11, Loss: 0.6169
Epoch 12, Loss: 0.6168
Epoch 13, Loss: 0.6171


KeyboardInterrupt: 