# Multimodal Depression Detection: Full Model Pipeline

This notebook demonstrates a complete pipeline for multimodal depression detection using deep learning. The workflow includes:

- Cloning the project repository
- Installing required dependencies
- Mounting Google Drive and extracting the dataset
- Preprocessing and fixing dataset paths
- Creating custom dataset and model scripts
- Training and evaluating the model

---

---
### Clone Project Repository & Set Working Directory
This cell clones the Multimodal Depression project from GitHub and sets the working directory to the encoder folder.

In [None]:
!git clone https://github.com/Amcky/Multimodal-Depression.git
%cd /content/Multimodal-Depression/encoder

Cloning into 'Multimodal-Depression'...
remote: Enumerating objects: 20, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 20 (delta 2), reused 14 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (20/20), 24.78 KiB | 3.54 MiB/s, done.
Resolving deltas: 100% (2/2), done.
/content/Multimodal-Depression/encoder


---
### 🛠️ Install Required Python Packages
This cell installs all necessary Python libraries for model training and data processing.

In [None]:
!pip install torch torchvision pandas numpy tqdm pytorch_lightning lightning


Collecting pytorch_lightning
  Using cached pytorch_lightning-2.5.3-py3-none-any.whl.metadata (20 kB)
Collecting lightning
  Using cached lightning-2.5.3-py3-none-any.whl.metadata (39 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)

---
### Mount Google Drive
This cell mounts your Google Drive to access datasets and pre-trained weights.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


---
### 🔧 Additional Path Correction for CSV Files
This cell further corrects any missing or misaligned CSV file paths in the dataset.

In [None]:
import zipfile
import os

