# DMNet Training on SUN RGB-D - Google Colab

**Complete end-to-end training pipeline for Direct Mixing ResNet (DMNet) on Google Colab with A100 GPU**

---

## üìã Checklist Before Running:

- [ ] **Enable A100 GPU:** Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator: GPU ‚Üí GPU type: A100
- [ ] **Mount Google Drive:** Your code and dataset will be stored on Drive
- [ ] **Upload dataset to Drive:** `MyDrive/datasets/sunrgbd_15/` (preprocessed 15-category dataset)
- [ ] **Expected Runtime:** ~2-3 hours for training

---

## üéØ What This Notebook Does:

1. ‚úÖ Verify A100 GPU is available
2. ‚úÖ Mount Google Drive
3. ‚úÖ Clone your repository to local disk (fast I/O)
4. ‚úÖ Copy SUN RGB-D dataset to local disk (10-20x faster than Drive)
5. ‚úÖ Install dependencies
6. ‚úÖ Train DMNet (Direct Mixing ResNet) with all optimizations
7. ‚úÖ Save checkpoints to Drive (persistent storage)
8. ‚úÖ Generate training curves and analysis

---

## üß† About DMNet:

**DMNet** (Direct Mixing Network) is a 2-stream neural network architecture where:
- **RGB stream** processes color images
- **Depth stream** processes depth maps
- **Integrated Stream** combines both streams using learned scalar mixing weights at every layer

Unlike traditional fusion methods, DMNet performs integration **inside each convolution neuron** through scalar-based direct mixing:
- Per-stream weights (full kernels for RGB and Depth)
- Integrated weight (1√ó1 channel-wise for integrated features)
- Scalar mixing coefficients (Œ±, Œ≥) learned per layer to combine stream outputs

This allows the network to learn optimal integration strategies at every layer with minimal computational overhead!

---

**Let's get started!** üöÄ

## 1. Environment Setup & GPU Verification

In [1]:
# Check GPU availability and specs
import torch
import subprocess

print("=" * 60)
print("GPU VERIFICATION")
print("=" * 60)

# Check PyTorch and CUDA
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

    # Check if it's A100
    gpu_name = torch.cuda.get_device_name(0)
    if 'A100' in gpu_name:
        print("\n‚úÖ A100 GPU detected - PERFECT for training!")
    elif 'V100' in gpu_name:
        print("\n‚úÖ V100 GPU detected - Good for training (slower than A100)")
    elif 'T4' in gpu_name:
        print("\n‚ö†Ô∏è  T4 GPU detected - Will be slower, consider upgrading to A100")
    else:
        print(f"\n‚ö†Ô∏è  GPU: {gpu_name} - Consider using A100 for best performance")
else:
    print("\n‚ùå NO GPU DETECTED!")
    print("Please enable GPU: Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator: GPU")
    raise RuntimeError("GPU is required for training")

print("\n" + "=" * 60)

GPU VERIFICATION
PyTorch version: 2.10.0+cu128
CUDA available: True
CUDA version: 12.8
GPU Device: NVIDIA A100-SXM4-80GB
GPU Memory: 79.25 GB

‚úÖ A100 GPU detected - PERFECT for training!



In [2]:
# Detailed GPU info
!nvidia-smi

Wed Feb 25 01:19:02 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   32C    P0             54W /  400W |       6MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

## 2. Mount Google Drive

In [3]:
from google.colab import drive
import os
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

print("\n‚úÖ Google Drive mounted successfully!")
print(f"\nDrive contents:")
!ls -la /content/drive/MyDrive/ | head -20

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

‚úÖ Google Drive mounted successfully!

Drive contents:
total 3079872
-rw------- 1 root root        176 Sep 21  2019 06-lab2.gdoc
-rw------- 1 root root      21621 Sep 30  2024 113-1363667-3121001@USSR24093000064918@pre-paid.png
-rw------- 1 root root        176 Aug 13  2020 2020 summer final (1).gdoc
-rw------- 1 root root        176 Aug 13  2020 2020 summer final (2).gdoc
-rw------- 1 root root        176 Aug 13  2020 2020 summer final (3).gdoc
-rw------- 1 root root        176 Aug 13  2020 2020 summer final.gdoc
-rw------- 1 root root        176 Jul 11  2025 2025_Gabriel_Clinger_Contractor Agreement_BASE copy.gdoc
-rw------- 1 root root      32204 Apr 18  2022 2900 On First- Welcome Home Next Steps.docx
-rw------- 1 root root       8822 Jun 24  2017 A6.docx
-rw------- 1 root root      22204 Jan 21  2023 activity (1).xlsx
-rw------- 1 root root      22161 

## 3. Clone Repository to Local Disk (Fast I/O)

**Important:** We clone to `/content/` (local SSD) instead of Drive for 10-20x faster I/O

**Default:** Clone from GitHub (recommended - always gets latest code)

In [4]:
import os
from pathlib import Path

# Configuration
PROJECT_NAME = "Multi-Stream-Neural-Networks"
GITHUB_REPO = "https://github.com/clingergab/Multi-Stream-Neural-Networks.git"  # UPDATE THIS
LOCAL_REPO_PATH = f"/content/{PROJECT_NAME}"  # Local copy for fast I/O

print("=" * 60)
print("REPOSITORY SETUP")
print("=" * 60)

# Ensure we're in a valid directory
os.chdir('/content')
print(f"Starting in: {os.getcwd()}")

# Check if repo already exists (same session, rerunning cell)
if Path(LOCAL_REPO_PATH).exists() and Path(f"{LOCAL_REPO_PATH}/.git").exists():
    print(f"\nüìÅ Repo already exists: {LOCAL_REPO_PATH}")
    print(f"üîÑ Pulling latest changes...")

    os.chdir(LOCAL_REPO_PATH)
    !git pull
    print("‚úÖ Repo updated")

# Clone from GitHub (first run)
else:
    # Remove old incomplete copy if exists
    if Path(LOCAL_REPO_PATH).exists():
        print(f"\nüóëÔ∏è  Removing incomplete repo copy...")
        !rm -rf {LOCAL_REPO_PATH}

    print(f"\nüîÑ Cloning from GitHub...")
    print(f"   Repo: {GITHUB_REPO}")
    print(f"   Destination: {LOCAL_REPO_PATH}")

    !git clone {GITHUB_REPO} {LOCAL_REPO_PATH}

    # Verify clone succeeded
    if not Path(LOCAL_REPO_PATH).exists():
        raise RuntimeError(f"Failed to clone repository to {LOCAL_REPO_PATH}")

    print("‚úÖ Repo cloned successfully")
    os.chdir(LOCAL_REPO_PATH)

