## Multitask Learning

Instead of compressing information (autoencoder), we add supervision, forcing the model to learn why braking happens by predicting both intention and intensity.

Why Multitask Learning is the RIGHT pivot

Your HARD dataset exposed the real issue:
- Ambiguity between Light ↔ Normal braking
- Labels depend on future behavior
- Subtle temporal cues matter more than denoising

Autoencoders failed because they removed information.

Multitask learning does the opposite: It adds supervision, not compression.

By asking the model to solve two related tasks at once, we force it to learn representations that explain why braking happens.

In [None]:
# Sanity check (for MTL data generator)
import numpy as np

X = np.load("../data/X_train_hard_mtl.npy")
y_c = np.load("../data/y_class_train_hard_mtl.npy")
y_i = np.load("../data/y_int_train_hard_mtl.npy")

print(X.shape)     # (N, 75, 3)
print(y_c.shape)   # (N,)
print(y_i.shape)   # (N,)
print(y_i.min(), y_i.max())  # should be within [0,1]

(10500, 75, 3)
(10500,)
(10500,)
0.15013748255769468 0.9999489978076975


## Multitask Learning Architecture for Braking Intention Prediction

### Motivation
Experiments on the HARD ambiguous braking dataset revealed that:
- Braking intention classes (Light vs Normal) overlap significantly
- Labels depend on future braking behavior
- Representation compression (autoencoders) degrades fine-grained temporal cues

To address this, we adopt **multitask learning**, which provides additional task-aligned supervision instead of compressing information.

---

### Core Idea
The model is trained to jointly solve two related tasks:
1. **Braking Intention Classification** (Light / Normal / Emergency)
2. **Brake Intensity Regression** (future braking strength ∈ [0,1])

By learning *what* the driver intends to do and *how strongly* they intend to brake, the model develops richer and more discriminative temporal representations.

---

### Model Architecture Overview

![Braking Intention Example](../img/1.jpg)


---

### Shared Backbone
- **Temporal CNN**: captures short-term local patterns (brake taps, fluctuations)
- **LSTM**: models long-term temporal dependencies and future intention buildup
- **Attention**: focuses on critical moments (onset and ramp-up of braking)

This shared backbone learns task-agnostic temporal features.

---

### Task-Specific Heads

#### 1. Braking Intention (Classification)
- Output: 3 logits (Light, Normal, Emergency)
- Activation: Softmax
- Loss: CrossEntropyLoss

#### 2. Brake Intensity (Regression)
- Output: single continuous value
- Represents future braking strength
- Loss: Mean Squared Error (MSE)

---

### Loss Function
The total training loss is a weighted sum:

#### L_total = L_class + λ * L_reg

Where:
- Classification is the primary task
- Regression acts as auxiliary supervision
- λ is set to 0.5 initially

---

### Why Multitask Learning Works Here
- Preserves subtle temporal details
- Forces causal understanding of braking behavior
- Reduces Light vs Normal ambiguity
- Aligns learning objective with physical meaning

This approach is better suited for ambiguous, future-dependent braking scenarios than representation compression methods.

In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from sklearn.metrics import confusion_matrix, classification_report

from models.multitask_lstm_cnn_attention import MultitaskLSTMCNNAttention

In [3]:
# Load Multitask HARD dataset
X_train = np.load("../data/X_train_hard_mtl.npy")
X_val   = np.load("../data/X_val_hard_mtl.npy")
X_test  = np.load("../data/X_test_hard_mtl.npy")

y_class_train = np.load("../data/y_class_train_hard_mtl.npy")
y_class_val   = np.load("../data/y_class_val_hard_mtl.npy")
y_class_test  = np.load("../data/y_class_test_hard_mtl.npy")

y_int_train = np.load("../data/y_int_train_hard_mtl.npy")
y_int_val   = np.load("../data/y_int_val_hard_mtl.npy")
y_int_test  = np.load("../data/y_int_test_hard_mtl.npy")

In [4]:
# Convert to PyTorch tensors
X_train_t = torch.tensor(X_train, dtype = torch.float32)
X_val_t   = torch.tensor(X_val, dtype = torch.float32)
X_test_t  = torch.tensor(X_test, dtype = torch.float32)

y_class_train_t = torch.tensor(y_class_train, dtype = torch.long)
y_class_val_t   = torch.tensor(y_class_val, dtype = torch.long)
y_class_test_t  = torch.tensor(y_class_test, dtype = torch.long)

y_int_train_t = torch.tensor(y_int_train, dtype = torch.float32)
y_int_val_t   = torch.tensor(y_int_val, dtype = torch.float32)
y_int_test_t  = torch.tensor(y_int_test, dtype = torch.float32)

In [None]:
# Initialize model & losses
model = MultitaskLSTMCNNAttention()

criterion_class = nn.CrossEntropyLoss()
criterion_reg   = nn.MSELoss()

lambda_reg = 0.8            # changed from 0.4

optimizer = optim.Adam(model.parameters(), lr = 1e-3)

EPOCHS = 20
BATCH_SIZE = 64

