# AcTBeCalf - Testing the Dataset

---

**Group:**
- João Gabriel
- Gustavo Tironi

**Subject:**
- Deep Learning - Dário Oliveira

---

### Objective

Using the AcTBeCalf dataset, with more than 27 hours of labelled data and 2 weeks of unlabelled data from sensors on calves, we want to analyse different models and get the best model possible to predict calf behaviour using only these sensors, trying to beat the simple models they've tried before (only using 2 or 4 classes).


### Setting the Seeds

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px  # For interactive graphics
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Setup done. Using: {device}")

### Configurations

---

All the parameters we'll use during training and testing, also with paths and dataset information (such as FREQ)

In [None]:
class Config:
    LABELED_PATH = '/kaggle/input/calf-dl/AcTBeCalf.parquet'
    UNLABELED_PATH = '/kaggle/input/calf-dl/Time_Adj_Raw_Data.parquet'
    
    # Parameters
    FREQ = 25  # Hz
    WINDOW_SECONDS = 3
    WINDOW_SIZE = FREQ * WINDOW_SECONDS # 250 samples
    CHANNELS = 3 # X, Y, Z
    
    # Train
    BATCH_SIZE = 64
    NUM_WORKERS = 2 # Using 2, with more than this it may struggle

print("Config done")

### General Inspection

---

Looking at all the info from the dataset, passing through and catching any mistakes that the new parquet might have before we start using the data

In [None]:
print("INSPECTING LABELED DATASET (AcTBeCalf)")

try:
    df_labeled = pd.read_parquet(Config.LABELED_PATH)
    
    # Basic Info
    print(f"Shape: {df_labeled.shape}")
    print("Columns:", df_labeled.columns.tolist())
    print("\nData Types:")
    print(df_labeled.dtypes)
    
    # Null Check
    print("\nMissing Values  per column:")
    null_counts = df_labeled.isnull().sum()
    print(null_counts[null_counts > 0])
    print("\ncalfId Column:")
    unique_calves = df_labeled['calfId'].unique()
    print(f"Number of unique calves: {len(unique_calves)}")
    print(f"First 10 unique values: {unique_calves[:10]}")
    
    if df_labeled['calfId'].isnull().all():
        print("'calfId' column is NaN. Subject-based split is impossible.")
    
    print("\nbehaviour Column:")
    raw_labels = df_labeled['behaviour'].astype(str).unique()
    print(f"Total unique raw labels: {len(raw_labels)}")
    print("Sample of labels (First 20):")
    print(raw_labels[:530])
    
    print("\nLabel Counts (Top 10 most frequent):")
    print(df_labeled['behaviour'].astype(str).value_counts().head(10))
    
    print("\nLabel Counts (Top 10 LEAST frequent - Potential split errors):")
    print(df_labeled['behaviour'].astype(str).value_counts().tail(10))

    # 5. Analyze 'segId'
    print("\nsegId Column:")
    n_segments = df_labeled['segId'].nunique()
    print(f"Total unique segments: {n_segments}")

except Exception as e:
    print(f"Error inspecting labeled data: {e}")

print("\n" + "="*50 + "\n")
try:
    df_unlabeled_head = pl.scan_parquet(Config.UNLABELED_PATH).fetch(5)
    
    print("Unlabeled Data Schema:")
    print(df_unlabeled_head.schema)
    
    print("\nFirst 5 rows:")
    print(df_unlabeled_head)
    
except Exception as e:
    print(f"Error inspecting unlabeled data: {e}")

# Dataset Preparation
### For TimeMAE and ResNet1D Experiments

Here we define the PyTorch Datasets for both labeled and unlabeled data. The goal is to prepare data in a **windowed format** suitable for 1D CNNs or masked autoencoders. It's a bit different than the method we'll use later.

---

#### **Labeled Dataset: `LabeledCalfDataset`**

1. **Purpose:** Create windows of signals with overlapping segments from labeled calf accelerometer data.
2. **Key Steps:**
   - Loads a Parquet file with labeled behavioral data.
   - Maps **string labels to integer codes** for training consistency.
   - Uses **segId** to avoid splitting continuous behavioral segments.
   - Creates **sliding windows** of size `window_size` (250 samples = 10s at 25 Hz) with **50% overlap** for data augmentation.
   - Overlap increases dataset size while preserving temporal coherence.