# Verify repo structure
print(f"\nüìÇ Repository structure:")
!ls -la {LOCAL_REPO_PATH}

print(f"\n‚úÖ Working directory: {os.getcwd()}")

REPOSITORY SETUP
Starting in: /content

üìÅ Repo already exists: /content/Multi-Stream-Neural-Networks
üîÑ Pulling latest changes...
Already up to date.
‚úÖ Repo updated

üìÇ Repository structure:
total 84
drwxr-xr-x 12 root root  4096 Feb 25 01:18 .
drwxr-xr-x  1 root root  4096 Feb 25 01:18 ..
drwxr-xr-x  5 root root  4096 Feb 25 01:18 configs
drwxr-xr-x  2 root root  4096 Feb 25 01:18 data
drwxr-xr-x  2 root root  4096 Feb 25 01:18 docs
drwxr-xr-x  3 root root  4096 Feb 25 01:18 experiments
drwxr-xr-x  9 root root  4096 Feb 25 01:19 .git
-rw-r--r--  1 root root   732 Feb 25 01:18 .gitattributes
drwxr-xr-x  3 root root  4096 Feb 25 01:18 .github
-rw-r--r--  1 root root   768 Feb 25 01:18 .gitignore
-rw-r--r--  1 root root  1084 Feb 25 01:18 LICENSE
drwxr-xr-x  2 root root  4096 Feb 25 01:18 notebooks
-rw-r--r--  1 root root   198 Feb 25 01:18 pytest.ini
-rw-r--r--  1 root root  3884 Feb 25 01:18 README.md
-rw-r--r--  1 root root   126 Feb 25 01:18 requirements.txt
drwxr-xr-x  2 ro

## 4. Install Dependencies

In [5]:
# Install required packages
print("Installing dependencies...")

!pip install -q h5py tqdm matplotlib seaborn ray[tune] kornia

# Verify installations
import h5py
import tqdm
import matplotlib
import seaborn
import ray
import kornia

print("‚úÖ All dependencies installed!")
print(f"   h5py: {h5py.__version__}")
print(f"   matplotlib: {matplotlib.__version__}")
print(f"   ray: {ray.__version__}")
print(f"   kornia: {kornia.__version__}")


Installing dependencies...
‚úÖ All dependencies installed!
   h5py: 3.15.1
   matplotlib: 3.10.0
   ray: 2.54.0
   kornia: 0.8.2


## 5. Copy SUN RGB-D Dataset to Local Disk

**Performance Note:** Local disk I/O is ~10-20x faster than Drive!

**Dataset:** SUN RGB-D 15-category preprocessed dataset with RGB + Depth (~2.5 GB)

In [None]:
from pathlib import Path
import os

# Paths ‚Äî 3-way split (train/val/test) for standalone training
DRIVE_DATASET_TAR = "/content/drive/MyDrive/datasets/sunrgbd_15.tar.gz"  # Compressed file (2-stream: RGB + Depth)
LOCAL_DATASET_PATH = "/dev/shm/sunrgbd_15"  # Extracted location

# Paths ‚Äî trainval (train/test only, no val) for k-fold CV in Ray Tune HPO
DRIVE_TRAINVAL_TAR = "/content/drive/MyDrive/datasets/sunrgbd_15_trainval.tar.gz"
LOCAL_TRAINVAL_PATH = "/dev/shm/sunrgbd_15_trainval"

print("=" * 60)
print("SUN RGB-D 15-CATEGORY DATASET SETUP (2-STREAM: RGB + DEPTH)")
print("=" * 60)

def extract_dataset(drive_tar, local_path, label):
    """Extract a dataset tarball from Drive to local disk if needed."""
    if Path(local_path).exists():
        print(f"\n[{label}] Already on local disk: {local_path}")
        train_rgb_count = len(list(Path(f"{local_path}/train/rgb").glob("*.png")))
        print(f"   Train samples: {train_rgb_count}")
        if Path(f"{local_path}/val").exists():
            val_rgb_count = len(list(Path(f"{local_path}/val/rgb").glob("*.png")))
            print(f"   Val samples: {val_rgb_count}")
        return True

    if Path(drive_tar).exists():
        print(f"\n[{label}] Found on Drive: {drive_tar}")
        print(f"   Copying compressed file to local disk...")

        tar_name = Path(drive_tar).name
        local_tar = f"/dev/shm/{tar_name}"

        !rsync -ah --info=progress2 {drive_tar} {local_tar}

        print(f"\n   Extracting dataset to local disk...")
        !tar -xzf {local_tar} -C /dev/shm/ 2>&1 | grep -v "Ignoring unknown extended header"

        !rm {local_tar}

        train_rgb_count = len(list(Path(f"{local_path}/train/rgb").glob("*.png")))
        print(f"   Extracted. Train samples: {train_rgb_count}")
        return True
    else:
        print(f"\n[{label}] NOT FOUND on Drive: {drive_tar}")
        return False

# Extract both datasets
extract_dataset(DRIVE_DATASET_TAR, LOCAL_DATASET_PATH, "3-way split")
ok = extract_dataset(DRIVE_TRAINVAL_TAR, LOCAL_TRAINVAL_PATH, "trainval (k-fold)")

if not ok:
    print(f"\n   To create trainval dataset:")
    print(f"   1. Run: python3 scripts/preprocess_sunrgbd_15.py --no-val-split")
    print(f"   2. Create tarball: tar -czf sunrgbd_15_trainval.tar.gz -C data sunrgbd_15_trainval")
    print(f"   3. Upload to Google Drive at: {DRIVE_TRAINVAL_TAR}")

print("\n" + "=" * 60)
print(f"3-way dataset: {LOCAL_DATASET_PATH}")
print(f"Trainval dataset: {LOCAL_TRAINVAL_PATH}")
print("=" * 60)

## 6. Setup Python Path & Import DMNet

In [7]:
import sys
import os

