In [None]:
import os, cv2, numpy as np
import time
from glob import glob
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

## Load & Process data

In [None]:
class NYUDepth(Dataset):
    def __init__(self, root, split='train', resize=(256, 320)):

        assert split in ['train', 'test'], "split must be 'train' or 'test'"
        self.resize = resize
        self.split = split

        if split == 'test':
            data_dir = os.path.join(root, 'test', 'official')
            rgb_paths = sorted(glob(os.path.join(data_dir, 'rgb_*.png')))
            depth_paths = sorted(glob(os.path.join(data_dir, 'depth_*.png')))
        else:  # train
            rgb_paths = sorted(glob(os.path.join(root, 'train', '**', 'rgb_*.png'), recursive=True))
            depth_paths = sorted(glob(os.path.join(root, 'train', '**', 'depth_*.png'), recursive=True))

        # Ghép cặp theo tên file (bỏ đuôi)
        self.pairs = []
        rgb_dict = {os.path.basename(p).split('_')[1].split('.')[0]: p for p in rgb_paths}
        depth_dict = {os.path.basename(p).split('_')[1].split('.')[0]: p for p in depth_paths}

        all_indices = sorted(set(rgb_dict.keys()) & set(depth_dict.keys()))
        for idx in all_indices:
            self.pairs.append({
                'rgb': rgb_dict[idx],
                'depth': depth_dict[idx]
            })

        print(f"[NYUDepth] {split.upper()} - Found {len(self.pairs)} valid RGB-Depth pairs.")

        if len(self.pairs) == 0:
            raise ValueError("No matching RGB-Depth pairs found! Check file naming and structure.")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, i):
        pair = self.pairs[i]

        # Đọc RGB
        img = cv2.imread(pair['rgb'])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Đọc Depth (16-bit)
        dep = cv2.imread(pair['depth'], cv2.IMREAD_UNCHANGED).astype(np.float32)

        # Resize
        if self.resize:
            img = cv2.resize(img, (self.resize[1], self.resize[0]), interpolation=cv2.INTER_LINEAR)
            dep = cv2.resize(dep, (self.resize[1], self.resize[0]), interpolation=cv2.INTER_NEAREST)

        # Chuẩn hóa RGB
        img = img.astype(np.float32) / 255.0
        img = np.moveaxis(img, -1, 0)  # HWC -> CHW

        # Depth: thêm channel
        dep = np.expand_dims(dep, 0)  # 1xHxW

        return {
            'image': img,      # (3, H, W)
            'depth': dep,      # (1, H, W)
            'path': pair['rgb']
        }

## Define Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_out, 3, padding=1), nn.BatchNorm2d(c_out),
            nn.ReLU(True),
            nn.Conv2d(c_out, c_out, 3, padding=1), nn.BatchNorm2d(c_out),
            nn.ReLU(True),
        )

    def forward(self, x):
        return self.net(x)


class UNetDepth(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.d1 = DoubleConv(3, base)
        self.d2 = DoubleConv(base, base * 2);
        self.p1 = nn.MaxPool2d(2)
        self.d3 = DoubleConv(base * 2, base * 4);
        self.p2 = nn.MaxPool2d(2)
        self.bott = DoubleConv(base * 4, base * 8)
        self.u3 = nn.ConvTranspose2d(base * 8, base * 4, 2, 2);
        self.c3 = DoubleConv(base * 8, base * 4)
        self.u2 = nn.ConvTranspose2d(base * 4, base * 2, 2, 2);
        self.c2 = DoubleConv(base * 4, base * 2)
        self.u1 = nn.ConvTranspose2d(base * 2, base, 2, 2);
        self.c1 = DoubleConv(base * 2, base)
        self.out = nn.Conv2d(base, 1, 1)

    def forward(self, x):
        x1 = self.d1(x)
        x2 = self.d2(self.p1(x1))
        x3 = self.d3(self.p2(x2))
        xb = self.bott(self.p2(x3))
        y3 = self.c3(torch.cat([self.u3(xb), x3], dim=1))
        y2 = self.c2(torch.cat([self.u2(y3), x2], dim=1))
        y1 = self.c1(torch.cat([self.u1(y2), x1], dim=1))
        d = self.out(y1)
        return torch.relu(d)  # độ sâu dương

In [None]:
def ssim(x, y, C1=0.01**2, C2=0.03**2, kernel_size=3):
    pad = kernel_size // 2
    mu_x = F.avg_pool2d(x, kernel_size, stride=1, padding=pad)
    mu_y = F.avg_pool2d(y, kernel_size, stride=1, padding=pad)

    mu_x2 = mu_x * mu_x
    mu_y2 = mu_y * mu_y
    mu_xy = mu_x * mu_y

    sigma_x2 = F.avg_pool2d(x * x, kernel_size, stride=1, padding=pad) - mu_x2
    sigma_y2 = F.avg_pool2d(y * y, kernel_size, stride=1, padding=pad) - mu_y2
    sigma_xy = F.avg_pool2d(x * y, kernel_size, stride=1, padding=pad) - mu_xy

    ssim_n = (2 * mu_xy + C1) * (2 * sigma_xy + C2)
    ssim_d = (mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2)
    ssim = ssim_n / ssim_d

    return torch.clamp((1 - ssim) / 2, 0, 1).mean()


def depth_loss(pred, gt):
    l1 = torch.mean(torch.abs(pred - gt))
    # làm mịn cạnh (edge-aware smoothness)
    dx = torch.mean(torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1]))
    dy = torch.mean(torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :]))
    return l1 + 0.1 * (dx + dy)