In [None]:
class LabeledCalfDataset(Dataset):
    def __init__(self, parquet_path, window_size=250):
        print(f"Loading Labeled Dataset: {parquet_path}...")
        self.data = pd.read_parquet(parquet_path)
        
        # Map labels to integers for consistent ordering
        cats = sorted(self.data['behaviour'].unique())
        self.label_map = {label: i for i, label in enumerate(cats)}
        self.int_to_label = {i: label for label, i in self.label_map.items()}
        
        self.indices = []
        
        # Group by segId to avoid breaking windows
        for seg_id, group in self.data.groupby('segId'):
            n = len(group)
            if n >= window_size:
                start_global = group.index[0]
                # Stride = window_size // 2 (50% overlap for natural data augmentation)
                for i in range(0, n - window_size + 1, window_size // 2):
                    self.indices.append((start_global + i, start_global + i + window_size))
        
        self.signals = self.data[['accX', 'accY', 'accZ']].values.astype('float32')
        
        # Map string labels to integer codes
        self.labels = self.data['behaviour'].map(self.label_map).values.astype('int64')
        print(f"Labeled Dataset Ready: {len(self.indices)} samples, {len(cats)} classes.")

    def __len__(self):
        return len(self.indices)
        
    def __getitem__(self, idx):
        s, e = self.indices[idx]
        x = self.signals[s:e]
        y = self.labels[s]
        # Transpose (Time, Channel) -> (Channel, Time) for PyTorch
        return torch.tensor(x).permute(1, 0), torch.tensor(y, dtype=torch.long)


class UnlabeledCalfDataset(Dataset):
    def __init__(self, parquet_path, window_size=250):
        print(f"Loading Unlabeled Dataset: {parquet_path}...")
        # Use Polars for memory efficiency
        df = pl.read_parquet(parquet_path)
        
        self.signals = df.select(['accX', 'accY', 'accZ']).to_numpy().astype('float32')
        timestamps = df.select('dateTime').to_numpy().flatten()  # array of int64 (ns) or datetime
        
        self.valid_indices = []
        total = len(df)
        stride = window_size  # No overlap for speed
        
        # Simplified gap detection logic
        limit_ns = window_size * 40 * 1_000_000 * 1.1
        print("Calculating valid indices (ignoring gaps)...")
        # Vectorized approach: check temporal consistency
        times = timestamps[::stride]  # pick start times
        
        # Simplified implementation: assume continuous data except for large gaps
        for i in range(0, total - window_size, stride):
             self.valid_indices.append(i)
             
        print(f"Unlabeled Dataset Ready: {len(self.valid_indices)} samples.")

    def __len__(self):
        return len(self.valid_indices)
        
    def __getitem__(self, idx):
        s = self.valid_indices[idx]
        x = self.signals[s : s + Config.WINDOW_SIZE]
        return torch.tensor(x).permute(1, 0), torch.tensor(-1)


In [None]:
# Dataset instantiation
ds_labeled = LabeledCalfDataset(Config.LABELED_PATH, Config.WINDOW_SIZE)
ds_unlabeled = UnlabeledCalfDataset(Config.UNLABELED_PATH, Config.WINDOW_SIZE)

# Quick check of the first sample
x, y = ds_labeled[0]
print(f"\nTensor overview:")
print(f"Input Shape (C, L): {x.shape} (Expected: 3, {Config.WINDOW_SIZE})")
print(f"Label: {y} ('{ds_labeled.int_to_label[y.item()]}')")
print(f"Data type: {x.dtype}")

# DataLoader setup
dl_check = DataLoader(ds_labeled, batch_size=32, shuffle=True)
batch_x, batch_y = next(iter(dl_check))
print(f"Batch Shape: {batch_x.shape}")

In [None]:
# Quick count using the Dataset's internal DataFrame
df_counts = ds_labeled.data['behaviour'].value_counts().reset_index()
df_counts.columns = ['behaviour', 'count']

plt.figure(figsize=(10, 10))
sns.barplot(data=df_counts, x='count', y='behaviour', palette='viridis')
plt.title("Behavior Distribution (Labeled Set)")
plt.xlabel("Number of Samples (Raw Rows)")
plt.xscale('log')  # Log scale helps visualize rare classes
plt.grid(axis='x', alpha=0.3)
plt.show()

print("There are rare classes (Vocalization, Grooming|None) and very common classes.")


### Visualizing Sample Accelerometer Windows

This function plots a **3-axis accelerometer signal** for a specific behavior class from the labeled dataset. It will help us see if the labels make sense, like a lying calf to not be moving much, and a running calf the opposite.

- **Behavior:**
  - If `index` is not provided, randomly selects a window corresponding to the requested label.
  - Converts the `(channels, time)` tensor to a DataFrame with axes labeled:
    - `X (Up/Down)`
    - `Y (Forward/Backward)`
    - `Z (Left/Right)`
  - Time axis is computed in seconds using the configured sampling frequency.

In [None]:
def plot_signal(dataset, label_name=None, index=None):
    """Plots the 3-axis accelerometer signal from a random window of a specific class."""
    
    # If no explicit index is given, sample a window belonging to the requested class
    if index is None:
        target_int = dataset.label_map[label_name]
        candidates = [i for i, idxs in enumerate(dataset.indices) 
                      if dataset.labels[idxs[0]] == target_int]
        if not candidates:
            return print(f"No data available for {label_name}")
        idx = random.choice(candidates)
    else:
        idx = index

    # Retrieve the tensor window and its label
    x_tensor, y_tensor = dataset[idx]
    label_str = dataset.int_to_label[y_tensor.item()]
    
    # Convert the tensor (channels-first) to a DataFrame for plotting
    data_np = x_tensor.permute(1, 0).numpy()
    df_plot = pd.DataFrame(
        data_np, 
        columns=['X (Up/Down)', 'Y (Forward/Backward)', 'Z (Left/Right)']
    )
    df_plot['Time'] = np.arange(len(df_plot)) / Config.FREQ
    
    # Create interactive plot
    fig = px.line(
        df_plot,
        x='Time',
        y=['X (Up/Down)', 'Y (Forward/Backward)', 'Z (Left/Right)'],
        title=f"Sample of Behavior: {label_str.upper()} (Index {idx})"
    )
    fig.update_layout(yaxis_title="Acceleration (g)", xaxis_title="Seconds")
    fig.show()

plot_signal(ds_labeled, label_name='lying')    # Should appear nearly flat
plot_signal(ds_labeled, label_name='running')  # Should show peaks and valleys

In [None]:
# Random sample to analyse
idx_rnd = random.randint(0, len(ds_unlabeled)-1)
x_u, _ = ds_unlabeled[idx_rnd]

data_np = x_u.permute(1, 0).numpy()
df_u = pd.DataFrame(data_np, columns=['X', 'Y', 'Z'])
df_u['Time'] = np.arange(len(df_u)) / Config.FREQ

fig = px.line(df_u, x='Time', y=['X', 'Y', 'Z'], 
              title=f"Unlabeled sample (Index {idx_rnd})")
fig.show()

### Behavior Taxonomy Mapping and Dataset Splitting

---

#### 1. Behavior Mapping to Super-Classes

The original dataset contains many fine-grained behaviors, some of which are rare or short-lived. To make the learning task feasible and reduce class imbalance, we map all behaviors into 19 super-classes, which are also specified:

- **Rare Events:** e.g., cough, fall, vocalization → `'rare_event'`
- **Abnormal:** e.g., cross-suckle, tongue, abnormal movements → `'abnormal'`
- **Self-Reactive (SRS):** scratch, rub, stretch → `'srs'`
- **Elimination:** defecation or urination → `'elimination'`
- **Play Behaviors:** jump, headbutt, mount, object play → `'play'`
- **Social Interaction:** nudge, social sniff → `'social_interaction'`
- **Rumination:** `'rumination'`
- **Drinking:** `'drinking'`
- **Eating:** `'eating'`
- **Exploration:** sniffing (non-social) → `'sniff'`
- **Oral Manipulation:** `'oral_manipulation'`
- **Grooming:** self-grooming → `'grooming'`
- **Transitions:** rising, lying down → `'rising'`, `'lying_down_action'`
- **Locomotion:** run → `'running'`, walk/backward → `'walking'`
- **Base Postures:** lying → `'lying'`, standing → `'standing'`

This ensures that we have enough samples per class while maintaining semantic meaning. Rare or minor behaviors are grouped to avoid extremely sparse classes that are difficult for models to learn.

---
#### Train/Validation/Test Split

To ensure **subject-independent evaluation**, we split the dataset based on **`calfId`** whenever possible:

- **Train set:** 70% of calves
- **Validation set:** 15% of calves
- **Test set:** 15% of calves

If there are very few calves, we fallback to splitting by `segId`, and this guarantees that the same calf does not appear in both training and testing, which prevents overfitting to individual animals patterns.


In [None]:
from sklearn.model_selection import train_test_split

df_full = pd.read_parquet(Config.LABELED_PATH)

def map_behavior(label):
    # Convert to lowercase and string to avoid parsing issues
    label = str(label).lower().strip()
    
    # 1. RARE EVENTS (Highest priority because they are short and uncommon)
    if any(x in label for x in ['cough', 'fall', 'vocalisation']):
        return 'rare_event'
        
    # 2. ABNORMAL BEHAVIORS
    if any(x in label for x in ['cross-suckle', 'tongue', 'abnormal']):
        return 'abnormal'
        
    # 3. SRS (Self-Reactive: scratch, rub, stretch)
    if any(x in label for x in ['scratch', 'rub', 'stretch', 'srs']):
        return 'srs'
        
    # 4. ELIMINATION
    if 'defecat' in label or 'urinat' in label:
        return 'elimination'
        
    # 5. PLAY BEHAVIORS — High-energy activities
    # Includes jump, headbutt, mount, playing with objects
    if any(x in label for x in ['play', 'jump', 'headbutt', 'mount']):
        return 'play'
        
    # 6. SOCIAL INTERACTION
    # The paper groups "nudge" and "social sniff" here.
    if 'social' in label or 'nudge' in label:
        return 'social_interaction'
        
    # 7. RUMINATION
    if 'ruminat' in label:
        return 'rumination'
        
    # 8. DRINKING
    if 'drink' in label:
        return 'drinking'
        
    # 9. EATING
    if 'eat' in label:
        return 'eating'
        
    # 10. EXPLORATION (Non-social sniffing)
    if 'sniff' in label:
        return 'sniff'
        
    # 11. ORAL MANIPULATION
    if 'oral' in label or 'manipulation' in label:
        return 'oral_manipulation'
        
    # 12. GROOMING (Self-grooming)
    # "social_groom" would have been caught earlier under social interaction.
    if 'groom' in label:
        return 'grooming'
        
    # 13. TRANSITIONS (Checked before static postures)
    if 'rising' in label:
        return 'rising'
    if 'lying down' in label or 'lying-down' in label:
        return 'lying_down_action'
        
    # 14. LOCOMOTION
    if 'run' in label:
        return 'running'
    if 'walk' in label or 'backward' in label:
        return 'walking'
        
    # 15. BASE POSTURES (Fallback)
    if 'ly' in label:  # captures "lying"
        return 'lying'
    if 'stand' in label:
        return 'standing'
        
    return 'other'

# Apply taxonomy mapping based on the behavioral paper
print("Applying behavior taxonomy...")
df_full['behaviour_clean'] = df_full['behaviour'].apply(map_behavior).astype('category')

# Display resulting classes
classes = sorted(df_full['behaviour_clean'].unique())
print(f"Mapped Classes ({len(classes)}):")
for c in classes:
    print(f"  - {c}")

# Replace categorical labels with integer indices
label_map = {c: i for i, c in enumerate(classes)}
int_to_label = {i: c for c, i in label_map.items()}
df_full['label_code'] = df_full['behaviour_clean'].map(label_map).astype(int)

# --- SUBJECT-INDEPENDENT TRAINING SPLIT ---
unique_calves = df_full['calfId'].unique()

if len(unique_calves) > 5:
    print("\nSplitting by subject (calfId)")
    train_ids, temp_ids = train_test_split(unique_calves, test_size=0.30, random_state=42)
    val_ids, test_ids = train_test_split(temp_ids, test_size=0.50, random_state=42)
    col_split = 'calfId'
    
    train_df = df_full[df_full['calfId'].isin(train_ids)].reset_index(drop=True)
    val_df = df_full[df_full['calfId'].isin(val_ids)].reset_index(drop=True)
    test_df = df_full[df_full['calfId'].isin(test_ids)].reset_index(drop=True)

else:
    print("\nSplitting by segment (few calf IDs available)")
    # Safe fallback in case calfId is unreliable
    seg_ids = df_full['segId'].unique()
    train_ids, temp_ids = train_test_split(seg_ids, test_size=0.30, random_state=42)
    val_ids, test_ids = train_test_split(temp_ids, test_size=0.50, random_state=42)
    
    train_df = df_full[df_full['segId'].isin(train_ids)].reset_index(drop=True)
    val_df = df_full[df_full['segId'].isin(val_ids)].reset_index(drop=True)
    test_df = df_full[df_full['segId'].isin(test_ids)].reset_index(drop=True)

# Compute class weights for handling imbalance
count_series = train_df['label_code'].value_counts().sort_index()
counts = np.zeros(len(classes))
for idx, val in count_series.items():
    if idx < len(counts):
        counts[idx] = val

weights = len(train_df) / (len(classes) * (counts + 1))
class_weights = torch.tensor(weights, dtype=torch.float).to(device)

print(f"\nFinal Dataset Sizes:")
print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")

### Sensor Preprocessing

This class handles **normalization and feature augmentation** for accelerometer data:

1. **Fit:** Computes the global **mean and standard deviation** for the X, Y, Z channels using only the training set.

2. **Transform:**  
   - Normalizes the 3-axis accelerometer signals using Z-score.  
   - Computes a **fourth channel: the magnitude** of the acceleration vector (`sqrt(X² + Y² + Z²)`).  
   - Returns a 4-channel array `[X_norm, Y_norm, Z_norm, Magnitude]` ready for model input.

This ensures consistent scaling and adds a useful summary feature for motion intensity.


In [None]:
class SensorPreprocessor:
    def __init__(self):
        self.mean = None
        self.std = None
        
    def fit(self, dataframe):
        """ Calculate statistics only on the training set """
        print("Calculating global statistics (X, Y, Z)...")
        signals = dataframe[['accX', 'accY', 'accZ']].values
        self.mean = np.mean(signals, axis=0)
        self.std = np.std(signals, axis=0)
        # Safety against division by zero
        self.std[self.std == 0] = 1.0
        print(f"   Mean: {self.mean} | Std: {self.std}")
        
    def transform(self, dataframe):
        """ Normalize and create the 4th channel (Magnitude) """
        x = dataframe[['accX', 'accY', 'accZ']].values.astype(np.float32)
        x_norm = (x - self.mean) / self.std
        mag = np.sqrt(np.sum(x_norm**2, axis=1, keepdims=True))
        x_final = np.concatenate([x_norm, mag], axis=1)
        return x_final


### In-Memory Windowed Dataset

This dataset class prepares the **windowed input for PyTorch models**:

- **Preprocessing:** Applies the `SensorPreprocessor` to normalize signals and add the magnitude channel.  
- **Sliding windows:** Creates overlapping windows per segment to avoid mixing behaviors or calves.  
  - Training uses 50% overlap (`stride = window_size // 2`) for data augmentation.  
  - Validation/test use non-overlapping windows.  
- **Output format:** `(Channels, Time)` tensors `(4, window_size)` along with the corresponding label.  


In [None]:
class InMemoryCalfDataset(Dataset):
    def __init__(self, dataframe, window_size=75, stride=None, mode='train', preprocessor=None):
        self.window_size = window_size
        self.indices = []
        
        print(f"Processing Dataset ({mode})...")
        
        # Apply preprocessing immediately (saves CPU during training)
        if preprocessor:
            self.signals = preprocessor.transform(dataframe)  # Returns array (N_samples, 4)
        else:
            raise ValueError("A trained preprocessor must be provided!")

        # Labels and segment IDs
        self.labels = dataframe['label_code'].values.astype('int64')
        self.seg_ids = dataframe['segId'].values

        # Create sliding windows
        if stride is None:
            stride = window_size // 2 if mode == 'train' else window_size
        
        # Find segment boundaries to avoid mixing calves/behaviors
        changes = np.where(self.seg_ids[:-1] != self.seg_ids[1:])[0] + 1
        starts = np.concatenate(([0], changes))
        ends = np.concatenate((changes, [len(self.signals)]))
        
        for start, end in zip(starts, ends):
            n = end - start
            if n >= window_size:
                for i in range(start, end - window_size + 1, stride):
                    self.indices.append(i)

        print(f"{len(self.indices)} windows created.")

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        start = self.indices[idx]
        end = start + self.window_size
        
        # Slice from preprocessed array
        x = self.signals[start:end]  # (window_size, 4)
        y = self.labels[start]       # Label at start of the window
        
        # PyTorch expects (Channels, Time) -> permute to (4, window_size)
        return torch.tensor(x).permute(1, 0), torch.tensor(y, dtype=torch.long)


# Instantiate datasets
ds_train = InMemoryCalfDataset(train_df, Config.WINDOW_SIZE, mode='train')
ds_val   = InMemoryCalfDataset(val_df, Config.WINDOW_SIZE, mode='val')
ds_test  = InMemoryCalfDataset(test_df, Config.WINDOW_SIZE, mode='test')

# DataLoaders
dl_train = DataLoader(ds_train, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
dl_val   = DataLoader(ds_val, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
dl_test  = DataLoader(ds_test, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"\nBatches per Epoch: {len(dl_train)}")

### Baseline

---

Here we have the simple ResNetBaseline, which is a modified CNN1D, who we'll be using to compare the TimeMAE, which will be our main model for this notebook

In [None]:
import torchvision.models as models

class ResNetBaseline(nn.Module):
    def __init__(self, num_classes, in_channels=4):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        self.backbone.conv1 = nn.Conv2d(
            in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        n_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(n_features, num_classes)

    def forward(self, x):
        x = x.unsqueeze(2) # (Batch, 4, 1, 75)
        x = self.backbone(x)
        return x

# Initialize Model
model = ResNetBaseline(num_classes=len(classes)).to(device)
print(f"Model created for {len(classes)} classes.")

### Training Loop for ResNet1D / TimeMAE Baseline
 
- **Training Loop:**
  - Iterate over epochs.
  - Training phase: forward pass, compute loss, backprop, update weights.
  - Validation phase: evaluate without gradients.
  - Track average loss and accuracy for both train and validation.

**Results:**  
- Train Accuracy: ~77%  
- Validation Accuracy: ~59%  

This confirms the baseline is worse than Random Forest or LSTM+CNN models, showing the challenge of this 19-class problem for simple architectures. This model was trained more than once, with others learning rates and hiperparameters, but it didn't get better than 60% in any of the tests.


In [None]:
from tqdm.notebook import tqdm

EPOCHS = 50
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.2, patience=10, verbose=True
)
criterion = nn.CrossEntropyLoss(weight=class_weights)

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

print("Starting training...")

# ===========================
# TRAINING LOOP
# ===========================
for epoch in range(EPOCHS):
    # --- Training Phase ---
    model.train()
    train_loss, correct, total = 0, 0, 0
    
    loop = tqdm(dl_train, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    for x, y in loop:
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
        loop.set_postfix(loss=loss.item())

    avg_train_loss = train_loss / len(dl_train)
    train_acc = 100 * correct / total

    # --- Validation Phase ---
    model.eval()
    val_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for x, y in dl_val:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()

    avg_val_loss = val_loss / len(dl_val)
    val_acc = 100 * correct / total

    # Record history
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f} (Acc {train_acc:.1f}%) | "
              f"Val Loss={avg_val_loss:.4f} (Acc {val_acc:.1f}%)")
    
    scheduler.step(avg_val_loss)

# --- Plot Training History ---
plt.figure(figsize=(14, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss', marker='o')
plt.plot(history['val_loss'], label='Validation Loss', marker='o')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True, alpha=0.3)

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Accuracy', marker='o')
plt.plot(history['val_acc'], label='Validation Accuracy', marker='o')
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.title("Training and Validation Accuracy")
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### TimeMAE Configuration

This block defines the **hyperparameters for TimeMAE**, a **masked autoencoder for time series**:

- **Input & Patching:**  
  - Window of 3 seconds sampled at 25 Hz → 75 time steps.  
  - Divided into small patches of 5 steps each, forming 15 patches per window.  

- **Model Dimensions:**  
  - Embedding dimension: 64  
  - Attention heads: 4  
  - Dropout: 0.2  

- **Encoder/Decoder Depth:**  
  - Original paper uses deeper stacks (8 encoder, 6 decoder).  
  - Here, we reduce depth (4 encoder, 2 decoder) to **avoid overfitting** and because the dataset is smaller than the other they've used in the paper that introduced the idea (like HAR).  

- **Masking & Tasks:**  
  - Mask ratio: 60% (predict the masked patches).  
  - Codebook size for reconstruction: 128  

- **Optimization:**  
  - Large batch size (1024) for stability  
  - Learning rate 0.003, weight decay 1e-4, momentum tau 0.95  

- **Loss Weights:**  
  - `alpha` and `beta` balance classification (MCC) and reconstruction (MRR) losses.

Overall, this config prepares the **TimeMAE model for time series data** while keeping it **lighter and more stable** for a medium-sized, multi-class calf behavior dataset.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import polars as pl
import numpy as np
import copy

class TimeMAEConfig:
    UNLABELED_PATH = '/kaggle/input/calf-dl/Time_Adj_Raw_Data.parquet'
    
    # Hiperparâmetros do Paper
    FREQ = 25
    WINDOW_SECONDS = 3.0
    WINDOW_SIZE = int(FREQ * WINDOW_SECONDS) 
    PATCH_SIZE = 5   # δ (Delta do paper)
    NUM_PATCHES = WINDOW_SIZE // PATCH_SIZE 
    
    # Model Dimensions (Paper usa 64)
    EMBED_DIM = 64   
    NUM_HEADS = 4    
    DROPOUT = 0.2    
    
    # Depth (Paper: 8 Encoder, 6 Decoder)
    # Kaggle Safe: Vamos reduzir um pouco, mas manter a proporção
    DEPTH_ENC = 4 
    DEPTH_DEC = 2
    
    # Masking & Tasks
    MASK_RATIO = 0.60
    CODEBOOK_SIZE = 128 # Vocabulary size
    
    # Optimization
    BATCH_SIZE = 1024
    LR = 0.003
    WEIGHT_DECAY = 1e-4
    MOMENTUM_TAU = 0.95
    
    # Loss Weights (Paper: alpha=1, beta search 1..10)
    ALPHA = 1.0 # Peso MCC (Classificação)
    BETA = 1.0  # Peso MRR (Regressão)

### Unlabeled TimeMAE Dataset

This dataset prepares **unlabeled accelerometer windows** for TimeMAE:

- Uses sliding windows of 75 time steps (3 s) with a stride of 25 → overlapping windows naturally augment the data.
- Returns each window as a **PyTorch tensor** in shape `(channels, time)` → `(3, 75)`.
- No labels are needed since TimeMAE is **self-supervised** (masked reconstruction task).


In [None]:
# --- Dataset e Collator (Reutilizando lógica otimizada) ---
class UnlabeledMaedDataset(Dataset):
    def __init__(self, parquet_path, window_size=75, stride=25):
        # Stride=25 gera sobreposição para aumentar dados de treino (Data Augmentation natural)
        df = pl.read_parquet(parquet_path)
        self.signals = df.select(['accX', 'accY', 'accZ']).to_numpy().astype('float32')
        self.window_size = window_size
        self.indices = [i for i in range(0, len(self.signals) - window_size, stride)]
        
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        start = self.indices[idx]
        window = self.signals[start : start + self.window_size]
        # (3, 75)
        return torch.from_numpy(window).permute(1, 0)

### Masking Collator for TimeMAE

This collator prepares **masked inputs** for the autoencoder:

- Receives a batch of windows `(B, 3, 75)`.
- Divides each window into `num_patches` patches.
- Randomly selects a fraction (`mask_ratio`, e.g., 60%) of patches to mask.
- Returns the batch and a **boolean mask** indicating which patches are masked (`True = masked`).


In [None]:
class MaskingCollator:
    def __init__(self, num_patches, mask_ratio=0.60):
        self.num_patches = num_patches
        self.mask_ratio = mask_ratio
        
    def __call__(self, batch):
        batch_x = torch.stack(batch) # (B, 3, 75)
        B = batch_x.shape[0]
        num_masked = int(self.mask_ratio * self.num_patches)
        
        # Cria máscara booleana (True = Mascarado)
        mask = torch.zeros(B, self.num_patches, dtype=torch.bool)
        for i in range(B):
            rand_idx = torch.randperm(self.num_patches)[:num_masked]
            mask[i, rand_idx] = True
            
        return batch_x, mask

### TimeMAE Model: Building Blocks and Overview

This block defines the **core architecture** of the TimeMAE (Masked Autoencoder for Time Series) model.

1. **Patch Embedding (`PatchEmbed`)**  
   - Converts a 1D time series `(B, 3, T)` into a sequence of **patch tokens** `(B, N_patches, D)`.  
   - Uses 1D convolution with kernel size equal to patch size for slicing.  
   - Acts as the **input feature encoder** for the transformer.

2. **Cross-Attention Block (`CrossAttentionBlock`)**  
   - Used in the **decoder** to reconstruct masked patches.  
   - Performs multi-head attention between masked queries and visible (encoded) tokens.  
   - Includes **residual connections**, layer normalization, and feed-forward network.

3. **TimeMAE Pretraining Model (`TimeMAE_Pretrain`)**  
   - **Encoder**:  
     - Patch embedding + positional embeddings.  
     - Transformer encoder processes visible (unmasked) patches.  
   - **Decoder**:  
     - Receives masked tokens and applies **cross-attention** to visible encoded tokens.  
     - Multi-layer decoder reconstructs masked representations.
   - **Target Encoder**:  
     - Frozen copy of online encoder for stable teacher signals.  
     - Updated via **Exponential Moving Average (EMA)**.
   - **Codebook**:  
     - Discretizes latent embeddings into a finite set for classification-based reconstruction (MCC).  

4. **Forward Pass**  
   - Takes input windows and a **mask** indicating which patches are masked.  
   - Encoder processes **visible patches**.  
   - Decoder reconstructs **masked patches** using cross-attention.  
   - Target encoder provides reference embeddings for the masked patches.

5. **Losses (`get_losses`)**  
   - **MRR (Masked Reconstruction Regression)**: MSE between predicted and target embeddings.  
   - **MCC (Masked Codebook Classification)**: Cross-entropy using codebook discretization.  
   - Combined loss encourages the model to **reconstruct masked patches accurately** while leveraging discrete representations for stability.

**Overall Idea:**  
TimeMAE learns to predict missing parts of the time series by masking random patches. The encoder-decoder architecture with a **frozen target encoder** stabilizes learning, while the **codebook** allows a hybrid regression+classification objective. For our problem, we reduced the depth of encoder/decoder to avoid **overfitting**, as the dataset is smaller and less complex than datasets used in the original paper.


In [None]:
# --- Building Blocks ---
class PatchEmbed(nn.Module):
    """Feature Encoder: Conv1d for slicing"""
    def __init__(self, in_ch=3, patch_size=5, embed_dim=64):
        super().__init__()
        self.proj = nn.Conv1d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, 3, T) -> (B, N_patches, D)
        return self.proj(x).transpose(1, 2)


class CrossAttentionBlock(nn.Module):
    """Decoder block with cross-attention"""
    def __init__(self, dim, num_heads, dropout=0.1):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim)
        )
        self.norm2 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key_value):
        attn_out, _ = self.cross_attn(query, key_value, key_value)
        x = self.norm1(query + self.dropout(attn_out))
        x = self.norm2(x + self.dropout(self.ffn(x)))
        return x


