## **2.2 Multi-Task Model Architecture**

The model is formulated as a **multi-task learning system**, where a single shared representation supports multiple prediction objectives.

### **2.2.2 Shared Backbone: The ‚ÄúMaster‚Äù Feature Finder**

The core of the architecture is a **ResNet-34 CNN** used as a **shared feature extractor** for all three tasks.

Instead of training separate networks for age, gender, and race, the model first learns a strong, general-purpose **face representation**, which is then reused by task-specific heads.

* **Input:** Preprocessed face images.
* **Output:** A compact embedding capturing facial structure, texture, and shape.
* **Benefit:** Since all tasks backpropagate through the same backbone, the learned representation generalizes across labels instead of overfitting to a single task.

#### **2.2.2.1 The Input‚ÄìOutput Handshake**

Before implementing task-specific logic, a strict contract is defined:

* **Input:** A batch of face images resized and normalized to $(B, 3, 224, 224)$.
* **Output:** A 512-dimensional embedding per image.
* **Shape Guarantee:** A batch of size $B$ produces an output tensor of shape $(B, 512)$.

This contract ensures that downstream heads can be attached cleanly and independently.

#### **2.2.2.2 Surgery: Turning a Classifier into a Feature Finder**

A standard ResNet-34 is trained to classify 1,000 ImageNet categories. These final class predictions are not useful for facial attribute learning.

* **The Operation:** The final classification layer is removed.
* **Stopping Point:** The network is truncated after the **Global Average Pooling** stage.
* **Cleanup:** The pooled tensor $(B, 512, 1, 1)$ is flattened into a clean $(B, 512)$ embedding.

This transforms ResNet-34 from an object classifier into a reusable feature extractor.

#### **2.2.2.3 Training Strategy: Let the Backbone Learn**

For the baseline model, the backbone is **fine-tuned**, not frozen.

* **Full Gradient Flow:** All backbone parameters remain trainable.
* **Shared Learning Signal:** Errors from age, gender, and race predictions all update the same representation.
* **Result:** The backbone learns facial features that are broadly useful across tasks.

#### **2.2.2.4 Trust-but-Verify Sanity Checks**

Before large-scale training, two sanity checks validate correctness:

1. **Shape Check:** A dummy input must produce exactly a 512-dimensional embedding.
2. **Gradient Check:** A backward pass must propagate gradients into backbone parameters.

These checks ensure the backbone is actively learning and not acting as a frozen observer.

In [25]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Literal

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode

In [26]:
ff_train_df = pd.read_csv('/home/onkar/projects/data/fairface/splits/ff_train.csv')
ff_train_df.head(10)

Unnamed: 0,file,age,gender,race,service_test,split
0,train/83377.jpg,50-59,Female,Indian,False,train
1,train/84431.jpg,20-29,Male,Black,True,train
2,train/682.jpg,20-29,Female,Latino_Hispanic,True,train
3,train/5478.jpg,40-49,Male,Middle Eastern,False,train
4,train/45214.jpg,30-39,Male,Southeast Asian,False,train
5,train/11276.jpg,40-49,Male,East Asian,True,train
6,train/65931.jpg,50-59,Female,East Asian,True,train
7,train/72645.jpg,20-29,Female,Southeast Asian,False,train
8,train/86050.jpg,10-19,Female,Indian,False,train
9,train/44783.jpg,30-39,Female,Indian,False,train


In [27]:
ff_valid_df = pd.read_csv('/home/onkar/projects/data/fairface/splits/ff_val.csv')
ff_valid_df.head(10)

Unnamed: 0,file,age,gender,race,service_test,split
0,train/3285.jpg,20-29,Female,East Asian,False,val
1,train/8561.jpg,30-39,Male,Southeast Asian,True,val
2,val/2503.jpg,20-29,Male,Southeast Asian,True,val
3,train/51659.jpg,20-29,Male,White,True,val
4,val/1482.jpg,10-19,Male,Southeast Asian,False,val
5,val/7815.jpg,20-29,Female,White,False,val
6,val/5324.jpg,40-49,Male,White,False,val
7,train/60461.jpg,20-29,Female,Middle Eastern,True,val
8,train/28122.jpg,50-59,Male,Latino_Hispanic,False,val
9,train/68788.jpg,20-29,Female,Indian,True,val