# Remove cached modules
modules_to_reload = [k for k in sys.modules.keys() if k.startswith('src.')]
for module in modules_to_reload:
    del sys.modules[module]

# Add project to Python path
project_root = '/content/Multi-Stream-Neural-Networks'
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Verify project structure
print("Project structure:")
!ls -la {project_root}/src/models/

# Import LiNet and SUN RGB-D dataloader
print("\nImporting LiNet and dataloaders...")
from src.models.linear_integration.li_net3 import li_resnet18
from src.data_utils.sunrgbd_dataset import get_sunrgbd_dataloaders
from src.training.augmentation_config import AugmentationConfig


# Import Ray Tune
from ray import train, tune
from ray.tune.schedulers import ASHAScheduler

print("‚úÖ LINet3, dataloaders, and Ray Tune imported successfully!")

Project structure:
total 48
drwxr-xr-x 11 root root 4096 Feb 25 01:18 .
drwxr-xr-x  7 root root 4096 Feb 25 01:18 ..
drwxr-xr-x  2 root root 4096 Feb 25 01:18 abstracts
drwxr-xr-x  2 root root 4096 Feb 25 01:18 common
drwxr-xr-x  2 root root 4096 Feb 25 01:18 core
drwxr-xr-x  2 root root 4096 Feb 25 01:18 direct_mixing_activation
drwxr-xr-x  2 root root 4096 Feb 25 01:18 direct_mixing_bn
drwxr-xr-x  2 root root 4096 Feb 25 01:18 direct_mixing_conv
-rw-r--r--  1 root root 1076 Feb 25 01:18 __init__.py
drwxr-xr-x  4 root root 4096 Feb 25 01:18 linear_integration
drwxr-xr-x  2 root root 4096 Feb 25 01:18 multi_channel
drwxr-xr-x  2 root root 4096 Feb 25 01:18 utils

Importing LiNet and dataloaders...
‚úÖ LINet3, dataloaders, and Ray Tune imported successfully!


In [8]:
# from scripts.benchmark_padding_vs_sequential import run_benchmark_suite
# run_benchmark_suite([3, 1], batch_size=96, use_torch_compile=True)

In [9]:
# from scripts.benchmark_padding_vs_sequential import run_benchmark_suite
# run_benchmark_suite([3, 1], batch_size=96)

In [10]:
# # Set random seed for reproducibility
# from src.utils.seed import set_seed

# SEED = 42
# DETERMINISTIC = False  # False = faster, True = fully reproducible

# print("=" * 60)
# print("RANDOM SEED CONFIGURATION")
# print("=" * 60)

# set_seed(SEED, deterministic=DETERMINISTIC)

# print(f"\n‚úÖ Seed: {SEED}")
# print(f"   Deterministic: {DETERMINISTIC}")
# if DETERMINISTIC:
#     print("   Mode: Fully reproducible (slower)")
# else:
#     print("   Mode: Fast reproducible")

# print("\n" + "=" * 60)

## 7. Load SUN RGB-D Dataset

In [11]:
# # Verify dataset structure
# from pathlib import Path

# print("=" * 60)
# print("DATASET STRUCTURE VERIFICATION")
# print("=" * 60)

# dataset_root = Path(LOCAL_DATASET_PATH)

# print("\nDirectory structure:")
# print(f"  {dataset_root}/")
# print(f"    train/")
# print(f"      rgb/ - {len(list((dataset_root / 'train' / 'rgb').glob('*.png')))} images")
# print(f"      depth/ - {len(list((dataset_root / 'train' / 'depth').glob('*.png')))} images")
# print(f"      labels.txt")
# print(f"    val/")
# print(f"      rgb/ - {len(list((dataset_root / 'val' / 'rgb').glob('*.png')))} images")
# print(f"      depth/ - {len(list((dataset_root / 'val' / 'depth').glob('*.png')))} images")
# print(f"      labels.txt")
# print(f"    class_names.txt")
# print(f"    dataset_info.txt")

# # Read class names
# with open(dataset_root / 'class_names.txt', 'r') as f:
#     class_names = [line.strip() for line in f]

# print(f"\nClasses ({len(class_names)}):")
# for i, name in enumerate(class_names):
#     print(f"  {i}: {name}")

# print("\n" + "=" * 60)

In [12]:
# print("=" * 60)
# print("LOADING SUN RGB-D 15-CATEGORY DATASET (2-STREAM: RGB + DEPTH)")
# print("=" * 60)

# # Dataset configuration
# DATASET_CONFIG = {
#     'data_root': LOCAL_DATASET_PATH,
#     'batch_size': 96,  # Good balance for A100
#     'num_workers': 8,
#     'target_size': (416, 544),
#     'num_classes': 15,  # SUN RGB-D merged to 15 categories (labels 0-14)
#     'seed': SEED  # For reproducible data loading
# }

# # Augmentation configuration (per-stream control)
# # Set to 1.0 for baseline behavior, adjust to tune augmentation strength
# AUGMENTATION_CONFIG = AugmentationConfig(
#     rgb_aug_prob=1.0,    # Scales probability of RGB augmentations
#     rgb_aug_mag=1.0,     # Scales magnitude of RGB augmentations
#     depth_aug_prob=1.0,  # Scales probability of Depth augmentations
#     depth_aug_mag=1.0,   # Scales magnitude of Depth augmentations
# )

# print(f"Configuration:")
# for key, value in DATASET_CONFIG.items():
#     print(f"  {key}: {value}")

# print(f"\nAugmentation Configuration:")
# print(f"  {AUGMENTATION_CONFIG}")

# print(f"\nLoading dataset from: {DATASET_CONFIG['data_root']}")

# # Create reproducible dataloaders
# # normalize=False because GPU augmentation will handle normalization after augmentation
# train_loader, val_loader = get_sunrgbd_dataloaders(
#     data_root=DATASET_CONFIG['data_root'],
#     batch_size=DATASET_CONFIG['batch_size'],
#     num_workers=DATASET_CONFIG['num_workers'],
#     target_size=DATASET_CONFIG['target_size'],
#     seed=DATASET_CONFIG['seed'],  # Pass seed for reproducibility
#     normalize=False,  # GPU will normalize after augmentation
#     **AUGMENTATION_CONFIG.to_dict(),  # Pass augmentation params
# )

