In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import cv2
import random
import h5py

import sys
# sys.path.append('/kaggle/working/Depth-Anything-V2/metric_depth')

from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate import notebook_launcher
from accelerate import DistributedDataParallelKwargs

import transformers

import torch
import torchvision
from torchvision.transforms import v2
from torchvision.transforms import Compose
import torch.nn.functional as F
import albumentations as A

# from depth_anything_v2.dpt import DepthAnythingV2
from depth_model.fdepth_resnet_v2 import FastDepthV2


from metric_depth.util.loss import SiLogLoss
from metric_depth.dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop

### List of paths

In [None]:
def get_all_files(directory):
    all_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            all_files.append(os.path.join(root, file))
    return all_files


train_paths = get_all_files('/kaggle/input/nyu-depth-dataset-v2/nyudepthv2/train')
val_paths = get_all_files('/kaggle/input/nyu-depth-dataset-v2/nyudepthv2/val')

### Dataset preparation

In [None]:
#NYU Depth V2 40k. Original NYU is 400k
class NYU(torch.utils.data.Dataset):
    def __init__(self, paths, mode, size=(518, 518)):
        
        self.mode = mode #train or val
        self.size = size
        self.paths = paths
        
        net_w, net_h = size
        #author's transforms
        self.transform = Compose([
            Resize(
                width=net_w,
                height=net_h,
                resize_target=True if mode == 'train' else False,
                keep_aspect_ratio=True,
                ensure_multiple_of=14,
                resize_method='lower_bound',
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            PrepareForNet(),
        ] + ([Crop(size[0])] if self.mode == 'train' else []))
        
        # only horizontal flip in the paper
        self.augs = A.Compose([
            A.HorizontalFlip(),
            A.ColorJitter(hue = 0.1, contrast=0.1, brightness=0.1, saturation=0.1),
            A.GaussNoise(var_limit=25),
#             A.ShiftScaleRotate(shift_limit=0.04, scale_limit=0.1, rotate_limit=7, interpolation=cv2.INTER_CUBIC, border_mode=0)
        ])
    
    def __getitem__(self, item):
        path = self.paths[item]
        image, depth = self.h5_loader(path)
        
        if self.mode == 'train':
            augmented = self.augs(image=image, mask = depth)
            image = augmented["image"] / 255.0
            depth = augmented['mask']
        else:
            image = image / 255.0

          
        sample = self.transform({'image': image, 'depth': depth})

        sample['image'] = torch.from_numpy(sample['image'])
        sample['depth'] = torch.from_numpy(sample['depth'])
        
        # sometimes there are masks for valid depths in datasets because of noise e.t.c
#         sample['valid_mask'] = ...

#         sample['image_path'] = path
        
        return sample

    def __len__(self):
        return len(self.paths)
    
    def h5_loader(self, path):
        h5f = h5py.File(path, "r")
        rgb = np.array(h5f['rgb'])
        rgb = np.transpose(rgb, (1, 2, 0))
        depth = np.array(h5f['depth'])
        return rgb, depth

### Dataset demonstration

In [None]:
num_images = 5

fig, axes = plt.subplots(num_images, 2, figsize=(10, 5 * num_images))

train_set = NYU(train_paths, mode='train') 

for i in range(num_images):
    sample = train_set[i*1000]
    img, depth = sample['image'].numpy(), sample['depth'].numpy()

    mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
    std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
    img = img*std+mean

    axes[i, 0].imshow(np.transpose(img, (1,2,0)))
    axes[i, 0].set_title('Image')
    axes[i, 0].axis('off')


    im1 = axes[i, 1].imshow(depth, cmap='viridis', vmin=0)
    axes[i, 1].set_title('True Depth')
    axes[i, 1].axis('off')
    fig.colorbar(im1, ax=axes[i, 1])
    


plt.tight_layout()


In [None]:
def get_dataloaders(batch_size):
    
    train_dataset = NYU(train_paths, mode='train')
    val_dataset = NYU(val_paths, mode='val')
    
    
    train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                                  batch_size = batch_size,
                                                  shuffle=True,
                                                  num_workers=4,
                                                  drop_last=True
#                                                   pin_memory=True
                                                  )

    val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                                               batch_size = 1, #for dynamic resolution evaluations without padding
                                               shuffle=False,
                                               num_workers=4,
                                               drop_last=True
#                                                 pin_memory=True
                                                )
    
    return train_dataloader, val_dataloader

### Metrics function