class TimeMAE_Pretrain(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        # Online encoder
        self.patch_embed = PatchEmbed(3, cfg.PATCH_SIZE, cfg.EMBED_DIM)
        self.pos_embed = nn.Parameter(torch.zeros(1, cfg.NUM_PATCHES, cfg.EMBED_DIM))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, cfg.EMBED_DIM))
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.normal_(self.mask_token, std=.02)
        
        enc_layer = nn.TransformerEncoderLayer(
            cfg.EMBED_DIM, cfg.NUM_HEADS, cfg.EMBED_DIM*4, cfg.DROPOUT,
            batch_first=True, norm_first=True
        )
        self.visible_encoder = nn.TransformerEncoder(enc_layer, cfg.DEPTH_ENC)
        self.decoupled_decoder = nn.ModuleList([
            CrossAttentionBlock(cfg.EMBED_DIM, cfg.NUM_HEADS, cfg.DROPOUT) 
            for _ in range(cfg.DEPTH_DEC)
        ])
        
        # Target encoder (copy of online encoder, frozen)
        self.target_patch_embed = copy.deepcopy(self.patch_embed)
        self.target_visible_encoder = copy.deepcopy(self.visible_encoder)
        for p in self.target_patch_embed.parameters(): p.requires_grad = False
        for p in self.target_visible_encoder.parameters(): p.requires_grad = False
        
        # Codebook for discretization
        self.codebook = nn.Parameter(torch.randn(cfg.CODEBOOK_SIZE, cfg.EMBED_DIM))

    @torch.no_grad()
    def update_target_encoder(self):
        """Update target encoder weights via EMA"""
        m = self.cfg.MOMENTUM_TAU
        for param_q, param_k in zip(self.patch_embed.parameters(), self.target_patch_embed.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)
        for param_q, param_k in zip(self.visible_encoder.parameters(), self.target_visible_encoder.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)

    def forward(self, x, mask):
        B = x.shape[0]
        
        # Online path
        x_patches = self.patch_embed(x) + self.pos_embed
        visible_tokens = x_patches[~mask].view(B, -1, self.cfg.EMBED_DIM)
        encoded_visible = self.visible_encoder(visible_tokens)
        
        pos_masked = self.pos_embed.expand(B, -1, -1)[mask].view(B, -1, self.cfg.EMBED_DIM)
        decoder_queries = self.mask_token + pos_masked
        
        x_dec = decoder_queries
        for layer in self.decoupled_decoder:
            x_dec = layer(query=x_dec, key_value=encoded_visible)
        
        # Target path
        with torch.no_grad():
            target_patches = self.target_patch_embed(x) + self.pos_embed
            masked_real_tokens = target_patches[mask].view(B, -1, self.cfg.EMBED_DIM)
            target_representations = self.target_visible_encoder(masked_real_tokens)
            target_representations = F.layer_norm(target_representations, target_representations.shape[-1:])
            
        return x_dec, target_representations
    
    def get_losses(self, pred_latents, target_latents):
        """Compute regression (MRR) and classification (MCC) losses"""
        loss_mrr = F.mse_loss(pred_latents, target_latents)
        
        flat_targets = target_latents.reshape(-1, self.cfg.EMBED_DIM)
        flat_preds = pred_latents.reshape(-1, self.cfg.EMBED_DIM)
        
        with torch.no_grad():
            sim_target = torch.matmul(flat_targets, self.codebook.t())
            target_ids = torch.argmax(sim_target, dim=-1)
            
        logits = torch.matmul(flat_preds, self.codebook.t())
        loss_mcc = F.cross_entropy(logits, target_ids)
        
        return loss_mrr, loss_mcc