# print(f"\n‚úÖ Dataset loaded successfully!")
# print(f"\nDataset Statistics:")
# print(f"  Train batches: {len(train_loader)}")
# print(f"  Val batches: {len(val_loader)}")
# print(f"  Train samples: {len(train_loader.dataset)}")
# print(f"  Val samples: {len(val_loader.dataset)}")
# print(f"  Batch size: {DATASET_CONFIG['batch_size']}")

# # Test loading a batch
# print(f"\nTesting batch loading...")
# rgb_batch, depth_batch, label_batch = next(iter(train_loader))
# print(f"  RGB shape: {rgb_batch.shape}")
# print(f"  Depth shape: {depth_batch.shape}")
# print(f"  Labels shape: {label_batch.shape}")
# print(f"  Labels min: {label_batch.min().item()}, max: {label_batch.max().item()}")

# print("\n" + "=" * 60)

## 8. Visualize Sample Data

Shows RGB images, depth maps, and scene labels from the dataset

In [13]:
# import matplotlib.pyplot as plt
# import numpy as np
# import torch

# # Visualize some samples from TRAINING set
# # Note: With normalize=False, data is in [0, 1] range (no denormalization needed)
# # GPU augmentation will normalize on-device during training, but raw data is [0, 1]
# print("Loading samples from TRAINING set...")
# rgb_batch, depth_batch, label_batch = next(iter(train_loader))

# print("\n" + "="*60)
# print("Creating visualization...")
# print("="*60 + "\n")

# fig, axes = plt.subplots(2, 4, figsize=(14, 7))

# for i in range(4):
#     rgb = rgb_batch[i].cpu()
#     depth = depth_batch[i].cpu()
#     label = label_batch[i].item()

#     # Data is already in [0, 1] range (normalize=False in dataloader)
#     # Clamp to handle any edge cases
#     rgb_vis = torch.clamp(rgb, 0, 1)
#     depth_vis = torch.clamp(depth, 0, 1)

#     # Plot RGB
#     axes[0, i].imshow(rgb_vis.permute(1, 2, 0))
#     axes[0, i].set_title(f"RGB - Class {label}", fontsize=10)
#     axes[0, i].axis('off')

#     # Plot Depth
#     axes[1, i].imshow(depth_vis.squeeze(), cmap='viridis')
#     axes[1, i].set_title(f"Depth - Class {label}", fontsize=10)
#     axes[1, i].axis('off')

# plt.suptitle('SUN RGB-D Training Data (RGB + Depth)', fontsize=14, fontweight='bold')
# plt.tight_layout()
# plt.show()

# print("‚úÖ Sample visualization complete!")
# print("\nNote: Data is shown in raw [0, 1] range (before GPU augmentation/normalization).")
# print("During training, GPU augmentation applies: color jitter, blur, grayscale,")
# print("normalization, and random erasing on-device.")


## 8b. Hyperparameter Tuning with Ray Tune (Optional)

Perform a wide search for optimal hyperparameters using Ray Tune.
- **Parallel Trials:** Run multiple configurations simultaneously
- **Data Subset:** Use 50% of data for fast iteration
- **Short Duration:** Train for 10 epochs per trial
- **ASHA Scheduler:** Early stopping for bad trials
- **Uses fit() method:** Ensures consistency with main training (no custom training loop!)

In [14]:
import os
import time

# 1. Define Paths explicitly
mps_pipe_dir = "/tmp/nvidia-mps"
mps_log_dir = "/tmp/nvidia-log"

# 2. Create the directories (CRITICAL: Daemon fails if log dir doesn't exist)
os.makedirs(mps_pipe_dir, exist_ok=True)
os.makedirs(mps_log_dir, exist_ok=True)

# 3. Set Environment Variables for the current Python process
os.environ["CUDA_MPS_PIPE_DIRECTORY"] = mps_pipe_dir
os.environ["CUDA_MPS_LOG_DIRECTORY"] = mps_log_dir
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

# 4. Configure GPU and Start Daemon using the SAME environment variables
# We use f-strings to pass the python variables into the shell command
print("Setting GPU to Exclusive Process Mode...")
!nvidia-smi -i 0 -c EXCLUSIVE_PROCESS

print("Starting MPS Daemon...")
# We explicitly pass the env vars to the shell command
!export CUDA_MPS_PIPE_DIRECTORY={mps_pipe_dir} && \
 export CUDA_MPS_LOG_DIRECTORY={mps_log_dir} && \
 nvidia-cuda-mps-control -d

# 5. Verify it is running
print("Verifying Daemon Status...")
time.sleep(1) # Give it a second to start
!ps -ef | grep mps

# Check if the pipe file actually exists
if os.path.exists(os.path.join(mps_pipe_dir, "control")):
    print("‚úÖ MPS Control Pipe found. Setup success.")
else:
    print("‚ùå MPS Control Pipe NOT found. Check /tmp/nvidia-log for errors.")
    # Optional: Print logs if it failed
    !cat {mps_log_dir}/control.log

Setting GPU to Exclusive Process Mode...
Set compute mode to EXCLUSIVE_PROCESS for GPU 00000000:00:05.0.
All done.
Starting MPS Daemon...
Verifying Daemon Status...
root        5290       1  0 01:21 ?        00:00:00 nvidia-cuda-mps-control -d
root        5296    4513  0 01:21 ?        00:00:00 /bin/bash -c ps -ef | grep mps
root        5298    5296  0 01:21 ?        00:00:00 grep mps
‚úÖ MPS Control Pipe found. Setup success.


In [None]:
import random
import numpy as np

import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
import torch
from collections import Counter
from sklearn.model_selection import StratifiedKFold

from src.models.linear_integration.li_net3 import li_resnet18
from src.training.optimizers import create_stream_optimizer
from src.training.schedulers import setup_scheduler
from src.data_utils.sunrgbd_dataset import SUNRGBDDataset
from src.training.augmentation_config import AugmentationConfig
from src.utils.seed import set_seed


class TrialTerminated(Exception):
    """Raised when a trial should be terminated early."""
    pass