In [28]:
ff_test_df = pd.read_csv('/home/onkar/projects/data/fairface/splits/ff_test.csv')
ff_test_df.head(10)

Unnamed: 0,file,age,gender,race,service_test,split
0,val/8440.jpg,30-39,Female,White,False,test
1,train/58829.jpg,20-29,Female,Black,True,test
2,train/25607.jpg,30-39,Male,East Asian,True,test
3,train/83916.jpg,30-39,Female,White,False,test
4,val/4051.jpg,3-9,Male,East Asian,True,test
5,train/75496.jpg,30-39,Female,Latino_Hispanic,False,test
6,train/53677.jpg,30-39,Male,Black,True,test
7,train/11646.jpg,3-9,Male,Indian,False,test
8,train/78166.jpg,3-9,Female,Middle Eastern,True,test
9,train/44393.jpg,30-39,Female,White,False,test


**The LabelEncoder:** 

In [29]:
class LabelEncoder:
    def __init__(self):
        self.to_id: Dict[str, int] = {}
        self.to_name: Dict[int, str] = {}

    def fit(self, values) -> "LabelEncoder":
        uniq = sorted(pd.Series(values).astype(str).unique().tolist())
        self.to_id = {name: i for i, name in enumerate(uniq)}
        self.to_name = {i: name for name, i in self.to_id.items()}
        return self

    def encode_one(self, v) -> int:
        v = str(v)
        if v not in self.to_id:
            raise KeyError(f"Unseen label '{v}' (not in train).")
        return self.to_id[v]

@dataclass
class FairFaceEncoders:
    age: LabelEncoder
    gender: LabelEncoder
    race: LabelEncoder

class FairFaceEncoderBuilder:
    @staticmethod
    def fit(train_df: pd.DataFrame) -> FairFaceEncoders:
        return FairFaceEncoders(
            age=LabelEncoder().fit(train_df["age"]),
            gender=LabelEncoder().fit(train_df["gender"]),
            race=LabelEncoder().fit(train_df["race"]),
        )

**FairFaceDataset**

In [30]:


class FairFaceDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        images_root: str | Path,
        encoders,
        *,
        mode: Literal["train", "eval"] = "train",
        train_tfms=None,
        eval_tfms=None,
    ):
        self.df = df.reset_index(drop=True).copy()
        self.root = Path(images_root)
        self.enc = encoders
        self.mode = mode

        if self.mode not in ("train", "eval"):
            raise ValueError(f"mode must be 'train' or 'eval', got: {self.mode!r}")
        if not self.root.exists():
            raise FileNotFoundError(f"images_root not found: {self.root}")

        required = {"file", "age", "gender", "race"}
        missing = required - set(self.df.columns)
        if missing:
            raise KeyError(f"Missing columns in df: {missing}. Have: {list(self.df.columns)}")

        # mode-specific transform requirement
        if self.mode == "train":
            if train_tfms is None:
                raise ValueError("mode='train' requires train_tfms (cannot be None).")
            self.img_tfms = train_tfms
        else:  # eval
            if eval_tfms is None:
                raise ValueError("mode='eval' requires eval_tfms (cannot be None).")
            self.img_tfms = eval_tfms

    def __len__(self) -> int:
        return len(self.df)

    def _img_path(self, rel_path: str) -> Path:
        return self.root / rel_path

    def _transform_image(self, img_path: Path) -> torch.Tensor:
        img_bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        if img_bgr is None:
            raise FileNotFoundError(f"Image not found or unreadable: {img_path}")
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        return self.img_tfms(img_rgb)  # exactly your pattern

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]

        rel_path = str(row["file"])
        img_path = self._img_path(rel_path)
        img_t = self._transform_image(img_path)

        y = {
            "age": torch.tensor(self.enc.age.encode_one(row["age"]), dtype=torch.long),
            "gender": torch.tensor(self.enc.gender.encode_one(row["gender"]), dtype=torch.long),
            "race": torch.tensor(self.enc.race.encode_one(row["race"]), dtype=torch.long),
        }

        meta = {"file": rel_path, "path": str(img_path), "mode": self.mode}
        
        return {"img_t": img_t, "y": y, "meta": meta}

In [31]:
# ==================== SETUP: Paths, Data, and Encoders ====================
from pathlib import Path

TARGET_H, TARGET_W = 224, 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