### TimeMAE Pretraining Loop

In this block, we **set up and train the TimeMAE model** on unlabeled accelerometer data. First, we initialize the model, move it to the GPU if available, and prepare the unlabeled dataset with sliding windows. A `MaskingCollator` applies random masking to patches during batching, which is the core of the self-supervised training.

During training, each batch is passed through the model to obtain predictions for the masked patches and reference embeddings from the frozen target encoder. The combined loss, a weighted sum of **MRR (regression reconstruction)** and **MCC (codebook classification)**, is backpropagated. After each step, the target encoder is updated using **momentum** to stabilize learning.

We log periodic progress for batches and compute epoch averages for loss and MRR. The scheduler updates the learning rate based on validation loss trends, and checkpoints are saved after each epoch, with the best model stored separately. Overall, this loop implements **self-supervised pretraining**, teaching the model to reconstruct missing patches from partially observed time series, building robust embeddings for downstream tasks.


In [None]:
cfg = TimeMAEConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = TimeMAE_Pretrain(cfg).to(device)

# Unlabeled dataset and dataloader
ds_ssl = UnlabeledMaedDataset(cfg.UNLABELED_PATH, window_size=cfg.WINDOW_SIZE, stride=25)
dl_ssl = DataLoader(
    ds_ssl, 
    batch_size=cfg.BATCH_SIZE, 
    shuffle=True, 
    num_workers=2, 
    pin_memory=True,
    collate_fn=MaskingCollator(cfg.NUM_PATCHES, cfg.MASK_RATIO),
    drop_last=True 
)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.LR, weight_decay=cfg.WEIGHT_DECAY)