In [None]:
def eval_depth(pred, target):
    assert pred.shape == target.shape

    thresh = torch.max((target / pred), (pred / target))

    d1 = torch.sum(thresh < 1.25).float() / len(thresh)

    diff = pred - target
    diff_log = torch.log(pred) - torch.log(target)

    abs_rel = torch.mean(torch.abs(diff) / target)

    rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
    mae = torch.mean(torch.abs(diff))

    silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))

    return {'d1': d1.detach(), 'abs_rel': abs_rel.detach(),'rmse': rmse.detach(), 'mae': mae.detach(), 'silog':silog.detach()}

### Training

In [None]:
def train_fn():
    
    device = "cuda:0"
    load_state = False
    state_path = './'

    #config nyu
    batch_size = 128
    max_depth = 10


    
    # params
    num_epochs = 500
    warmup_epochs = 8
    num_cycles = 2

    print("CUDA available:", torch.cuda.is_available())
    print("CUDA device:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))



    model = FastDepthV2(max_depth=max_depth).to(device)

    # optim = torch.optim.Adam(
    #       model.parameters(),  # lấy toàn bộ parameter của model
    #       lr=3e-4,
    #       weight_decay=0.01
    #   )
    
    backbone_params = model.encoder.parameters()
    decoder_params = model.decoder.parameters()

    optim = torch.optim.Adam([
        {"params": backbone_params, "lr": 3e-4},  # backbone LR nhỏ
        {"params": decoder_params, "lr": 3e-3}    # decoder LR lớn
    ], weight_decay=1e-5)


    print('Model created')

    criterion = SiLogLoss() # author's loss
    # criterion = SiLogL1Loss()
    # criterion = DepthLoss()
    # scheduler = transformers.get_cosine_schedule_with_warmup(optim, len(train_dataloader)*warmup_epochs, num_epochs*scheduler_rate*len(train_dataloader))

    # train_loader, val_loader = dataloader_v6.create_data_loaders("/home/gremsy_guest/hyp_workspace/depth_dataset/datasets/hyp_dataset_v1", batch_size=512, size=(160, 128))
    train_dataloader, val_dataloader = get_dataloaders(batch_size)

    print(f"size of train loader: {len(train_dataloader)}; val loader: {len(val_dataloader)}")
 
    # best val monitor: loss silog
    best_val = 1e9
    # best_loss = 1e9
    history = {"train_loss": [], "val_loss": [], "val_metrics": []}

    if load_state:
        checkpoint = torch.load("/home/gremsy_guest/hyp_workspace/depth_v2/ours_checkpoints/11/last_checkpoint_16.pth", map_location=device)
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])

        # model.load_state_dict(checkpoint)
        model = model.to(device)


    # model = torch.compile(model)  

    

    # Chọn device
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # print(f"Using device: {device}")
    # for name, param in model.named_parameters():
    #     print(name, param.device)

    # print("------------------------------------------------------------------")


    for epoch in range(0, num_epochs):
        model.train()
        total_loss = 0

        for i , sample in enumerate(tqdm(train_dataloader, total=len(train_dataloader))):
            # img, depth = input.to(device), target.to(device)
            img, depth = sample['image'], sample['depth']
            img, depth = img.to(device), depth.to(device)

            optim.zero_grad()
            pred = model(img)

            # loss = criterion('l1',pred,depth,epoch)

            # mask = (depth > 1e-3)
            mask = (depth > 1e-3) & (depth <= max_depth) & torch.isfinite(depth)

            # print("pred shape:", pred.shape)
            # print("target shape:", target.shape)
            # print("valid_mask shape:", mask.shape)
            loss = criterion(pred, depth, mask)

            loss.backward()
            optim.step()
            # scheduler.step()


            total_loss += loss.item()

        avg_loss = total_loss / len(train_dataloader)

        # ===== Validation =====
        model.eval()
        # results = {'d1': 0, 'rmse': 0}
        results = {'d1': 0, 'abs_rel': 0, 'rmse': 0, 'mae': 0, 'silog': 0}
        # test_loss = 0

        with torch.no_grad():
            for sample , (input,target) in tqdm(enumerate(val_dataloader)):
                # img, depth = input.to(device), target.to(device)
                img, depth = sample['image'], sample['depth']
                img, depth = img.to(device), depth.to(device)

                pred = model(img)

                pred = F.interpolate(pred[:, None], depth.shape[-2:], mode='bilinear', align_corners=True)[0, 0]

                # test_loss += criterion('l1',pred, depth).item()
                # pred = pred.squeeze(1).squeeze(0)

                # mask = (depth >= 0.001)
                # cur_results = eval_depth(pred, depth)

                # print(depth)


                mask = (depth > 1e-3) & (depth <= max_depth) & torch.isfinite(depth)

                # print(mask)

                # valid_pixels = mask.sum().item()
                # print(f"mask: {valid_pixels}")

                # print("pred shape:", pred.shape)
                # print("target shape:", target.shape)
                # print("valid_mask shape:", mask.shape)
                cur_results = eval_depth(pred[mask], depth[mask])


                for k in results:
                    results[k] += cur_results[k]

        
        # val_loss = test_loss/len(val_loader)

        # for k in results:
        #    results[k] = round(results[k] / len(val_loader), 3)
        for k in results:
            results[k] = round((results[k] / len(val_dataloader)).item(), 3)

        # # ===== Save Checkpoint =====
        # torch.save({
        #     "model": model.state_dict(),
        #     "optim": optim.state_dict(),
        #     # "scheduler": scheduler.state_dict()
        # }, f"{state_path}/last_checkpoint_{epoch}.pth")

        # if results['abs_rel'] < best_val_absrel:
        if results['abs_rel'] < best_val:

            best_val = results['abs_rel']
            new_ckpt = f"{state_path}/checkpoint_best_{epoch}.pth"

            # 1. Lưu checkpoint mới
            # torch.save(model.state_dict(), new_ckpt)
            torch.save({
                "model": model.state_dict(),
                "optim": optim.state_dict(),
                # "scheduler": scheduler.state_dict()
            }, new_ckpt)


        # Cập nhật history
        history["train_loss"].append(avg_loss)
        # history["val_loss"].append(val_loss)
        history["val_metrics"].append(results)

        print(f"epoch_{epoch}, train_loss={avg_loss:.5f}, val_metrics={results}")

        # ==== Vẽ biểu đồ ====
        # epochs = range(1, num_epochs+1)
        epochs = range(1, len(history["train_loss"]) + 1)
        loss_val = [m["silog"] for m in history["val_metrics"]]  # lấy metric silog từ val_metrics

        plt.figure(figsize=(8, 5))

        # Train loss
        plt.plot(epochs, history["train_loss"], label="Train Loss", marker='o')

        # Validation loss
        plt.plot(epochs, loss_val, label="Val Loss", marker='s')

        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training vs Validation Loss")
        plt.legend()
        plt.grid(True)

        # Lưu biểu đồ
        plt.savefig(f"{state_path}/train_val_loss_curve.png", dpi=150)
        plt.close()

        absrel = [m["abs_rel"] for m in history["val_metrics"]]
        plt.figure(figsize=(8,5))
        plt.plot(epochs, absrel, label="AbsRel (val)")
        plt.xlabel("Epoch")
        plt.ylabel("AbsRel")
        plt.legend()
        plt.savefig(f"{state_path}/val_absrel_curve.png")
        plt.close()