class RayTuneReporter:
    """Callback for reporting metrics to Ray Tune during training."""

    def __init__(self, fold_idx=None):
        self.best_accuracy = 0.0
        self.best_loss = float('inf')
        self.fold_idx = fold_idx

    def on_epoch_end(self, epoch, logs):
        """Report current AND best metrics to Ray Tune."""
        # Track best metrics
        if logs['val_accuracy'] > self.best_accuracy:
            self.best_accuracy = logs['val_accuracy']
        if logs['val_loss'] < self.best_loss:
            self.best_loss = logs['val_loss']

        # Report both current and best metrics
        metrics = {
            "accuracy": logs['val_accuracy'],        # Current epoch
            "loss": logs['val_loss'],                # Current epoch
            "best_accuracy": self.best_accuracy,     # Best so far
            "best_loss": self.best_loss,             # Best so far
            "train_loss": logs['train_loss'],
            "train_accuracy": logs['train_accuracy'],
        }
        if self.fold_idx is not None:
            metrics["fold"] = self.fold_idx

        tune.report(metrics)

        # Early termination for bad trials
        if epoch == 31 and self.best_accuracy < 0.6:
            raise TrialTerminated(
                f"Trial terminated: best_accuracy={self.best_accuracy:.1%} < 0.6 "
                f"at epoch 31"
            )
        if epoch == 81 and self.best_accuracy < 0.7:
            raise TrialTerminated(
                f"Trial terminated: best_accuracy={self.best_accuracy:.1%} < 0.7 "
                f"at epoch 81"
            )


def train_linet_tune(config, data_root=None, target_size=None, seed=42):
    """
    Trainable function for Ray Tune using fit() method with k-fold CV.

    Each trial trains on 1 randomly-assigned fold (out of 5). The fold index
    is assigned BEFORE set_seed() so that different trials get different folds
    (set_seed() would make random.randint deterministic otherwise).

    Args:
        config: Ray Tune configuration dict with hyperparameters
        data_root: Path to trainval dataset root (no val/ split)
        target_size: Target image size (H, W)
        seed: Random seed for reproducible trials
    """
    # CRITICAL: Assign fold BEFORE set_seed() ‚Äî set_seed() seeds Python's
    # random module, so random.randint() after it returns the same fold
    # for every trial. This must come first for true randomness.
    fold_idx = random.randint(0, 4)

    # Seed this worker process for reproducibility (after fold assignment)
    set_seed(seed, deterministic=False)
    g = torch.Generator().manual_seed(seed)

    # Per-trial augmentation config
    aug_config = AugmentationConfig(
        rgb_aug_prob=config.get("rgb_aug_prob", 1.0),
        rgb_aug_mag=config.get("rgb_aug_mag", 1.0),
        depth_aug_prob=config.get("depth_aug_prob", 1.0),
        depth_aug_mag=config.get("depth_aug_mag", 1.0),
    )

    # Two dataset instances from the same train/ directory:
    # 1) train_dataset: augmentation ON (split='train')
    # 2) val_dataset:   augmentation OFF (split overridden to 'val')
    # Both load from the same trainval train/ dir ‚Äî mmap shares OS pages.
    train_dataset = SUNRGBDDataset(
        data_root=data_root,
        split='train',
        target_size=target_size,
        normalize=False,  # GPU will normalize after augmentation
        **aug_config.to_dict(),
    )
    val_dataset = SUNRGBDDataset(
        data_root=data_root,
        split='train',
        target_size=target_size,
        normalize=False,
    )
    val_dataset.split = 'val'  # Disable augmentation in __getitem__

    # K-fold split (deterministic given seed ‚Äî same fold definitions every trial)
    all_labels = train_dataset.labels
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    folds = list(skf.split(range(len(all_labels)), all_labels))
    train_indices, val_indices = folds[fold_idx]

    # Create fold subsets
    train_subset = torch.utils.data.Subset(train_dataset, train_indices)
    val_subset = torch.utils.data.Subset(val_dataset, val_indices)

    # Stratified sampling for training fold
    subset_labels = [all_labels[i] for i in train_indices]
    label_counts = Counter(subset_labels)

    # Compute class weights (inverse frequency)
    num_samples = len(subset_labels)
    class_weights = {label: num_samples / count for label, count in label_counts.items()}
    sample_weights = torch.tensor([class_weights[label] for label in subset_labels], dtype=torch.float32)

    # Create sampler
    train_sampler = torch.utils.data.WeightedRandomSampler(
        weights=sample_weights,
        num_samples=num_samples,
        replacement=True,
        generator=g
    )

    # Worker init function for reproducible data loading
    def worker_init_fn(worker_id):
        worker_seed = seed + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    # Create subset dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=96,
        shuffle=False,  # Disabled when using sampler
        sampler=train_sampler,
        num_workers=2,
        prefetch_factor=2,
        persistent_workers=True,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    val_loader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=96,
        shuffle=False,
        num_workers=2,
        prefetch_factor=2,
        persistent_workers=False,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    # 2. Create Model (2-stream: RGB + Depth)
    model = li_resnet18(
        num_classes=15,
        stream_input_channels=[3, 1],  # RGB=3, Depth=1
        dropout_p=config["dropout_p"],
        device="cuda",
        use_amp=True
    )

    # 3. Create Optimizer with stream-specific learning rates
    optimizer = create_stream_optimizer(
        model,
        optimizer_type='adamw',
        stream_lrs=[config["lr_rgb"], config["lr_depth"]],
        stream_weight_decays=[config["wd_rgb"], config["wd_depth"]],
        shared_lr=config["lr_shared"],
        shared_weight_decay=config["wd_shared"]
    )

    # 4. Create Scheduler
    warmup_epochs = 5

    scheduler = setup_scheduler(
        optimizer,
        scheduler_type='cosine',
        eta_min=[config['s1_eta_min'], config['s2_eta_min'], config['eta_min']],
        t_max=config['t_max'],
        train_loader_len=len(train_loader),
        warmup_epochs=warmup_epochs,
        warmup_start_factor=0.2
    )

    # 5. Compile model (Keras-style API)
    model.compile(
        optimizer=optimizer,
        scheduler=scheduler,
        loss='cross_entropy',
        label_smoothing=config["label_smoothing"],
        gpu_augmentation=True,
        **aug_config.to_dict(),
    )

    # 6. Train using fit() with Ray Tune callback
    try:
        model.fit(
            train_loader=train_loader,
            val_loader=val_loader,
            epochs=120,
            early_stopping=True,
            patience=15,
            grad_clip_norm=config["grad_clip_norm"],
            modality_dropout=True,
            modality_dropout_start=0,
            modality_dropout_ramp=20,
            modality_dropout_rate=config['modality_dropout_rate'],
            callbacks=[RayTuneReporter(fold_idx=fold_idx)],
            verbose=False
        )
    except TrialTerminated as e:
        print(f"\n{e}")

