In [1]:
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torch.nn.functional as F_nn
import random

from PIL import Image
from tqdm import tqdm

import time
import os

import matplotlib.pyplot as plt
import pandas as pd
import timm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def paired_transform(color_img, depth_img, output_size=224, max_depth=10.0, apply_color_jitter=True):
    if random.random() > 0.5:
        color_img = F.hflip(color_img)
        depth_img = F.hflip(depth_img)

    angle = random.uniform(-15, 15)
    color_img = F.rotate(color_img, angle, interpolation=Image.BILINEAR)
    depth_img = F.rotate(depth_img, angle, interpolation=Image.NEAREST)

    i, j, h, w = transforms.RandomResizedCrop.get_params(
        color_img, scale=(0.8, 1.0), ratio=(1.0, 1.0)
    )
    color_img = F.resized_crop(color_img, i, j, h, w, size=(output_size, output_size))
    depth_img = F.resized_crop(depth_img, i, j, h, w, size=(output_size, output_size))

    if apply_color_jitter:
        color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3)
        color_img = color_jitter(color_img)

    # RGB → tensor
    color_tensor = F.to_tensor(color_img)
    color_tensor = F.normalize(color_tensor, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])

    # Depth → tensor (meters)
    depth_np = np.array(depth_img).astype(np.float32) / 1000.0
    depth_tensor = torch.from_numpy(depth_np).unsqueeze(0)
    depth_tensor = torch.clamp(depth_tensor, 0, max_depth)

    return color_tensor, depth_tensor
    

class DepthDataset(Dataset):
    def __init__(self, data_dir, paired_transform=None, max_depth=10.0):
        """
        Args:
            data_dir (str): Path to folder containing 'colors/' and 'depths/' subfolders.
            paired_transform (callable, optional): Function to apply same geometric transform to both RGB and depth.
            max_depth (float): Maximum depth value to scale depth maps (in meters).
        """
        self.data_dir = data_dir
        self.paired_transform = paired_transform
        self.max_depth = max_depth

        # Paths to color and depth folders
        self.color_dir = os.path.join(data_dir, "colors")
        self.depth_dir = os.path.join(data_dir, "depths")

        # List all RGB color images
        self.color_files = sorted([f for f in os.listdir(self.color_dir) if f.endswith("_colors.png")])

    def __len__(self):
        return len(self.color_files)

    def __getitem__(self, idx):
        # Load RGB image
        color_path = os.path.join(self.color_dir, self.color_files[idx])
        color_img = Image.open(color_path).convert("RGB")

        # Load corresponding depth image
        depth_file = self.color_files[idx].replace("_colors.png", "_depth.png")
        depth_path = os.path.join(self.depth_dir, depth_file)
        depth_img = Image.open(depth_path)

        # Apply paired transform if provided
        if self.paired_transform:
            color_tensor, depth_tensor = self.paired_transform(color_img, depth_img)
        else:
            # Convert RGB to tensor and normalize
            color_tensor = T.ToTensor()(color_img)
            color_tensor = T.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225])(color_tensor)
            # Convert depth to tensor and scale to meters
            depth_np = np.array(depth_img).astype(np.float32) / 1000.0  # mm → meters
            depth_tensor = torch.from_numpy(depth_np).unsqueeze(0)
            depth_tensor = torch.clamp(depth_tensor, 0, self.max_depth)

        return color_tensor, depth_tensor


class DepthModel(nn.Module):
    def __init__(self, backbone_name='efficientnet_b3', pretrained=True):
        super().__init__()
        
        # --------------------------
        # Encoder: pretrained CNN
        # --------------------------
        # Use features_only=True to get intermediate feature maps for decoder
        self.encoder = timm.create_model(backbone_name, pretrained=pretrained, features_only=True)
        
        # Channels of encoder feature maps at each stage
        encoder_channels = self.encoder.feature_info.channels()  # e.g., [40, 48, 136, 384]
        last_ch = encoder_channels[-1]

        # --------------------------
        # Simple decoder: upsample to original resolution
        # --------------------------
        self.decoder = nn.Sequential(
            nn.Conv2d(last_ch, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(64, 1, kernel_size=3, padding=1)  # Output: single-channel depth map
        )

    def forward(self, x):
        # Encoder forward: returns list of feature maps at different stages
        features = self.encoder(x)
        x = features[-1]  # Use last-stage feature map for decoder

        # Decode to depth map
        depth = self.decoder(x)
        # Unsample to match input size
        depth = F_nn.interpolate(depth, size=(224, 224), mode='bilinear', align_corners=False)
        return depth


# --------------------------
# Training Loop
# --------------------------
def train_depth_model(model, train_loader, val_loader, epochs=10, freeze_encoder=True, lr_head=1e-3, lr_finetune=1e-4, model_path="depth_model.pth"):
    criterion = nn.MSELoss()
    
    if freeze_encoder:
        for param in model.encoder.parameters():
            param.requires_grad = False
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_head)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

    best_val_loss = float("inf")

    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs} {'(Frozen Encoder)' if freeze_encoder else '(Fine-tune)'}")

        # Training
        model.train()
        train_loss = 0.0
        for images, depths in tqdm(train_loader):
            images, depths = images.to(device), depths.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, depths)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
        train_loss /= len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, depths in val_loader:
                images, depths = images.to(device), depths.to(device)
                outputs = model(images)
                loss = criterion(outputs, depths)
                val_loss += loss.item() * images.size(0)
                
        val_loss /= len(val_loader.dataset)
        rmse = np.sqrt(val_loss)
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val RMSE: {rmse:.4f} m")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), model_path)
            print("✅ Best model saved.")

        # Scheduler step
        scheduler.step(val_loss)

    print("Training complete.")



