# Adapting ATEK Data Samples for Depth Anything 2 Model

This notebook demonstrates how to adapt ATEK data samples to be compatible with the Depth Anything 2 model. We will cover the entire process, including loading the dataset, adapting it, and performing inference.

Depth Anything 2: https://github.com/DepthAnything/Depth-Anything-V2

## Import Required Libraries

First, we import all necessary libraries that will be used throughout the notebook.

In [None]:
import argparse
import glob
import os
from typing import List

import cv2
import matplotlib
import numpy as np
import torch
from atek.data_loaders.atek_wds_dataloader import load_atek_wds_dataset
from depth_anything_v2.dpt import DepthAnythingV2
from webdataset.filters import pipelinefilter

## Configuration and Initialization
Define the paths and configuration parameters that will be used to load the data and the model.

In [None]:
wds_dir = "/data/home/ariak/datadir/0823_wds_for_test_depthanything2"
out_dir = "/data/home/ariak/workdir/Depth-Anything-V2/output"

## Model Adaptor Class from ATEK to Depth Anything 2
This class handles the conversion of ATEK dataset format to be compatible with Depth Anything 2.

In [None]:
class depthAnything2Adaptor:

    @staticmethod
    def get_dict_key_mapping_all():
        dict_key_mapping = {"mfcd#camera-rgb+images": "image"}
        return dict_key_mapping

    def atek_to_depth_anything2(self, data):
        for atek_wds_sample in data:
            sample = {}
            # Add images
            # from [1, C, H, W] to [H, W, C]
            image_torch = atek_wds_sample["image"].clone().detach()
            image_np = image_torch.squeeze(0).permute(1, 2, 0).numpy()
            sample["image"] = image_np
            yield sample

## Data Loading Function
load_atek_wds_dataset_as_depth_anything_2 loads the ATEK dataset and applies the adaptor to make it compatible with Depth Anything 2.

In [None]:
def load_atek_wds_dataset_as_depth_anything_2(
    urls: List,
    batch_size: int,
    repeat_flag: bool,
    shuffle_flag: bool = False,
):
    adaptor = depthAnything2Adaptor()

    return load_atek_wds_dataset(
        urls,
        batch_size=batch_size,
        dict_key_mapping=depthAnything2Adaptor.get_dict_key_mapping_all(),
        data_transform_fn=pipelinefilter(adaptor.atek_to_depth_anything2)(),
        collation_fn=simple_collation_fn,
        repeat_flag=repeat_flag,
        shuffle_flag=shuffle_flag,
    )

## Simple Collation Function
A simple function to collate batches of data.

In [None]:
def simple_collation_fn(batch):
    # Simply collate as a list
    return list(batch)

## Load Depth Anything Model
Load the Depth Anything 2 model with specified configurations.

In [None]:
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

model_configs = {
    "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
    "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
    "vitl": {
        "encoder": "vitl",
        "features": 256,
        "out_channels": [256, 512, 1024, 1024],
    },
    "vitg": {
        "encoder": "vitg",
        "features": 384,
        "out_channels": [1536, 1536, 1536, 1536],
    },
}

encoder = "vitl"  # or 'vits', 'vitb', 'vitg'

model = DepthAnythingV2(**model_configs[encoder])
model.load_state_dict(
    torch.load(
        f"/data/home/ariak/workdir/Depth-Anything-V2/checkpoints/depth_anything_v2_{encoder}.pth",
        map_location="cpu",
    )
)
model = model.to(DEVICE).eval()

## Perform Inference
Perform inference on the adapted dataset and show the results.

In [None]:
tar_list = [os.path.join(wds_dir, f"shards-000{i}.tar") for i in range(5)]
depth_anything2_dataset = load_atek_wds_dataset_as_depth_anything_2(
    tar_list,
    batch_size=1,
    repeat_flag=False,
    shuffle_flag=False,
)


with torch.inference_mode(), torch.autocast("cuda"):
    cur_img_id = 0
    for depth_anything2_dict_list in depth_anything2_dataset:
        for depth_anything2_dict in depth_anything2_dict_list:
            raw_image = depth_anything2_dict["image"]
            depth = model.infer_image(raw_image)  # HxW raw depth map in numpy
            cv2.imshow("Raw Image", raw_image)
            cv2.setWindowTitle("Raw Image", f"Raw Image {cur_img_id}")
            cv2.imshow("Depth", depth)
            cv2.setWindowTitle("Depth", f"Depth {cur_img_id}")
            print("wrote to:", os.path.join(out_dir, f"raw_{cur_img_id}.png"))
    cur_img_id += 1