data_root = Path("/home/onkar/projects/data/fairface")
splits_dir = data_root / "splits"

ff_train_df = pd.read_csv(splits_dir / "ff_train.csv")
ff_val_df   = pd.read_csv(splits_dir / "ff_val.csv")
ff_test_df  = pd.read_csv(splits_dir / "ff_test.csv")

encoders = FairFaceEncoderBuilder.fit(ff_train_df)

# Standard evaluation transforms (no augmentation)
eval_tfms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

In [32]:
# ==================== SETUP ALL COMPONENTS FOR DATASETS & DATALOADERS ====================

# Step 0: Define DataLoader parameters
import os
import psutil

cpu_count = os.cpu_count() or 4
optimal_num_workers = min(8, max(2, cpu_count - 2))

common_params_optimized = {
    "batch_size": 256,
    "num_workers": optimal_num_workers,
    "pin_memory": True,
    "persistent_workers": True,
    "prefetch_factor": 4,
    "drop_last": True,
}

# Step 1: Define training transforms (Level 1 - recommended)
train_tfms_selected = T.Compose([
    T.ToTensor(),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# Step 2: Create datasets
train_ds = FairFaceDataset(ff_train_df, data_root, encoders, mode="train", train_tfms=train_tfms_selected)
val_ds   = FairFaceDataset(ff_val_df,   data_root, encoders, mode="eval",  eval_tfms=eval_tfms)
test_ds  = FairFaceDataset(ff_test_df,  data_root, encoders, mode="eval",  eval_tfms=eval_tfms)

print(f"‚úì Datasets created:")
print(f"  Train: {len(train_ds)} samples")
print(f"  Valid: {len(val_ds)} samples")
print(f"  Test:  {len(test_ds)} samples")
print()

# Step 3: Create DataLoaders
train_loader = DataLoader(train_ds, shuffle=True, **common_params_optimized)
valid_loader = DataLoader(val_ds, shuffle=False, **common_params_optimized)
test_loader  = DataLoader(test_ds,  shuffle=False, **common_params_optimized)

print("‚úì DataLoaders created successfully!")
print(f"  train_loader: {len(train_loader)} batches")
print(f"  valid_loader: {len(valid_loader)} batches")
print(f"  test_loader : {len(test_loader)} batches")

‚úì Datasets created:
  Train: 73273 samples
  Valid: 14655 samples
  Test:  9770 samples

‚úì DataLoaders created successfully!
  train_loader: 286 batches
  valid_loader: 57 batches
  test_loader : 38 batches


In [33]:
# ==================== DATALOADER PERFORMANCE TUNING ====================
# Optimal settings for modern GPUs with sufficient CPU cores
import os
import psutil

cpu_count = os.cpu_count() or 4
available_memory_gb = psutil.virtual_memory().available / (1024**3)

optimal_num_workers = min(8, max(2, cpu_count - 2))

print(f"System: {cpu_count} CPUs, {available_memory_gb:.1f} GB available RAM")
print(f"Optimal num_workers: {optimal_num_workers}")
print()

# Benchmark DataLoader parameters - DEFINE FIRST before using
common_params_optimized = {
    "batch_size": 256,
    "num_workers": optimal_num_workers,
    "pin_memory": True,
    "persistent_workers": True,
    "prefetch_factor": 4,
    "drop_last": True,
}

System: 16 CPUs, 12.0 GB available RAM
Optimal num_workers: 8



In [None]:
# ==================== LEVEL 1: QUICK WINS (Fast) ====================
# ~15-20% speedup: Lighter augmentations + higher prefetch
train_tfms_fast = T.Compose([
    T.ToTensor(),
    T.RandomHorizontalFlip(p=0.5),           # Keep (very cheap)
    T.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),  # Reduce intensity
    # REMOVE: RandomRotation (expensive)
    # REMOVE: RandomAffine (expensive)
    # REMOVE: RandomErasing (expensive)
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])


train_tfms_faster = T.Compose([
    T.ToTensor(),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.RandomRotation(degrees=10),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

train_tfms_selected = train_tfms_balanced  # USE THIS

print("=" * 60)
print("SELECTED: train_tfms_balanced (Level 1.5 - Anti-Overfitting)")
print("Stronger augmentations to combat overfitting")
print("=" * 60)

SELECTED: train_tfms_fast (Level 1 - Quick Wins)
Expected speedup: ~15-20% with maintained accuracy

To use faster versions, replace with:
  - train_tfms_faster  : ~30-40% faster (fewer augmentations)
  - train_tfms_fastest : ~50%+ faster (minimal augmentations, debug only)


### üöÄ Performance Optimization: Fast Training Setup

Three levels of optimizations for training speed:
1. **Quick Wins** (~15-20% faster): Reduce augmentations, increase prefetch
2. **Medium Effort** (~30-40% faster): Move heavy augments to GPU, optimize batch size
3. **Max Speed** (~50%+ faster): Minimal augmentations + larger batch size (if memory allows)

### 2.2.2.1 Backbone interface (input -> embedding)

In [35]:
from torchvision import models
from torchinfo import summary

model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)

summary = summary(model, input_size=(1, 3, 256, 512), # (B, C, H, W)
        col_names=("input_size", "output_size", "num_params"),
        depth=4)

print(summary)
print(model)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
ResNet                                   [1, 3, 256, 512]          [1, 1000]                 --
‚îú‚îÄConv2d: 1-1                            [1, 3, 256, 512]          [1, 64, 128, 256]         9,408
‚îú‚îÄBatchNorm2d: 1-2                       [1, 64, 128, 256]         [1, 64, 128, 256]         128
‚îú‚îÄReLU: 1-3                              [1, 64, 128, 256]         [1, 64, 128, 256]         --
‚îú‚îÄMaxPool2d: 1-4                         [1, 64, 128, 256]         [1, 64, 64, 128]          --
‚îú‚îÄSequential: 1-5                        [1, 64, 64, 128]          [1, 64, 64, 128]          --
‚îÇ    ‚îî‚îÄBasicBlock: 2-1                   [1, 64, 64, 128]          [1, 64, 64, 128]          --
‚îÇ    ‚îÇ    ‚îî‚îÄConv2d: 3-1                  [1, 64, 64, 128]          [1, 64, 64, 128]          36,864
‚îÇ    ‚îÇ    ‚îî‚îÄBatchNorm2d: 3-2             [1, 64, 64, 128]          [1, 64, 64, 12

In [None]:
class ResNetBackbone34(nn.Module):
    """
    ResNet-34 backbone as a pure feature extractor.

    Input : (B, 3, 224, 224)
    Output: (B, 512)
    """
    def __init__(self, pretrained: bool = True):
        super().__init__()

        weights = models.ResNet34_Weights.DEFAULT if pretrained else None
        self.model = models.resnet34(weights=weights)

        # Remove classifier
        self.out_dim_final = self.model.fc.in_features  # 512
        self.model.fc = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)  # already (B, 512)

class FairFaceMultiTaskModel(nn.Module):
    """
    ResNet-34 backbone + 3 classification heads (age / gender / race).
    Uses final pooled embedding (B, 512) for all heads.
    """
    def __init__(
        self,
        *,
        num_age_classes: int = 9,
        num_gender_classes: int = 2,
        num_race_classes: int = 7,
        pretrained: bool = True,
        freeze_backbone: bool = False,
        dropout_p: float = 0.3,  # INCREASE from 0.2
    ):
        super().__init__()

        self.backbone = ResNetBackbone34(pretrained=pretrained)

        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        out_final = self.backbone.out_dim_final  # 512

        # Optional shared projection ("intermediate") like your style
        self.shared = nn.Sequential(
            nn.Linear(out_final, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_p),
        )

        # Heads output logits (no activation)
        self.age_head = nn.Linear(512, num_age_classes)
        self.gender_head = nn.Linear(512, num_gender_classes)
        self.race_head = nn.Linear(512, num_race_classes)

    def forward(self, x: torch.Tensor):
        f_final = self.backbone(x)  # f_final: (B, 512)

        z = self.shared(f_final)

        out = {
            "age": self.age_head(z),
            "gender": self.gender_head(z),
            "race": self.race_head(z),
        }

        return out  # logits dict

In [37]:
def sanity_test_fairface_model(model: nn.Module, *, batch_size: int = 1, h: int = 224, w: int = 224, device=torch.device("cpu")) -> None:
    model = model.to(device).train()

    # 1) Forward shape check
    x = torch.randn(batch_size, 3, h, w, dtype=torch.float32, device=device)
    out = model(x)

    assert isinstance(out, dict), f"Expected dict output, got {type(out)}"
    assert "age" in out and "gender" in out and "race" in out, f"Missing keys: {out.keys()}"

    age_logits = out["age"]
    gender_logits = out["gender"]
    race_logits = out["race"]

    print("age logits   :", tuple(age_logits.shape), age_logits.dtype)
    print("gender logits:", tuple(gender_logits.shape), gender_logits.dtype)
    print("race logits  :", tuple(race_logits.shape), race_logits.dtype)

    # infer class counts from head modules (gold-standard, avoids hardcoding)
    num_age = model.age_head.out_features
    num_gender = model.gender_head.out_features
    num_race = model.race_head.out_features

    assert age_logits.shape == (batch_size, num_age)
    assert gender_logits.shape == (batch_size, num_gender)
    assert race_logits.shape == (batch_size, num_race)

    # 2) Backward / gradient flow check (uses fake labels)
    y_age = torch.randint(0, num_age, (batch_size,), device=device, dtype=torch.long)
    y_gender = torch.randint(0, num_gender, (batch_size,), device=device, dtype=torch.long)
    y_race = torch.randint(0, num_race, (batch_size,), device=device, dtype=torch.long)

    ce = nn.CrossEntropyLoss()
    loss = ce(age_logits, y_age) + ce(gender_logits, y_gender) + ce(race_logits, y_race)

    model.zero_grad(set_to_none=True)
    loss.backward()

    # confirm backbone gets gradients (at least one param)
    got_grad = False
    for p in model.backbone.parameters():
        if p.requires_grad and p.grad is not None:
            got_grad = True
            break

    print("loss:", float(loss.detach().cpu()))
    print("backbone received gradients:", got_grad)

    assert got_grad, "Backbone did not receive gradients (check detach / freezing)."

    print("Sanity test passed.")


In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FairFaceMultiTaskModel(pretrained=True, freeze_backbone=False)
sanity_test_fairface_model(model, device=device)

age logits   : (1, 9) torch.float32
gender logits: (1, 2) torch.float32
race logits  : (1, 7) torch.float32
loss: 4.936391830444336
backbone received gradients: True
Sanity test passed.


## üìä Quick Reference: Training Speed Optimization

### What Slows Down Training?

| Factor | Bottleneck | Severity |
|--------|-----------|----------|
| **CPU Augmentations** | RandomRotation, RandomAffine, RandomErasing | üî¥ HIGH |
| **Image Decoding** | cv2.imread with num_workers | üü° MEDIUM |
| **Batch Transfer** | CPU‚ÜíGPU memory copy | üü° MEDIUM |
| **GPU Utilization** | Model size vs batch size mismatch | üü¢ LOW |

### Applied Optimizations

‚úÖ **Batch 1: DataLoader**
- Dynamic `num_workers` based on CPU count
- `persistent_workers=True` (avoid restart overhead)
- `prefetch_factor=4` (prefetch batches)
- `pin_memory=True` (fast CPU‚ÜíGPU transfer)
- `drop_last=True` (avoid partial batches)

‚úÖ **Batch 2: Transforms (LEVEL 1)**
- Removed RandomRotation, RandomAffine, RandomErasing
- Kept RandomHorizontalFlip (negligible cost)
- Reduced ColorJitter intensity
- **Expected speedup: 15-20%**

‚úÖ **Batch 3: GPU**
- `torch.backends.cudnn.benchmark=True` (auto-tuning)
- `torch.set_float32_matmul_precision("high")` (Tensor Cores)

### If Still Slow

**Option A: More aggressive (LEVEL 2)**
```python
train_tfms_selected = train_tfms_faster  # Only horizontal flip
# Expected: 30-40% faster
```

**Option B: Maximum speed (LEVEL 3 - debugging only)**
```python
train_tfms_selected = train_tfms_fastest  # No augmentations
# Expected: 50%+ faster, but reduced regularization
```

**Option C: Hardware**
- Increase `num_workers` (if CPU headroom exists)
- Use faster storage (NVMe > SSD > HDD)
- Increase batch size (if GPU memory allows)

### Verify It Works

Run the **DATALOADER SPEED BENCHMARK** cell above to see batch loading time. Target: **<50ms per batch**.