# Checkpoint handling
checkpoint_path = "/kaggle/working/timemae_pretrained.pth"
start_epoch = 0
best_loss = float('inf')

if os.path.exists(checkpoint_path):
    print(f"Checkpoint found: {checkpoint_path}")
    try:
        state = torch.load(checkpoint_path, map_location=device)
        if isinstance(state, dict) and 'model_state_dict' in state:
            model.load_state_dict(state['model_state_dict'])
            if 'optimizer_state_dict' in state:
                optimizer.load_state_dict(state['optimizer_state_dict'])
            if 'loss' in state:
                best_loss = state['loss']
        else:
            model.load_state_dict(state)
        print("Weights loaded successfully. Resuming training...")
    except Exception as e:
        print(f"Error loading checkpoint: {e}. Starting from scratch.")
else:
    print("No checkpoint found. Starting training from scratch.")

# LR Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=3, verbose=True
)

# Training loop
EPOCHS = 10
history = {'loss': [], 'mrr': [], 'mcc': []}

print(f"Starting training: Batch={cfg.BATCH_SIZE}, LR={cfg.LR}, Tau={cfg.MOMENTUM_TAU}")

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    total_mrr = 0
    total_mcc = 0
    steps = len(dl_ssl)
    
    for batch_idx, (x, mask) in enumerate(dl_ssl):
        x, mask = x.to(device), mask.to(device)
        optimizer.zero_grad()
        
        # Forward
        pred_latents, target_latents = model(x, mask)
        
        # Compute losses
        loss_mrr, loss_mcc = model.get_losses(pred_latents, target_latents)
        loss = cfg.ALPHA * loss_mcc + cfg.BETA * loss_mrr
        
        loss.backward()
        optimizer.step()
        
        # Update target encoder via momentum
        model.update_target_encoder()
        
        # Accumulate metrics
        total_loss += loss.item()
        total_mrr += loss_mrr.item()
        total_mcc += loss_mcc.item()
        
        # Periodic logging
        if batch_idx % (steps // 10) == 0 and batch_idx > 0:
            print(f"  Epoch {epoch+1} [{batch_idx}/{steps}] Loss: {loss.item():.4f} (MRR: {loss_mrr.item():.4f})")
    
    # Epoch statistics
    avg_loss = total_loss / steps
    avg_mrr = total_mrr / steps
    avg_mcc = total_mcc / steps
    history['loss'].append(avg_loss)
    
    current_lr = optimizer.param_groups[0]['lr']
    print(f"End Epoch {epoch+1} | Avg Loss: {avg_loss:.4f} | MRR: {avg_mrr:.4f} | LR: {current_lr}")
    
    # Scheduler step
    scheduler.step(avg_loss)
    
    # Save checkpoints
    save_dict = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }
    torch.save(save_dict, "timemae_last.pth")
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), "timemae_pretrained.pth")
        print(f"New best model saved! (Loss: {best_loss:.4f})")