zip_path = '/content/drive/MyDrive/dataset.zip'
unzip_dir = '/content'
os.makedirs(unzip_dir, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(unzip_dir)
print("Unzipped dataset to", unzip_dir)

Unzipped dataset to /content


---
### Unzip Dataset
This cell extracts the dataset from Google Drive into the working directory.

In [None]:
!ls /content

dataset  drive	Multimodal-Depression  sample_data


---
### List Dataset Contents
This cell lists the files and folders in the dataset directory to verify extraction.

In [None]:
import pandas as pd
import os

csv_path = "/content/dataset/avec14/multimodal_labels.csv"
base_dir = "/content"

df = pd.read_csv(csv_path)
for col in ["frames_path", "faps_path", "rppg_path"]:
    df[col] = df[col].apply(lambda p: os.path.abspath(os.path.join(base_dir, p)) if not os.path.isabs(p) else p)

df.to_csv(csv_path, index=False)
print(df.head())


                                         frames_path  \
0  /content/dataset/avec14/frames_align/train_Nor...   
1  /content/dataset/avec14/frames_align/train_Fre...   
2  /content/dataset/avec14/frames_align/train_Nor...   
3  /content/dataset/avec14/frames_align/train_Fre...   
4  /content/dataset/avec14/frames_align/train_Nor...   

                                           faps_path  \
0  /content/dataset/avec14/faps/train_Northwind/2...   
1  /content/dataset/avec14/faps/train_Freeform/20...   
2  /content/dataset/avec14/faps/train_Northwind/2...   
3  /content/dataset/avec14/faps/train_Freeform/20...   
4  /content/dataset/avec14/faps/train_Northwind/2...   

                                           rppg_path  label  
0  /content/dataset/avec14/rppg_physformer/train_...      3  
1  /content/dataset/avec14/rppg_physformer/train_...      3  
2  /content/dataset/avec14/rppg_physformer/train_...      3  
3  /content/dataset/avec14/rppg_physformer/train_...      3  
4  /content/data

---
### Fix Dataset Paths in CSV
This cell updates the dataset CSV to ensure all file paths are absolute and correct for further processing.

In [None]:
import pandas as pd
import os

label_csv = "/content/dataset/avec14/multimodal_labels.csv"
output_csv = "/content/dataset/avec14/multimodal_labels_fixed.csv"

df = pd.read_csv(label_csv)

def fix_path(path):
    if not os.path.exists(path) and path.endswith(".csv"):
        base, ext = os.path.splitext(path)
        new_path = base + "_video_aligned" + ext
        if os.path.exists(new_path):
            return new_path
    return path

for col in df.columns:
    if "faps" in col or "rppg" in col:
        df[col] = df[col].apply(fix_path)

df.to_csv(output_csv, index=False)
print(f"Fixed CSV saved to {output_csv}")


Fixed CSV saved to /content/dataset/avec14/multimodal_labels_fixed.csv


---
### Create Custom Dataset Script
This cell writes the custom dataset class (`dataset.py`) used for loading and preprocessing multimodal data.

In [None]:
new_code = """
# /content/Multimodal-Depression/encoder/dataset.py
import os
import glob
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import pytorch_lightning as pl
import warnings

# ---------- Utility functions ----------
def load_image(path):
    return Image.open(path).convert("RGB")

def load_csv_features(path):
    df = pd.read_csv(path)
    arr = df.values.astype('float32')

    # Replace NaN/Inf with 0
    if not np.isfinite(arr).all():
        arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    # arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)

    # Standardize per sequence (T x D)
    if arr.shape[0] > 1:
        mean = arr.mean(axis=0, keepdims=True)
        std = arr.std(axis=0, keepdims=True) + 1e-6
        arr = (arr - mean) / std

    # Clip extreme values to prevent blow-ups
    clip_threshold = 10.0
    max_val = np.max(arr)
    min_val = np.min(arr)
    if max_val > clip_threshold or min_val < -clip_threshold:
        arr = np.clip(arr, -clip_threshold, clip_threshold)

    return torch.tensor(arr, dtype=torch.float32)

# Map between args.modalities and CSV column names
MODALITY_TO_COLUMN = {
    'frames_align': 'frames_path',
    'faps': 'faps_path',
    'rppg_physformer': 'rppg_path',
}

MODALITY_TYPES = {
    'frames_align': 'image',
    'faps': 'csv',
    'rppg_physformer': 'csv',
}

# ---------- Dataset ----------
class VideoDataset(Dataset):
    def __init__(self, args, labeldata, transform=None, stage='train'):
        self.args = args
        self.labeldata = labeldata.reset_index(drop=True)
        self.transform = transform
        self.modalities = args.modalities  # preserve order requested by the model

    def __len__(self):
        return len(self.labeldata)

    def __getitem__(self, idx):
        row = self.labeldata.iloc[idx]
        sample = {}

        for m in self.modalities:
            col = MODALITY_TO_COLUMN[m]
            path = row[col]

            if MODALITY_TYPES[m] == 'image':
                file_list = sorted(glob.glob(os.path.join(path, '*.jpg')))
                if not file_list:
                    raise FileNotFoundError(f"No .jpg files in {path}")
                img = load_image(file_list[0])  # first frame
                img = self.transform(img) if self.transform else T.ToTensor()(img)
                sample[m] = img  # [3,H,W]

            else:  # CSV modalities
                if not os.path.isfile(path):
                    raise FileNotFoundError(f"CSV file not found: {path}")
                csv_tensor = load_csv_features(path)  # [T,D]
                sample[m] = csv_tensor

        score = torch.tensor(float(row['label']), dtype=torch.float32)
        vid_id = os.path.basename(str(row[MODALITY_TO_COLUMN['frames_align']]))
        return sample, score, vid_id

# ---------- DataModule ----------
class VideoRegressionDataModule(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.batch_size = args.batch_size
        self.num_workers = args.num_workers
        self.labeldata = pd.read_csv(args.label_file)

        self.base_transform = T.Compose([
            T.Resize((112, 112)),  # resize first for consistency
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])

        self.train_dataset = VideoDataset(args, self.labeldata, transform=self.base_transform, stage='train')
        self.val_dataset = self.train_dataset
        self.test_dataset = self.val_dataset

    def _loader(self, ds, shuffle):
        return DataLoader(
            ds, batch_size=self.batch_size, shuffle=shuffle,
            num_workers=min(self.num_workers, 2),  # heed Colab worker warning
            pin_memory=True, collate_fn=self._collate
        )

    @staticmethod
    def _collate(batch):
        # batch: list of (sample_dict, score, vid_id)
        samples, scores, ids = zip(*batch)
        batch_dict = {}
        for k in samples[0].keys():
            v = samples[0][k]
            if isinstance(v, torch.Tensor) and v.ndim == 3:
                batch_dict[k] = torch.stack([s[k] for s in samples], dim=0)
            else:  # csv [T,D]
                ts = [s[k] for s in samples]
                T_max = max(t.shape[0] for t in ts)
                Ds = ts[0].shape[1]
                out = torch.zeros(len(ts), T_max, Ds, dtype=ts[0].dtype)
                for i, t in enumerate(ts):
                    out[i, :t.shape[0]] = t
                batch_dict[k] = out
        scores = torch.stack(scores, dim=0)  # [B]
        return batch_dict, scores, list(ids)

    def train_dataloader(self): return self._loader(self.train_dataset, True)
    def val_dataloader(self):   return self._loader(self.val_dataset, False)
    def test_dataloader(self):  return self._loader(self.test_dataset, False)

"""

with open("/content/Multimodal-Depression/encoder/dataset.py", "w") as f:
    f.write(new_code)


---
### Create Model Script
This cell writes the model architecture (`model.py`) for multimodal regression using deep learning.

In [None]:

new_code = """

import os
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torchmetrics import MeanAbsoluteError, MeanSquaredError
from backbones import iresnet
from torch.optim.lr_scheduler import LambdaLR

# -----------------------------
# Extra Metrics (PCC & CCC)
# -----------------------------
def pearson_corr(x, y, eps=1e-8):
    x = x.view(-1)
    y = y.view(-1)
    vx = x - x.mean()
    vy = y - y.mean()
    num = torch.sum(vx * vy)
    den = torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)) + eps
    return num / den

def concordance_corr(x, y, eps=1e-8):
    x = x.view(-1)
    y = y.view(-1)
    mx, my = x.mean(), y.mean()
    vx, vy = x.var(unbiased=True), y.var(unbiased=True)
    cov = ((x - mx) * (y - my)).mean()
    return (2 * cov) / (vx + vy + (mx - my) ** 2 + eps)

# -----------------------------
# Image Feature Extractor
# -----------------------------
class ImageFeatureExtractor(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.model = iresnet.iresnet50(pretrained=False)

        if not getattr(args, 'ablation', None) == 'no_pretrain':
            ckpt = getattr(args, "webface_ckpt", "/content/drive/MyDrive/webface_r50.pth")
            if getattr(args, 'pretrain', None) == 'webface' and os.path.isfile(ckpt):
                state = torch.load(ckpt, map_location='cpu')
                self.model.load_state_dict(state, strict=False)

        for m in self.model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

        self.out_dim = self.model.fc.out_features

    def forward(self, x):
        return self.model(x)

# -----------------------------
# CSV Feature Extractor
# -----------------------------
class CSVFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden=256, out_dim=128, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden, out_dim),
            nn.ReLU(inplace=True)
        )
        self.out_dim = out_dim

    def forward(self, x):
        if x.ndim == 3:
            x = x.mean(dim=1)  # temporal mean pooling
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return self.net(x)

# -----------------------------
# Cross-Modal Transformer Fusion
# -----------------------------
class CrossModalAttention(nn.Module):
    def __init__(self, dim, num_heads=6, num_layers=2, mlp_ratio=4.0,
                 dropout=0.1, attn_dropout=0.1, use_norm=True):
        super().__init__()
        self.use_norm = use_norm
        self.ln = nn.LayerNorm(dim) if use_norm else nn.Identity()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=num_heads,
            dim_feedforward=int(mlp_ratio * dim),
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        x = self.ln(x)
        x = self.encoder(x)
        return x

# -----------------------------
# Main Model
# -----------------------------
class VideoRegressionModel(LightningModule):
    def __init__(self, args, csv_input_dims=None):
        super().__init__()
        self.save_hyperparameters(ignore=['csv_input_dims'])
        self.args = args
        self.modalities = args.modalities

        attn_heads  = getattr(args, "attn_heads", 6)
        attn_layers = getattr(args, "attn_layers", 2)
        fusion_dim  = getattr(args, "fusion_dim", 512)

        if args.ablation == 'smaller_fusion_dim':
            fusion_dim = 128

        # -------------------------------
        # ABLATION: modality removal
        # -------------------------------
        if args.ablation == 'frames_only':
            self.modalities = ['frames_align']
        elif args.ablation == 'csv_only':
            self.modalities = [m for m in args.modalities if m in ('faps', 'rppg_physformer')]

        self.extractors = nn.ModuleDict()
        raw_dims = []
        for m in self.modalities:
            if m == 'frames_align':
                self.extractors[m] = ImageFeatureExtractor(args)
                raw_dims.append(self.extractors[m].out_dim)
            elif m in ('faps', 'rppg_physformer'):
                if csv_input_dims and m in csv_input_dims:
                    in_dim = csv_input_dims[m]
                else:
                    in_dim = 29 if m == 'faps' else 64
                self.extractors[m] = CSVFeatureExtractor(
                    input_dim=in_dim, hidden=256, out_dim=128, dropout=0.1
                )
                raw_dims.append(self.extractors[m].out_dim)

        # -------------------------------
        # ABLATION: projectors
        # -------------------------------
        self.projectors = nn.ModuleDict()
        if args.ablation == 'no_projectors':
            for m, d in zip(self.modalities, raw_dims):
                self.projectors[m] = nn.Identity()
        elif args.ablation == 'shared_projector':
            first_dim = raw_dims[0]
            shared_proj = nn.Sequential(nn.Linear(first_dim, fusion_dim), nn.GELU())
            for m in self.modalities:
                self.projectors[m] = shared_proj
        else:
            for m, d in zip(self.modalities, raw_dims):
                if d != fusion_dim:
                    proj = nn.Linear(d, fusion_dim)
                    nn.init.xavier_uniform_(proj.weight)
                    if proj.bias is not None:
                        nn.init.zeros_(proj.bias)
                    self.projectors[m] = nn.Sequential(proj, nn.GELU())
                else:
                    self.projectors[m] = nn.Identity()

        # -------------------------------
        # ABLATION: attention fusion
        # -------------------------------
        if args.ablation == 'no_attention':
            self.attention_fusion = nn.Identity()
        elif args.ablation == 'uni_modal_attention':
            self.attention_fusion = nn.ModuleList([
                CrossModalAttention(
                    dim=fusion_dim,
                    num_heads=attn_heads,
                    num_layers=attn_layers,
                    dropout=getattr(args, "attn_dropout", 0.1),
                    attn_dropout=getattr(args, "attn_dropout", 0.1),
                    use_norm=not (args.ablation == 'no_norm')
                ) for _ in self.modalities
            ])
        else:
            if args.ablation == 'attn_1layer':
                attn_layers = 1
            elif args.ablation == 'attn_2heads':
                attn_heads = 2
            use_norm = not (args.ablation == 'no_norm')
            self.attention_fusion = CrossModalAttention(
                dim=fusion_dim,
                num_heads=attn_heads,
                num_layers=attn_layers,
                mlp_ratio=4.0,
                dropout=getattr(args, "attn_dropout", 0.1),
                attn_dropout=getattr(args, "attn_dropout", 0.1),
                use_norm=use_norm
            )

        # -------------------------------
        # Regression head
        # -------------------------------
        if args.ablation == 'simple_regressor':
            self.regressor = nn.Linear(fusion_dim, 1)
        else:
            dropout_rate = 0.0 if args.ablation == 'no_dropout' else args.dropout_rate
            self.regressor = nn.Sequential(
                nn.LayerNorm(fusion_dim),
                nn.Linear(fusion_dim, 128),
                nn.GELU(),
                nn.Dropout(dropout_rate),
                nn.Linear(128, 1),
            )

        # -------------------------------
        # ABLATION: frozen extractors
        # -------------------------------
        if args.ablation == 'frozen_extractors':
            for p in self.extractors.parameters():
                p.requires_grad = False

        self.mae_metric = MeanAbsoluteError()
        self.mse_metric = MeanSquaredError()
        self.test_preds, self.test_tgts = [], []

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, batch_dict):
        tokens = []
        for m in self.modalities:
            feat = self.extractors[m](batch_dict[m])
            feat = torch.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0)
            feat = self.projectors[m](feat)
            tokens.append(feat)

        if self.args.ablation == 'uni_modal_attention':
            attn_outs = []
            for feat, attn in zip(tokens, self.attention_fusion):
                feat = feat.unsqueeze(1)
                attn_outs.append(attn(feat).squeeze(1))
            x = torch.stack(attn_outs, dim=1)
        else:
            x = torch.stack(tokens, dim=1)
            x = self.attention_fusion(x)

        fused = x.mean(dim=1)
        yhat = self.regressor(fused)
        return yhat

    def _step(self, batch, stage):
        x_dict, y, _ = batch
        y = y.view(-1, 1)
        y_hat = self.forward(x_dict)
        loss = F.mse_loss(y_hat, y)

        if torch.isnan(loss) or torch.isinf(loss):
            self.log(f"{stage}_loss", 0.0, prog_bar=True)
            return torch.tensor(0.0, requires_grad=True)

        self.log(f'{stage}_loss', loss, prog_bar=True, on_epoch=True, batch_size=y.size(0))
        self.log(f'{stage}_mae',  self.mae_metric(y_hat, y), prog_bar=True, on_epoch=True, batch_size=y.size(0))
        self.log(f'{stage}_rmse', torch.sqrt(self.mse_metric(y_hat, y)), prog_bar=False, on_epoch=True, batch_size=y.size(0))
        return loss

    def training_step(self, batch, batch_idx):
        return self._step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self._step(batch, 'val')

    def test_step(self, batch, batch_idx):
        loss = self._step(batch, 'test')
        x_dict, y, _ = batch
        y_hat = self.forward(x_dict).detach()
        self.test_preds.append(y_hat.cpu())
        self.test_tgts.append(y.view(-1,1).cpu())
        return loss

    def on_test_epoch_end(self):
        if self.test_preds:
            yhat = torch.cat(self.test_preds, dim=0)
            yt   = torch.cat(self.test_tgts, dim=0)
            pcc = pearson_corr(yhat, yt)
            ccc = concordance_corr(yhat, yt)
            self.log('test_pcc', pcc, prog_bar=True)
            self.log('test_ccc', ccc, prog_bar=True)

            total_params = sum(p.numel() for p in self.parameters())
            trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
            self.print(f"#Params total: {total_params/1e6:.2f}M, trainable: {trainable_params/1e6:.2f}M")

        self.test_preds.clear()
        self.test_tgts.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay
        )

        def lr_lambda(epoch):
            warmup = getattr(self.args, "warmup_epochs", 5)
            if epoch < warmup:
                return float(epoch + 1) / float(max(1, warmup))
            progress = (epoch - warmup) / float(max(1, self.args.max_epochs - warmup))
            return 0.5 * (1.0 + math.cos(progress * 3.1415926535))

        scheduler = {
            'scheduler': LambdaLR(optimizer, lr_lambda=lr_lambda),
            'interval': 'epoch',
            'frequency': 1,
        }
        return [optimizer], [scheduler]






"""

with open("/content/Multimodal-Depression/encoder/model.py", "w") as f:
    f.write(new_code)

---
### Create Main Training Script
This cell writes the main script (`main.py`) to train, validate, and test the model.

In [None]:

new_code = """
# /content/Multimodal-Depression/encoder/main.py
import glob, os, argparse, re
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger
from dataset import VideoRegressionDataModule
from model import VideoRegressionModel

seed_everything(1)

def find_min_mae_file(file_list):
    min_mae, min_file = float('inf'), None
    mae_pattern = re.compile(r'val_mae=([\\d.]+)')
    for fp in file_list:
        m = mae_pattern.search(fp)
        if m:
            v = float(m.group(1))
            if v < min_mae:
                min_mae, min_file = v, fp
    return min_file

def find_latest_checkpoint(latest_dir: str):
    last_ckpt = os.path.join(latest_dir, "last.ckpt")
    if os.path.isfile(last_ckpt):
        return last_ckpt
    ckpts = glob.glob(os.path.join(latest_dir, "*.ckpt"))
    if not ckpts:
        return None
    def epoch_num(p):
        m = re.search(r"epoch(\\d+)", os.path.basename(p))
        return int(m.group(1)) if m else -1
    ckpts.sort(key=epoch_num)
    return ckpts[-1] if ckpts else None

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='/content/dataset/avec14')
    parser.add_argument('--label_file', type=str, default='/content/dataset/avec14/multimodal_labels_fixed.csv')
    parser.add_argument('--train_data', nargs='+', default=['AVEC2014-train'])
    parser.add_argument('--val_data', nargs='+', default=['AVEC2014-dev'])
    parser.add_argument('--test_data', nargs='+', default=['AVEC2014-test'])
    parser.add_argument('--modalities', nargs='+', default=['frames_align', 'faps', 'rppg_physformer'])
    parser.add_argument('--num_frames', type=int, default=1)
    parser.add_argument('--frame_interval', type=int, default=1)
    parser.add_argument('--pretrain', type=str, default='webface')

    parser.add_argument('--save_dir', type=str,
                        default='/content/drive/MyDrive/Multimodal-Depression')

    parser.add_argument('--remove_rate', type=float, default=0.1)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--max_epochs', type=int, default=300)
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-2)
    parser.add_argument('--dropout_rate', type=float, default=0.7)
    parser.add_argument('--fusion_dim', type=int, default=512)
    parser.add_argument('--attn_heads', type=int, default=6)
    parser.add_argument('--attn_layers', type=int, default=2)
    parser.add_argument('--attn_dropout', type=float, default=0.1)
    parser.add_argument('--warmup_epochs', type=int, default=5)
    parser.add_argument('--no_resume', action='store_true', help='Disable auto-resume from latest/last.ckpt')

    parser.add_argument(
        '--ablation',
        type=str,
        default='none',
        choices=[
            'none', 'no_attention', 'attn_1layer',
            'no_projectors', 'shared_projector',
            'frozen_extractors',
            'simple_regressor',
            'no_pretrain', 'smaller_fusion_dim',
            'no_norm', 'uni_modal_attention'
        ],
        help='Ablation study mode'
    )
    args = parser.parse_args()
    args.save_dir = os.path.abspath(args.save_dir)
    best_dir   = os.path.join(args.save_dir, "best")
    latest_dir = os.path.join(args.save_dir, "latest")
    logs_dir   = os.path.join(args.save_dir, "logs")
    for d in [best_dir, latest_dir, logs_dir]:
       os.makedirs(d, exist_ok=True)
       print(f"Folder ready: {d}")

    datamodule = VideoRegressionDataModule(args)
    model = VideoRegressionModel(args, csv_input_dims={'faps': 29, 'rppg_physformer': 23})

    best_checkpoint_callback = ModelCheckpoint(
        dirpath=best_dir,
        monitor="val_mae",
        mode='min',
        save_top_k=1,
        filename='best-{epoch:03d}-{val_mae:.2f}-{val_rmse:.2f}'
    )
    # Save EVERY epoch + keep rolling "last.ckpt"
    latest_checkpoint_callback = ModelCheckpoint(
        dirpath=latest_dir,
        save_top_k=-1,
        every_n_epochs=50,
        save_last=True,
        filename='epoch{epoch:03d}'
    )

    print(f"Lightning will save BEST checkpoints to: {best_checkpoint_callback.dirpath}")
    print(f"Lightning will save LATEST checkpoints to: {latest_checkpoint_callback.dirpath}")
    print(f"Logger files will go to: {logs_dir}")

    early_stop_callback = EarlyStopping(monitor='val_loss', mode='min', patience=40)
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    trainer = Trainer(
        max_epochs=args.max_epochs,
        accelerator="auto",
        logger=CSVLogger(save_dir=logs_dir),
        callbacks=[best_checkpoint_callback, latest_checkpoint_callback, early_stop_callback, lr_monitor],
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm"
    )

    ckpt_path = None
    if not args.no_resume:
        ckpt_path = find_latest_checkpoint(latest_dir)
        if ckpt_path and os.path.isfile(ckpt_path):
            print(f">>> Resuming from checkpoint: {ckpt_path}")
        else:
            ckpt_path = None

    trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)

    test_ckpt = best_checkpoint_callback.best_model_path
    if not test_ckpt or not os.path.isfile(test_ckpt):
        test_ckpt = find_latest_checkpoint(latest_dir)

    if test_ckpt and os.path.isfile(test_ckpt):
        print(f">>> Testing with checkpoint: {test_ckpt}")
        best_model = VideoRegressionModel.load_from_checkpoint(test_ckpt, csv_input_dims={'faps': 29, 'rppg_physformer': 23}, strict=False)
        trainer.test(best_model, datamodule=datamodule)
    else:
        print(">>> No checkpoint found, testing current in-memory model.")
        trainer.test(model, datamodule=datamodule)








"""

with open("/content/Multimodal-Depression/encoder/main.py", "w") as f:
    f.write(new_code)

In [None]:
!python main.py \
  --data_dir /content/dataset/avec14 \
  --label_file /content/dataset/avec14/multimodal_labels_fixed.csv \
  --train_data AVEC2014-train \
  --val_data AVEC2014-dev \
  --test_data AVEC2014-test \
  --modalities frames_align faps \
  --num_frames 1 \
  --frame_interval 1 \
  --pretrain webface \
  --save_dir /content/drive/MyDrive/Multimodal-Depression \
  --remove_rate 0.1 \
  --batch_size 2 \
  --num_workers 4 \
  --max_epochs 300 \
  --learning_rate 1e-5 \
  --weight_decay 1e-4 \
  --dropout_rate 0.7 \
  --fusion_dim 512 \
  --attn_heads 8 \
  --attn_layers 4 \
  --attn_dropout 0.1 \
  --ablation none