In [None]:
# =============================================================================
# WARM-START CONFIGURATION (Optional)
# =============================================================================
# Enable warm-starting to continue exploration from previous runs.
# HyperOptSearch's full TPE model state is saved/restored via cloudpickle,
# so the surrogate model retains all (config, result) observations across
# Colab sessions. If the search space changes, the checkpoint is ignored
# and exploration starts fresh.
#
# IMPORTANT: Configs are tagged with a search space hash. If you change your
# search space, old data is automatically ignored (fresh exploration starts).
# =============================================================================

import hashlib
import json as json_module  # avoid conflict with pandas
import os
import pandas as pd
from pathlib import Path
from ray.tune.search.sample import Domain

WARM_START_ENABLED = True  # Set to True to load previous results
WARM_START_CSV_PATH = "/content/drive/MyDrive/ray_tune_results/ray_tune_results.csv"  # Path on Google Drive (persistent)
EPOCH_HISTORY_CSV_PATH = "/content/drive/MyDrive/ray_tune_results/epoch_history.csv"  # Per-epoch metrics for HistoricalMedianStoppingRule
HYPEROPT_CHECKPOINT_DIR = "/content/drive/MyDrive/ray_tune_results"  # Directory for TPE model checkpoints


def get_search_space_hash(search_space: dict) -> str:
    """
    Generate a short hash to identify a search space configuration.

    This allows us to track which configs came from which search space,
    so we only warm-start from compatible configs.

    Args:
        search_space: Ray Tune search space dict

    Returns:
        8-character hash string
    """

    def _serialize_value(v):
        if isinstance(v, Domain):
            if hasattr(v, "categories"):
                return sorted([repr(c) for c in v.categories])
            # Include sampler type to distinguish loguniform from uniform etc.
            sampler_name = type(v.sampler).__name__ if hasattr(v, "sampler") else ""
            domain_str = repr(v.domain_str) if hasattr(v, "domain_str") else type(v).__name__
            return f"{sampler_name}:{domain_str}"
        return repr(v)

    space_repr = {k: _serialize_value(v) for k, v in sorted(search_space.items())}
    space_str = json_module.dumps(space_repr, sort_keys=True)
    return hashlib.md5(space_str.encode()).hexdigest()[:8]


# Ensure the directory exists on Google Drive
Path(WARM_START_CSV_PATH).parent.mkdir(parents=True, exist_ok=True)

# Note: Warm-start loading happens AFTER search_space is defined (in next cell)
# We just set the flag here
print(f"Warm-start: {'ENABLED' if WARM_START_ENABLED else 'DISABLED'}")
if WARM_START_ENABLED:
    print(f"   Results CSV: {WARM_START_CSV_PATH}")
    print(f"   Epoch history: {EPOCH_HISTORY_CSV_PATH}")
    print(f"   HyperOpt checkpoint dir: {HYPEROPT_CHECKPOINT_DIR}")
else:
    print("   To enable: set WARM_START_ENABLED = True")

In [None]:
# Initialize Ray
from ray.tune.search.hyperopt import HyperOptSearch
from ray.tune.search import ConcurrencyLimiter
from src.training.historical_median_stopping import HistoricalMedianStoppingRule

ray.shutdown()  # Clean shutdown of any previous Ray instance
ray.init(
    ignore_reinit_error=True,
    runtime_env={
        "env_vars": {
            "CUDA_MPS_PIPE_DIRECTORY": "/tmp/nvidia-mps",
            "CUDA_MPS_LOG_DIRECTORY": "/tmp/nvidia-log",
            "CUDA_DEVICE_ORDER": "PCI_BUS_ID",
            # Ensure workers see the GPU as Device 0
            "CUDA_VISIBLE_DEVICES": "0"
        }
    }
)

SEED = 42

# Continuous search space for TPE optimization (k-fold CV)
search_space = {
    # Learning rates (log-uniform for order-of-magnitude exploration)
    "lr_rgb": tune.loguniform(5e-6, 1e-3),
    "lr_depth": tune.loguniform(2e-5, 1e-3),
    "lr_shared": tune.loguniform(1e-6, 5e-5),

    # Weight decay (log-uniform)
    "wd_rgb": tune.loguniform(1e-6, 5e-4),
    "wd_depth": tune.loguniform(5e-6, 1e-3),
    "wd_shared": tune.loguniform(1e-4, 1e-2),

    # Scheduler eta_min (log-uniform)
    "s1_eta_min": tune.loguniform(5e-8, 2e-6),
    "s2_eta_min": tune.loguniform(1e-7, 3e-6),
    "eta_min": tune.loguniform(1e-8, 1e-6),

    # Scheduler t_max (integer)
    "t_max": tune.randint(80, 111),

    # Regularization (uniform)
    "dropout_p": tune.uniform(0.3, 0.7),
    "label_smoothing": tune.uniform(0.01, 0.15),
    "grad_clip_norm": tune.uniform(0.5, 2.0),

    # Augmentation parameters (per-stream control)
    "rgb_aug_prob": tune.uniform(0.9, 1.8),
    "rgb_aug_mag": tune.uniform(0.9, 1.8),
    "depth_aug_prob": tune.uniform(0.9, 1.8),
    "depth_aug_mag": tune.uniform(0.9, 1.8),

    # Modality dropout
    "modality_dropout_rate": tune.uniform(0.1, 0.3),
}

print("=" * 60)
print("K-FOLD CV SEARCH SPACE (CONTINUOUS, 2-STREAM: RGB + DEPTH)")
print("=" * 60)

# HyperOptSearch uses best_accuracy for the TPE surrogate model
# (so it optimizes for running-best, not noisy per-epoch accuracy)
hyperopt_searcher = HyperOptSearch(
    metric="best_accuracy",
    mode="max",
    random_state_seed=SEED,
)