In [None]:
class TimeMAE_Classifier(nn.Module):
    def __init__(self, cfg, num_classes):
        super().__init__()
        self.cfg = cfg
        
        # Backbone
        self.patch_embed = PatchEmbed(3, cfg.PATCH_SIZE, cfg.EMBED_DIM)
        self.pos_embed = nn.Parameter(torch.zeros(1, cfg.NUM_PATCHES, cfg.EMBED_DIM))
        
        # Transformer Encoder
        enc_layer = nn.TransformerEncoderLayer(
            d_model=cfg.EMBED_DIM, 
            nhead=cfg.NUM_HEADS, 
            dim_feedforward=cfg.EMBED_DIM*4, 
            dropout=cfg.DROPOUT, 
            batch_first=True, 
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, cfg.DEPTH_ENC)
        
        # Classification Head
        self.norm = nn.LayerNorm(cfg.EMBED_DIM)
        self.head = nn.Linear(cfg.EMBED_DIM, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        
    def forward(self, x):
        x = self.patch_embed(x) + self.pos_embed
        x = self.encoder(x)
        x = x.mean(dim=1)
        x = self.norm(x)
        logits = self.head(x)
        return logits

    def load_pretrained(self, path):
        print(f"Loading pretrained weights from: {path}")
        checkpoint = torch.load(path, map_location='cpu')
        
        model_dict = self.state_dict()
        pretrained_dict = {}
        
        for k, v in checkpoint.items():
            if k.startswith('patch_embed') or k.startswith('pos_embed'):
                pretrained_dict[k] = v
            elif k.startswith('visible_encoder'):
                pretrained_dict[k.replace('visible_encoder', 'encoder')] = v
        
        # Keep only keys that exist in the classifier
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        
        missing, unexpected = self.load_state_dict(pretrained_dict, strict=False)
        print("Pretrained weights loaded. Decoder and target encoder ignored.")
        print(f"Missing keys (likely the classification head): {missing}")


### Training Loop - TimeMAE

---

Now, we can finally test the model, using our unlabeled data, with +2000h.

---

### Results

As we can see, it didn't even get better than the baseline when talking about accuracy, getting close to 54% at best, with different hiperparameters. 

In [None]:
num_classes = len(classes)
model_ft = TimeMAE_Classifier(cfg, num_classes).to(device)

import os
if os.path.exists("timemae_pretrained.pth"):
    model_ft.load_pretrained("timemae_pretrained.pth")
else:
    print("Warning: Pretrained file not found. Training from scratch.")

# Supervised dataloaders
dl_train = DataLoader(ds_train, batch_size=256, shuffle=True, num_workers=2)
dl_val = DataLoader(ds_val, batch_size=256, shuffle=False, num_workers=2)

# Optimizer and loss
optimizer = torch.optim.Adam(model_ft.parameters(), lr=0.002, weight_decay=1e-3)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Training loop
print(f"\nStarting fine-tuning on {num_classes} classes...")
EPOCHS = 200

for epoch in range(EPOCHS):
    model_ft.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for x, y in dl_train:
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        logits = model_ft(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = logits.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
        
    # Validation
    model_ft.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for x, y in dl_val:
            x, y = x.to(device), y.to(device)
            logits = model_ft(x)
            loss = criterion(logits, y)
            val_loss += loss.item()
            _, predicted = logits.max(1)
            val_total += y.size(0)
            val_correct += predicted.eq(y).sum().item()
            
    # Logs
    acc_train = 100 * correct / total
    acc_val = 100 * val_correct / val_total
    avg_train_loss = train_loss / len(dl_train)
    avg_val_loss = val_loss / len(dl_val)

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f} (Acc {acc_train:.1f}%) | "
              f"Val Loss={avg_val_loss:.4f} (Acc {acc_val:.1f}%)")