# --------------------------
# Run training
# --------------------------


# --------------------------
# Device
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------
# Model
# --------------------------
model = DepthModel().to(device)


# --------------------------
# Datasets & Loaders
# --------------------------
train_folder = "/kaggle/input/nyu-depth-split-dataset/nyu_split/train"
valid_folder = "/kaggle/input/nyu-depth-split-dataset/nyu_split/val"
max_depth = 10

train_dataset = DepthDataset(train_folder, paired_transform, max_depth)
val_dataset = DepthDataset(valid_folder, paired_transform, max_depth)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# --------------------------
# Freeze backbone initially
# --------------------------
for param in model.encoder.parameters():
    param.requires_grad = False

# --------------------------
# Training Parameters
# --------------------------
num_epoch_head = 5       # Train classifier head first
num_epoch_finetune = 25  # Fine-tune full model
MODEL_PATH = "/kaggle/working/depth_classifier.pth"

model = DepthModel().to(device)

# 1. Train decoder head first
train_depth_model(model, train_loader, val_loader, epochs=num_epoch_head, freeze_encoder=True, 
                  model_path= MODEL_PATH)

# 2. Fine-tune full model
for param in model.encoder.parameters():
    param.requires_grad = True
train_depth_model(model, train_loader, val_loader, epochs=num_epoch_finetune, freeze_encoder=False, 
                  model_path="depth_model.pth")

Using device: cpu


model.safetensors:   0%|          | 0.00/49.3M [00:00<?, ?B/s]


Epoch 1/5 (Frozen Encoder)


100%|██████████| 19/19 [01:32<00:00,  4.86s/it]


Train Loss: 3.3079 | Val Loss: 2.1087 | Val RMSE: 1.4521 m
✅ Best model saved.

Epoch 2/5 (Frozen Encoder)


100%|██████████| 19/19 [01:35<00:00,  5.00s/it]


Train Loss: 1.7199 | Val Loss: 1.6214 | Val RMSE: 1.2733 m
✅ Best model saved.

Epoch 3/5 (Frozen Encoder)


100%|██████████| 19/19 [01:35<00:00,  5.03s/it]


Train Loss: 1.3521 | Val Loss: 1.3686 | Val RMSE: 1.1699 m
✅ Best model saved.

Epoch 4/5 (Frozen Encoder)


100%|██████████| 19/19 [01:35<00:00,  5.01s/it]


Train Loss: 1.1196 | Val Loss: 1.2538 | Val RMSE: 1.1197 m
✅ Best model saved.

Epoch 5/5 (Frozen Encoder)


100%|██████████| 19/19 [01:35<00:00,  5.02s/it]


Train Loss: 1.1416 | Val Loss: 1.1631 | Val RMSE: 1.0785 m
✅ Best model saved.
Training complete.

Epoch 1/25 (Fine-tune)


100%|██████████| 19/19 [02:53<00:00,  9.15s/it]


Train Loss: 1.7817 | Val Loss: 1.4319 | Val RMSE: 1.1966 m
✅ Best model saved.

Epoch 2/25 (Fine-tune)


100%|██████████| 19/19 [02:57<00:00,  9.36s/it]


Train Loss: 0.9087 | Val Loss: 1.0512 | Val RMSE: 1.0253 m
✅ Best model saved.

Epoch 3/25 (Fine-tune)


100%|██████████| 19/19 [02:56<00:00,  9.29s/it]


Train Loss: 0.6816 | Val Loss: 0.7369 | Val RMSE: 0.8584 m
✅ Best model saved.

Epoch 4/25 (Fine-tune)


100%|██████████| 19/19 [02:48<00:00,  8.89s/it]


Train Loss: 0.6405 | Val Loss: 0.9133 | Val RMSE: 0.9557 m

Epoch 5/25 (Fine-tune)


100%|██████████| 19/19 [02:58<00:00,  9.37s/it]