if WARM_START_ENABLED:
    current_hash = get_search_space_hash(search_space)
    hyperopt_checkpoint_path = os.path.join(HYPEROPT_CHECKPOINT_DIR, f"hyperopt_searcher_{current_hash}.pkl")
    _restored = False

    # Migrate old-format checkpoint (hyperopt_searcher.pkl + .hash) to new format
    old_checkpoint = os.path.join(HYPEROPT_CHECKPOINT_DIR, "hyperopt_searcher.pkl")
    old_hash_file = old_checkpoint + ".hash"
    if not os.path.exists(hyperopt_checkpoint_path) and os.path.exists(old_checkpoint) and os.path.exists(old_hash_file):
        with open(old_hash_file, "r") as f:
            old_hash = f.read().strip()
        if old_hash == current_hash:
            os.rename(old_checkpoint, hyperopt_checkpoint_path)
            os.remove(old_hash_file)
            print(f"   Migrated old checkpoint to {hyperopt_checkpoint_path}")

    if os.path.exists(hyperopt_checkpoint_path):
        try:
            hyperopt_searcher.restore(hyperopt_checkpoint_path)
            # Clear stale trial mappings from previous session.
            hyperopt_searcher._live_trial_mapping = {}
            # Remove any instance-level _setup_hyperopt that may have been
            # pickled from a previous session's monkey-patch.
            if '_setup_hyperopt' in hyperopt_searcher.__dict__:
                del hyperopt_searcher.__dict__['_setup_hyperopt']
            # Wrap _setup_hyperopt to preserve restored trials.
            # set_search_properties() (called by Tuner.fit()) triggers
            # _setup_hyperopt() which creates a NEW _hpopt_trials, destroying
            # our restored history. The wrapper re-injects the saved trials.
            _restored_trials = hyperopt_searcher._hpopt_trials
            _original_setup = hyperopt_searcher._setup_hyperopt
            def _patched_setup():
                _original_setup()
                hyperopt_searcher._hpopt_trials = _restored_trials
            hyperopt_searcher._setup_hyperopt = _patched_setup
            # Clear saved search space so Tuner can re-initialize it
            # from param_space via set_search_properties().
            hyperopt_searcher._space = None
            hyperopt_searcher.domain = None
            hyperopt_searcher._points_to_evaluate = None
            n_prev = len(hyperopt_searcher._hpopt_trials.trials)
            print(f"   Restored HyperOptSearch TPE model with {n_prev} previous trials (hash: {current_hash})")
            _restored = True
        except Exception as e:
            print(f"   Failed to restore HyperOptSearch checkpoint: {e}")
            print("   Starting fresh exploration.")
    if not _restored:
        print(f"Starting HyperOptSearch from scratch (hash: {current_hash})")

print(f"\nUsing trainval dataset (k-fold): {LOCAL_TRAINVAL_PATH}")

# Concurrency limiter for parallel trials
limited_search_alg = ConcurrencyLimiter(
    hyperopt_searcher,
    max_concurrent=3
)

print("=" * 60)
print("STARTING HYPERPARAMETER TUNING")
print("=" * 60)

# Configure Tuner ‚Äî uses trainval dataset (no val/ split)
tuner = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(
            train_linet_tune,
            data_root=LOCAL_TRAINVAL_PATH,
            target_size=(416, 544),
            seed=SEED
        ),
        resources={"cpu": 3, "gpu": 0.333}
    ),
    param_space=search_space,
    tune_config=tune.TuneConfig(
        # HistoricalMedianStoppingRule uses per-epoch "accuracy" for stopping
        # decisions (current epoch accuracy compared against running medians)
        scheduler = HistoricalMedianStoppingRule(
            historical_csv_path=EPOCH_HISTORY_CSV_PATH if WARM_START_ENABLED else None,
            search_space_hash=get_search_space_hash(search_space) if WARM_START_ENABLED else None,
            min_historical_epochs=20,
            metric="accuracy",
            mode="max",
            time_attr="training_iteration",
            grace_period=25,
            min_samples_required=5,
        ),
        search_alg=limited_search_alg,
        num_samples=13
    ),
)

# Run Tuning
results = tuner.fit()

# Remove monkey-patched _setup_hyperopt so save() doesn't pickle the closure
if '_setup_hyperopt' in hyperopt_searcher.__dict__:
    del hyperopt_searcher.__dict__['_setup_hyperopt']

# Get Best Result
best_result = results.get_best_result("best_accuracy", "max")

print("\n" + "=" * 60)
print("TUNING COMPLETE")
print("=" * 60)
print(f"Best Trial Config: {best_result.config}")
print(f"Best Trial Accuracy: {best_result.metrics['best_accuracy']:.4f}")
print(f"Best Trial Loss: {best_result.metrics['best_loss']:.4f}")

In [None]:
# =============================================================================
# SAVE RESULTS FOR FUTURE WARM-STARTS
# =============================================================================
# Save the full results DataFrame to CSV for warm-starting future runs.
# Results are tagged with a search space hash so only compatible configs
# are used when warm-starting (if you change search space, old configs ignored).
# =============================================================================

import datetime

# Save with timestamp for history, plus a 'latest' version for easy warm-start
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

# Get full results DataFrame
results_df = results.get_dataframe()

# Tag with current search space hash for future compatibility filtering
current_hash = get_search_space_hash(search_space)
results_df['search_space_hash'] = current_hash

# Save timestamped version (for history)
timestamped_path = f"/content/drive/MyDrive/ray_tune_results/trial_{timestamp}.csv"
results_df.to_csv(timestamped_path, index=False)
print(f"üìÅ Saved results to: {timestamped_path}")
print(f"   Search space hash: {current_hash}")

# Save/update 'latest' version (for easy warm-start)
latest_path = WARM_START_CSV_PATH