### Time Series Augmentation

This class implements **data augmentation for time series**. The `weak_aug` method applies small additive noise and slight scaling to the signals, simulating natural sensor variability. The `strong_aug` method builds on this by additionally performing **segment permutation** (shuffling parts of the signal) and **time masking** (zeroing out random segments), which helps the model learn to be robust to missing or reordered information. These augmentations are key for semi-supervised approaches like FixMatch, where the model benefits from diverse views of the same data.


In [None]:
# ============================================================================== 
# 1. TIME SERIES AUGMENTATION CLASS
# ==============================================================================
class TimeSeriesAugmentations:
    def __init__(self, device):
        self.device = device

    def weak_aug(self, x):
        """Weak augmentation: small additive noise and scaling"""
        B, C, T = x.shape
        sigma = 0.05
        noise = torch.randn_like(x, device=self.device) * sigma   # Additive noise
        scale = torch.rand(B, 1, 1, device=self.device) * 0.2 + 0.9  # Multiplicative scale
        return (x * scale) + noise

    def strong_aug(self, x):
        """Strong augmentation: segment permutation + time masking"""
        x = self.weak_aug(x)
        B, C, T = x.shape
        x_aug = x.clone()
        
        # Segment permutation
        if T > 10:
            num_segs = np.random.randint(2, 5)
            seg_len = T // num_segs
            for i in range(B):
                perm = torch.randperm(num_segs)
                segs = [x[i, :, p*seg_len : (p+1)*seg_len] for p in perm]
                shuffled = torch.cat(segs, dim=1)
                
                # Adjust length if needed
                curr_len = shuffled.shape[1]
                if curr_len < T:
                    pad = torch.zeros(C, T - curr_len, device=self.device)
                    shuffled = torch.cat([shuffled, pad], dim=1)
                elif curr_len > T:
                    shuffled = shuffled[:, :T]
                x_aug[i] = shuffled

        # Time masking
        mask_ratio = 0.2
        mask_len = int(T * mask_ratio)
        for i in range(B):
            start = np.random.randint(0, max(1, T - mask_len))
            x_aug[i, :, start : start+mask_len] = 0.0
            
        return x_aug

### FixMatch Semi-Supervised Training

This block implements the **FixMatch semi-supervised framework** using a ResNet baseline for time series classification. It sets up the model, optimizer, learning rate scheduler, and checkpoint handling. The key idea is to combine **supervised loss on labeled data** with an **unsupervised loss on unlabeled data**. The unlabeled data is augmented twice: once weakly to generate pseudo-labels, and once strongly to enforce consistency. Only pseudo-labels with high confidence (above `FIXMATCH_THRESHOLD`) contribute to the unsupervised loss. Augmentations like additive noise, scaling, segment permutation, and time masking increase robustness. The code also manages preprocessor normalization, tracks metrics like training accuracy and mask rate, and saves checkpoints when validation performance improves.