def depth_ssim_loss(pred, gt, alpha=1.0, beta=1.0, gamma=0.1):
    # 1. L1 Loss
    l1 = torch.mean(torch.abs(pred - gt))

    # 2. SSIM Loss (chuẩn hóa về [0,1])
    pred_norm = torch.clamp(pred / 10000.0, 0, 1)  # NYUv2 max ~10m
    gt_norm = torch.clamp(gt / 10000.0, 0, 1)
    ssim_loss = ssim(pred_norm, gt_norm)

    # 3. Edge-aware Smoothness (gradient của pred)
    def gradient_x(img):
        return img[:, :, :, :-1] - img[:, :, :, 1:]

    def gradient_y(img):
        return img[:, :, :-1, :] - img[:, :, 1:, :]

    dx = torch.mean(torch.abs(gradient_x(pred)))
    dy = torch.mean(torch.abs(gradient_y(pred)))

    smooth_loss = dx + dy

    # Tổng loss
    total_loss = alpha * l1 + beta * ssim_loss + gamma * smooth_loss
    return total_loss, l1.item(), ssim_loss.item(), smooth_loss.item()

In [None]:
train_ds = NYUDepth('/kaggle/input/nyuv2-official-split-dataset', 'train', resize=(240, 320))
val_ds = NYUDepth('/kaggle/input/nyuv2-official-split-dataset', 'test', resize=(240, 320))
train_ld = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4)
val_ld = DataLoader(val_ds, batch_size=8)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UNetDepth(base=32).to(device)
opt = optim.AdamW(model.parameters(), lr=1e-3)
sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)
print(model)

[NYUDepth] TRAIN - Found 621 valid RGB-Depth pairs.
[NYUDepth] TEST - Found 654 valid RGB-Depth pairs.
UNetDepth(
  (d1): DoubleConv(
    (net): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (d2): DoubleConv(
    (net): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)

## Training

In [None]:
num_epoch = 5
for epoch in range(num_epoch):
    model.train()
    loss_sum = l1_sum = ssim_sum = smooth_sum = 0.0
    start_time = time.perf_counter()

    for batch in train_ld:
        img = batch['image'].to(device, non_blocking=True)
        dep = batch['depth'].to(device, non_blocking=True)

        opt.zero_grad()
        pred = model(img)

        # Tính loss với 4 thành phần
        loss, l1, ssim_val, smooth = depth_ssim_loss(pred, dep, alpha=1.0, beta=1.0, gamma=0.1)

        loss.backward()
        opt.step()
        sched.step()

        # Cộng dồn
        loss_sum += loss.item()
        l1_sum += l1
        ssim_sum += ssim_val
        smooth_sum += smooth

    end_time = time.perf_counter()
    num_batches = len(train_ld)

    print(f"Epoch {epoch + 1}/{num_epoch} | "
          f"Loss: {loss_sum / num_batches:.4f} | "
          f"L1: {l1_sum / num_batches:.4f} | "
          f"SSIM: {ssim_sum / num_batches:.4f} | "
          f"Smooth: {smooth_sum / num_batches:.4f} | "
          f"Time: {end_time - start_time:.2f}s")

Epoch 1/5 | Loss: 24510.4734 | L1: 24509.9739 | SSIM: 0.4804 | Smooth: 0.1914 | Time: 8.04s
Epoch 2/5 | Loss: 24510.4026 | L1: 24509.9005 | SSIM: 0.4803 | Smooth: 0.2188 | Time: 7.88s
Epoch 3/5 | Loss: 24535.6659 | L1: 24535.1585 | SSIM: 0.4801 | Smooth: 0.2714 | Time: 7.93s
Epoch 4/5 | Loss: 24531.3735 | L1: 24530.8612 | SSIM: 0.4799 | Smooth: 0.3243 | Time: 7.91s
Epoch 5/5 | Loss: 24514.3156 | L1: 24513.7976 | SSIM: 0.4798 | Smooth: 0.3811 | Time: 7.89s