# If a previous results file exists, merge with new results
if os.path.exists(latest_path):
    previous_df = pd.read_csv(latest_path)
    # Combine previous and new results
    combined_df = pd.concat([previous_df, results_df], ignore_index=True)
    # Remove duplicates based on config columns AND search_space_hash (keep best accuracy)
    config_cols = [c for c in combined_df.columns if c.startswith('config/')]
    dedup_cols = config_cols + ['search_space_hash']
    combined_df = combined_df.sort_values('accuracy', ascending=False)
    combined_df = combined_df.drop_duplicates(subset=dedup_cols, keep='first')
    combined_df.to_csv(latest_path, index=False)

    # Count configs per search space
    hash_counts = combined_df['search_space_hash'].value_counts()
    print(f"üìÅ Updated {latest_path} with {len(results_df)} new trials")
    print(f"   Total unique configs: {len(combined_df)}")
    print(f"   Configs per search space:")
    for h, count in hash_counts.items():
        marker = " (current)" if h == current_hash else ""
        print(f"      {h}: {count} configs{marker}")
else:
    results_df.to_csv(latest_path, index=False)
    print(f"üìÅ Created {latest_path}")


# =============================================================================
# SAVE PER-EPOCH METRICS FOR HISTORICAL MEDIAN STOPPING
# =============================================================================
# Save per-epoch metrics (one row per trial per epoch) so that
# HistoricalMedianStoppingRule can use them as a median baseline
# in future sessions.
# =============================================================================

all_epoch_rows = []
for result in results:
    if result.metrics_dataframe is not None:
        epoch_df = result.metrics_dataframe.copy()
        epoch_df["trial_id"] = result.metrics.get("trial_id", id(result))
        epoch_df["search_space_hash"] = current_hash
        all_epoch_rows.append(epoch_df)

if all_epoch_rows:
    new_epoch_df = pd.concat(all_epoch_rows, ignore_index=True)

    # Merge with previous epoch history (single source of truth)
    if os.path.exists(EPOCH_HISTORY_CSV_PATH):
        previous_epoch_df = pd.read_csv(EPOCH_HISTORY_CSV_PATH)
        combined_epoch_df = pd.concat([previous_epoch_df, new_epoch_df], ignore_index=True)
        # Dedup on (trial_id, training_iteration) keeping last occurrence
        combined_epoch_df = combined_epoch_df.drop_duplicates(
            subset=["trial_id", "training_iteration"], keep="last"
        )
        combined_epoch_df.to_csv(EPOCH_HISTORY_CSV_PATH, index=False)
        print(f"\nüìÅ Updated epoch history: {EPOCH_HISTORY_CSV_PATH}")
        print(f"   Added {len(new_epoch_df)} epoch rows from {len(all_epoch_rows)} trials")
        print(f"   Total epoch rows: {len(combined_epoch_df)}")
    else:
        new_epoch_df.to_csv(EPOCH_HISTORY_CSV_PATH, index=False)
        print(f"\nüìÅ Created epoch history: {EPOCH_HISTORY_CSV_PATH}")
        print(f"   Saved {len(new_epoch_df)} epoch rows from {len(all_epoch_rows)} trials")
else:
    print("\n‚ö†Ô∏è  No per-epoch metrics available (trials may have failed)")

# =============================================================================
# SAVE HYPEROPTSEARCH CHECKPOINT (TPE model state)
# =============================================================================
# Save the full TPE surrogate model so that next session's HyperOptSearch
# retains all (config, result) observations without re-running them.
# Each search space hash gets its own checkpoint file, so switching
# between search spaces preserves history for each one independently.
# =============================================================================

if WARM_START_ENABLED:
    hyperopt_searcher.save(hyperopt_checkpoint_path)
    n_trials = len(hyperopt_searcher._hpopt_trials.trials)
    print(f"\nüìÅ Saved HyperOptSearch checkpoint ({n_trials} trials) to {hyperopt_checkpoint_path}")

print(f"\nüí° To warm-start next run:")
print(f"   1. Set WARM_START_ENABLED = True in the warm-start cell")
print(f"   2. Run the notebook")
print(f"   3. Each search space gets its own TPE checkpoint ‚Üí switch freely between configs")

In [None]:
# Analyze Top 10 Trials from Ray Tune (ranked by best_accuracy)
# With continuous search spaces, each trial has unique float values ‚Äî
# no grouping by config. Treat fold assignment as noise.
import pandas as pd

print("=" * 80)
print("TOP 10 TRIALS BY BEST ACCURACY (K-FOLD CV, CONTINUOUS SEARCH)")
print("=" * 80)

# Get all trials and convert to DataFrame
df = results.get_dataframe()

# Sort by best_accuracy (descending)
df_sorted = df.sort_values('best_accuracy', ascending=False)

# Select relevant columns for display
display_cols = [
    'best_accuracy',
    'config/lr_rgb', 'config/lr_depth', 'config/lr_shared',
    'config/wd_rgb', 'config/wd_depth', 'config/wd_shared',
    'config/s1_eta_min', 'config/s2_eta_min', 'config/eta_min', 'config/t_max',
    'config/dropout_p', 'config/label_smoothing', 'config/grad_clip_norm',
    'config/rgb_aug_prob', 'config/rgb_aug_mag',
    'config/depth_aug_prob', 'config/depth_aug_mag',
    'config/modality_dropout_rate',
]

# Get top 10 trials
top_10 = df_sorted[display_cols].head(10)

# Format for better display
top_10_formatted = top_10.copy()
top_10_formatted['best_accuracy'] = top_10_formatted['best_accuracy'].apply(lambda x: f"{x*100:.2f}%")

# Format scientific notation columns
sci_cols = [
    'config/lr_rgb', 'config/lr_depth', 'config/lr_shared',
    'config/wd_rgb', 'config/wd_depth', 'config/wd_shared',
    'config/s1_eta_min', 'config/s2_eta_min', 'config/eta_min',
]
for col in sci_cols:
    if col in top_10_formatted.columns:
        top_10_formatted[col] = top_10_formatted[col].apply(lambda x: f"{x:.2e}")

# Format float columns
float_cols = [
    'config/dropout_p', 'config/label_smoothing', 'config/grad_clip_norm',
    'config/rgb_aug_prob', 'config/rgb_aug_mag',
    'config/depth_aug_prob', 'config/depth_aug_mag',
    'config/modality_dropout_rate',
]
for col in float_cols:
    if col in top_10_formatted.columns:
        top_10_formatted[col] = top_10_formatted[col].apply(lambda x: f"{x:.3f}")

print(top_10_formatted.to_string(index=False))
print("\n" + "=" * 80)
print("\nNote: Each trial trained on 1 random fold (out of 5). To get robust")
print("estimates, revalidate top configs with full 5-fold CV.")