In [None]:
CHECKPOINT_PATH = "fixmatch_resnet.pth"
NUM_CLASSES = 19
LR = 10        
EPOCHS = 80
FIXMATCH_THRESHOLD = 0.90 
LAMBDA_U = 15.0          

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNetBaseline(NUM_CLASSES, in_channels=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-3)
augmenter = TimeSeriesAugmentations(device)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

# Make sure the preprocessor exists
if 'preprocessor' not in globals():
    preprocessor = SensorPreprocessor()
    # Fit preprocessor on labeled training data
    preprocessor.fit(train_df)

start_epoch = 0
best_val_acc = 0.0

# Load checkpoint if exists
if os.path.exists(CHECKPOINT_PATH):
    print(f"Checkpoint found: {CHECKPOINT_PATH}")
    try:
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_acc = checkpoint.get('best_val_acc', 0.0)
        print(f"Resuming from epoch {start_epoch}. Best validation accuracy: {best_val_acc:.2f}%")
    except Exception as e:
        print(f"Error loading checkpoint: {e}. Starting from scratch.")
else:
    print("Starting FixMatch training from scratch.")

# Infinite iterator for unlabeled data
def cycle(iterable):
    while True:
        for x in iterable:
            yield x
unlabeled_iter = cycle(dl_ssl)

# Move preprocessor mean/std to GPU
mean_gpu = torch.tensor(preprocessor.mean, device=device, dtype=torch.float32).view(1, 3, 1)
std_gpu = torch.tensor(preprocessor.std, device=device, dtype=torch.float32).view(1, 3, 1)

print(f"Running FixMatch (Threshold={FIXMATCH_THRESHOLD}, Lambda={LAMBDA_U}) for {EPOCHS} epochs...")

for epoch in range(start_epoch, EPOCHS):
    model.train()
    
    train_loss_acc = 0
    sup_loss_acc = 0
    unsup_loss_acc = 0
    mask_count = 0
    total_unlabeled = 0
    train_correct = 0
    train_total = 0
    
    for batch_idx, (inputs_x, targets_x) in enumerate(dl_train):
        inputs_x, targets_x = inputs_x.to(device), targets_x.to(device)
        
        # Unlabeled batch
        try:
            batch_u = next(unlabeled_iter)
            inputs_u_raw = batch_u[0] if isinstance(batch_u, (list, tuple)) else batch_u
        except:
            unlabeled_iter = cycle(dl_ssl)
            inputs_u_raw = next(unlabeled_iter)[0]
        inputs_u_raw = inputs_u_raw.to(device)
        
        # Convert 3 channels -> 4 channels on-the-fly
        with torch.no_grad():
            u_norm = (inputs_u_raw - mean_gpu) / std_gpu
            u_mag = torch.sqrt(torch.sum(u_norm**2, dim=1, keepdim=True))
            inputs_u = torch.cat([u_norm, u_mag], dim=1)
        total_unlabeled += inputs_u.size(0)
        
        # Supervised loss (weak augment)
        inputs_x_aug = augmenter.weak_aug(inputs_x)
        logits_x = model(inputs_x_aug)
        loss_sup = nn.CrossEntropyLoss(weight=class_weights)(logits_x, targets_x)
        
        # Training accuracy
        with torch.no_grad():
            _, pred_x = logits_x.max(1)
            train_correct += pred_x.eq(targets_x).sum().item()
            train_total += inputs_x.size(0)
        
        # Unsupervised loss (FixMatch)
        with torch.no_grad():
            inputs_u_weak = augmenter.weak_aug(inputs_u)
            logits_u_weak = model(inputs_u_weak)
            probs_u = torch.softmax(logits_u_weak, dim=1)
            max_probs, pseudo_label = torch.max(probs_u, dim=1)
            mask = max_probs.ge(FIXMATCH_THRESHOLD).float()
            mask_count += mask.sum().item()
        
        inputs_u_strong = augmenter.strong_aug(inputs_u)
        logits_u_strong = model(inputs_u_strong)
        loss_u_unreduced = nn.CrossEntropyLoss(reduction='none')(logits_u_strong, pseudo_label)
        loss_unsup = (loss_u_unreduced * mask).mean()
        
        # Total loss
        loss = loss_sup + (LAMBDA_U * loss_unsup)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_loss_acc += loss.item()
        sup_loss_acc += loss_sup.item()
        unsup_loss_acc += loss_unsup.item()
    
    # Epoch metrics
    avg_train_loss = train_loss_acc / len(dl_train)
    avg_sup = sup_loss_acc / len(dl_train)
    avg_unsup_weighted = (unsup_loss_acc / len(dl_train)) * LAMBDA_U
    mask_rate = mask_count / max(total_unlabeled, 1)
    train_acc = 100. * train_correct / train_total
    
    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dl_val:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = nn.CrossEntropyLoss(weight=class_weights)(logits, y)
            val_loss += loss.item()
            _, predicted = logits.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    avg_val_loss = val_loss / len(dl_val)
    val_acc = 100. * correct / total
    
    scheduler.step(avg_val_loss)
    
    # Logging
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.3f} [Sup: {avg_sup:.3f}, Unsup: {avg_unsup_weighted:.3f}]")
    print(f"   Train Acc: {train_acc:.2f}% | Mask Rate: {mask_rate:.1%}")
    print(f"   Val Loss:  {avg_val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    
    # Checkpoint
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f"New best validation accuracy: {best_val_acc:.2f}%")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'best_val_acc': best_val_acc,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }, CHECKPOINT_PATH)
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"fixmatch_ep{epoch+1}.pth")

### Observations on TimeMAE & FixMatch Performance

In our experiments, TimeMAE and the FixMatch pipeline **did not outperform simpler models** like Random Forest or LSTM+CNN, with accuracy around 54–56%. Several factors likely contributed:

- **Data scarcity:** Masked autoencoders (TimeMAE) and semi-supervised approaches require substantial unlabeled data to learn meaningful representations. With limited calves and short time series, the models may not generalize well.  
- **Overly complex models:** TimeMAE's encoder-decoder architecture has many parameters relative to the dataset size. Even reducing depth may not prevent overfitting.  
- **Augmentation sensitivity:** Transformations such as segment permutation or time masking may disrupt behavior patterns more than they help, especially for short windows (3s).  
- **Representation misalignment:** Autoencoder latent spaces may not capture discriminative features needed for downstream classification, leading to weak pseudo-labels and poor unsupervised loss contribution.

Overall, these methods are **data-hungry and sensitive to hyperparameters**, which explains why simpler supervised models performed better in this problem.