Train Loss: 0.5135 | Val Loss: 0.6482 | Val RMSE: 0.8051 m
✅ Best model saved.

Epoch 6/25 (Fine-tune)


100%|██████████| 19/19 [02:48<00:00,  8.87s/it]


Train Loss: 0.5340 | Val Loss: 0.6896 | Val RMSE: 0.8304 m

Epoch 7/25 (Fine-tune)


100%|██████████| 19/19 [02:48<00:00,  8.85s/it]


Train Loss: 0.4729 | Val Loss: 0.6469 | Val RMSE: 0.8043 m
✅ Best model saved.

Epoch 8/25 (Fine-tune)


100%|██████████| 19/19 [02:54<00:00,  9.21s/it]


Train Loss: 0.3966 | Val Loss: 0.6635 | Val RMSE: 0.8146 m

Epoch 9/25 (Fine-tune)


100%|██████████| 19/19 [02:54<00:00,  9.17s/it]


Train Loss: 0.4506 | Val Loss: 0.6330 | Val RMSE: 0.7956 m
✅ Best model saved.

Epoch 10/25 (Fine-tune)


100%|██████████| 19/19 [02:51<00:00,  9.04s/it]


Train Loss: 0.4114 | Val Loss: 0.5776 | Val RMSE: 0.7600 m
✅ Best model saved.

Epoch 11/25 (Fine-tune)


100%|██████████| 19/19 [02:48<00:00,  8.88s/it]


Train Loss: 0.3801 | Val Loss: 0.7299 | Val RMSE: 0.8543 m

Epoch 12/25 (Fine-tune)


100%|██████████| 19/19 [02:49<00:00,  8.92s/it]


Train Loss: 0.3925 | Val Loss: 0.6260 | Val RMSE: 0.7912 m

Epoch 13/25 (Fine-tune)


100%|██████████| 19/19 [02:53<00:00,  9.14s/it]


Train Loss: 0.3614 | Val Loss: 0.7369 | Val RMSE: 0.8585 m

Epoch 14/25 (Fine-tune)


100%|██████████| 19/19 [02:51<00:00,  9.00s/it]


Train Loss: 0.3032 | Val Loss: 0.5945 | Val RMSE: 0.7710 m

Epoch 15/25 (Fine-tune)


100%|██████████| 19/19 [02:49<00:00,  8.91s/it]


Train Loss: 0.2716 | Val Loss: 0.5280 | Val RMSE: 0.7267 m
✅ Best model saved.

Epoch 16/25 (Fine-tune)


100%|██████████| 19/19 [02:48<00:00,  8.88s/it]


Train Loss: 0.2518 | Val Loss: 0.5259 | Val RMSE: 0.7252 m
✅ Best model saved.

Epoch 17/25 (Fine-tune)


100%|██████████| 19/19 [02:46<00:00,  8.79s/it]


Train Loss: 0.2355 | Val Loss: 0.5498 | Val RMSE: 0.7415 m

Epoch 18/25 (Fine-tune)


100%|██████████| 19/19 [02:46<00:00,  8.79s/it]


Train Loss: 0.2268 | Val Loss: 0.5449 | Val RMSE: 0.7382 m

Epoch 19/25 (Fine-tune)


100%|██████████| 19/19 [02:52<00:00,  9.08s/it]


Train Loss: 0.2162 | Val Loss: 0.5404 | Val RMSE: 0.7351 m

Epoch 20/25 (Fine-tune)


100%|██████████| 19/19 [02:55<00:00,  9.22s/it]


Train Loss: 0.2236 | Val Loss: 0.5054 | Val RMSE: 0.7109 m
✅ Best model saved.

Epoch 21/25 (Fine-tune)


100%|██████████| 19/19 [02:56<00:00,  9.29s/it]


Train Loss: 0.2184 | Val Loss: 0.5496 | Val RMSE: 0.7413 m

Epoch 22/25 (Fine-tune)


100%|██████████| 19/19 [02:58<00:00,  9.40s/it]


Train Loss: 0.2144 | Val Loss: 0.5273 | Val RMSE: 0.7262 m

Epoch 23/25 (Fine-tune)


100%|██████████| 19/19 [02:56<00:00,  9.27s/it]


Train Loss: 0.2149 | Val Loss: 0.5047 | Val RMSE: 0.7104 m
✅ Best model saved.

Epoch 24/25 (Fine-tune)


100%|██████████| 19/19 [03:13<00:00, 10.17s/it]


Train Loss: 0.2090 | Val Loss: 0.5219 | Val RMSE: 0.7224 m

Epoch 25/25 (Fine-tune)


100%|██████████| 19/19 [03:21<00:00, 10.62s/it]


Train Loss: 0.2059 | Val Loss: 0.5309 | Val RMSE: 0.7286 m
Training complete.