In [None]:
#You can run this code with 1 gpu. Just set num_processes=1
notebook_launcher(train_fn, num_processes=1)
# ignore the error. it's harmless

### Inference

Reminder: This is ABSOLUTE (metric) depth, not a relative depth!!

In [None]:
model = FastDepthV2(max_depth=10).to('cuda:0')
checkpoint = torch.load("", map_location='cuda:0')
model.load_state_dict(checkpoint["model"])
model = model.to('cuda:0')
model.eval()

In [None]:
num_images = 10

fig, axes = plt.subplots(num_images, 3, figsize=(15, 5 * num_images))

val_dataset = NYU(val_paths, mode='val') 
model.eval()
for i in range(num_images):
    sample = val_dataset[i]
    img, depth = sample['image'], sample['depth']
    
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
   
    with torch.inference_mode():
        pred = model(img.unsqueeze(0).to('cuda'))
        pred = F.interpolate(pred[:, None], depth.shape[-2:], mode='bilinear', align_corners=True)[0, 0]
            
    img = img*std + mean
     
    axes[i, 0].imshow(img.permute(1,2,0))
    axes[i, 0].set_title('Image')
    axes[i, 0].axis('off')

    max_depth = max(depth.max(), pred.cpu().max())
    
    im1 = axes[i, 1].imshow(depth, cmap='viridis', vmin=0, vmax=max_depth)
    axes[i, 1].set_title('True Depth')
    axes[i, 1].axis('off')
    fig.colorbar(im1, ax=axes[i, 1])
    
    im2 = axes[i, 2].imshow(pred.cpu(), cmap='viridis', vmin=0, vmax=max_depth)
    axes[i, 2].set_title('Predicted Depth')
    axes[i, 2].axis('off')
    fig.colorbar(im2, ax=axes[i, 2])

plt.tight_layout()