In [6]:
# Training loop (multitask)
def train_one_epoch_mtl(model, X, y_class, y_int, optimizer, batch_size):
    
    model.train()
    total_loss = 0
    correct = 0

    idx = torch.randperm(len(X))

    for i in range(0, len(X), batch_size):
        batch_idx = idx[i : i+batch_size]

        xb = X[batch_idx]
        yb_class = y_class[batch_idx]
        yb_int = y_int[batch_idx]

        optimizer.zero_grad()

        class_logits, int_pred = model(xb)

        loss_class = criterion_class(class_logits, yb_class)
        loss_reg = criterion_reg(int_pred, yb_int)

        loss = loss_class + lambda_reg * loss_reg
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = class_logits.argmax(dim = 1)
        correct += (preds == yb_class).sum().item()

    acc = correct / len(X)
    return total_loss / len(X), acc

In [None]:
# Validation function
def evaluate_mtl(model, X, y_class, y_int):
    
    model.eval()
    total_loss = 0
    correct = 0

    with torch.no_grad():
        class_logits, int_pred = model(X)

        loss_class = criterion_class(class_logits, y_class)
        loss_reg = criterion_reg(int_pred, y_int)
        loss = loss_class + lambda_reg * loss_reg

        preds = class_logits.argmax(dim = 1)
        correct = (preds == y_class).sum().item()

    acc = correct / len(X)
    return loss.item(), acc

In [8]:
# Train Multitask model
for epoch in range(EPOCHS):
    
    train_loss, train_acc = train_one_epoch_mtl(
        model,
        X_train_t,
        y_class_train_t,
        y_int_train_t,
        optimizer,
        BATCH_SIZE
    )

    val_loss, val_acc = evaluate_mtl(
        model,
        X_val_t,
        y_class_val_t,
        y_int_val_t
    )

    print(
        f"[MTL] Epoch {epoch+1}/{EPOCHS} | "
        f"Train Acc: {train_acc:.3f} | "
        f"Val Acc: {val_acc:.3f}"
    )

[MTL] Epoch 1/20 | Train Acc: 0.388 | Val Acc: 0.666
[MTL] Epoch 2/20 | Train Acc: 0.680 | Val Acc: 0.678
[MTL] Epoch 3/20 | Train Acc: 0.714 | Val Acc: 0.691
[MTL] Epoch 4/20 | Train Acc: 0.704 | Val Acc: 0.677
[MTL] Epoch 5/20 | Train Acc: 0.712 | Val Acc: 0.724
[MTL] Epoch 6/20 | Train Acc: 0.718 | Val Acc: 0.720
[MTL] Epoch 7/20 | Train Acc: 0.722 | Val Acc: 0.707
[MTL] Epoch 8/20 | Train Acc: 0.721 | Val Acc: 0.726
[MTL] Epoch 9/20 | Train Acc: 0.718 | Val Acc: 0.698
[MTL] Epoch 10/20 | Train Acc: 0.724 | Val Acc: 0.681
[MTL] Epoch 11/20 | Train Acc: 0.721 | Val Acc: 0.693
[MTL] Epoch 12/20 | Train Acc: 0.724 | Val Acc: 0.711
[MTL] Epoch 13/20 | Train Acc: 0.728 | Val Acc: 0.720
[MTL] Epoch 14/20 | Train Acc: 0.724 | Val Acc: 0.716
[MTL] Epoch 15/20 | Train Acc: 0.725 | Val Acc: 0.664
[MTL] Epoch 16/20 | Train Acc: 0.726 | Val Acc: 0.688
[MTL] Epoch 17/20 | Train Acc: 0.724 | Val Acc: 0.724
[MTL] Epoch 18/20 | Train Acc: 0.727 | Val Acc: 0.666
[MTL] Epoch 19/20 | Train Acc: 0.718 

In [9]:
# Test set evaluation 
model.eval()
with torch.no_grad():
    class_logits, _ = model(X_test_t)
    preds = class_logits.argmax(dim=1).cpu().numpy()

test_acc = (preds == y_class_test).mean()

print(f"[MTL] Test Accuracy: {test_acc:.4f}")

[MTL] Test Accuracy: 0.7102


In [10]:
# Confusion Matrix
print("\nConfusion Matrix (MTL):")
print(confusion_matrix(y_class_test, preds))

print("\nClassification Report (MTL):")
print(classification_report(
    y_class_test,
    preds,
    target_names=["Light Braking", "Normal Braking", "Emergency Braking"]
))


Confusion Matrix (MTL):
[[593 199   2]
 [218 485 103]
 [  6 124 520]]

Classification Report (MTL):
                   precision    recall  f1-score   support

    Light Braking       0.73      0.75      0.74       794
   Normal Braking       0.60      0.60      0.60       806
Emergency Braking       0.83      0.80      0.82       650

         accuracy                           0.71      2250
        macro avg       0.72      0.72      0.72      2250
     weighted avg       0.71      0.71      0.71      2250



Increasing λ:

Forced the shared backbone to encode braking strength more explicitly

Reduced reliance on brittle class boundaries

Helped separate:
- Light vs Normal
- Normal vs Emergency