# EfficientDroneDepth — Drive-aware Training
Run on real **TartanAir/AbandonedFactory** if found on Drive, otherwise auto-download the minimal subset and train end-to-end.
**Order:** Setup → Data → Remap → Train → Eval/Export → Figures.


## 1) Setup

In [None]:
# Install deps
!pip -q install torch torchvision timm numpy opencv-python scikit-image onnx onnxruntime rich matplotlib tartanair

import torch, platform, sys, os
print("torch", torch.__version__, "cuda", torch.cuda.is_available(), platform.platform())

# Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Clone fresh
%cd /content
!rm -rf efficientdronedepth
!git clone --depth 1 https://github.com/malodept/EfficientDroneDepth.git efficientdronedepth
%cd /content/efficientdronedepth
if "/content/efficientdronedepth" not in sys.path:
    sys.path.append("/content/efficientdronedepth")

# Persist runs to Drive
OUT="/content/drive/MyDrive/edd_runs_abandoned"
os.makedirs(OUT, exist_ok=True)
!rm -rf runs && ln -s "$OUT" runs && ls -la runs


## 2) Data — detect or download

In [None]:
import os, pathlib, shutil
from tartanair import tartanair as ta

DRIVE_ROOT = '/content/drive/MyDrive/tartanair_subset'
ENV_DIR = os.path.join(DRIVE_ROOT, 'AbandonedFactory')
DATA_OMNI = os.path.join(ENV_DIR, 'Data_omni')
os.makedirs(DRIVE_ROOT, exist_ok=True)

present = os.path.exists(DATA_OMNI) and any(
    os.path.isdir(os.path.join(DATA_OMNI, p)) for p in ['P0000','P0001','P0002','P0004','P0005']
)
print("Dataset present:", present)

if not present:
    ta.init(DRIVE_ROOT)
    ta.download_ground(
        env=["AbandonedFactory"],
        version=['omni'],
        traj=[],
        modality=['image', 'depth'],
        camera_name=['lcam_front'],
        unzip=True
    )
    print("Downloaded to:", DRIVE_ROOT)
else:
    print("Using existing dataset at:", DRIVE_ROOT)

!ls -R $DRIVE_ROOT/AbandonedFactory | head -n 60


## 3) Remap layout to loader format

In [None]:
import re

ROOT = os.path.join(DRIVE_ROOT, "AbandonedFactory", "Data_omni")
seqs = [d for d in os.listdir(ROOT) if re.match(r"P\d{4}", d)]
print("Found sequences:", seqs)

def fix_seq(pseq):
    import pathlib, shutil
    src_img = pathlib.Path(ROOT)/pseq/"image_lcam_front"
    src_dep = pathlib.Path(ROOT)/pseq/"depth_lcam_front"
    if not src_img.exists() or not src_dep.exists():
        return
    dst_left  = pathlib.Path(ROOT)/pseq/"left"
    dst_depth = pathlib.Path(ROOT)/pseq/"depth"
    dst_left.mkdir(exist_ok=True)
    dst_depth.mkdir(exist_ok=True)
    for f in sorted(src_img.glob("*.png")):
        stem = f.stem.replace("_lcam_front","")
        out = dst_left/f"{stem}_left.png"
        if not out.exists(): shutil.copy2(f, out)
    for f in sorted(src_dep.glob("*.png")):
        stem = f.stem.replace("_lcam_front_depth","")
        out = dst_depth/f"{stem}_depth.png"
        if not out.exists(): shutil.copy2(f, out)

for s in seqs:
    fix_seq(s)

print("Remap done.")
!find "$DRIVE_ROOT/AbandonedFactory/Data_omni" -maxdepth 2 -type d -name left -o -name depth | sed 's/^/ - /' | head -n 20


## 4) Sanity check model forward

In [None]:
from src.edd.modeling import DPTSmall
import torch
m = DPTSmall(pretrained=False)
x = torch.randn(2,3,384,384)
y = m(x)
print("Output shape:", tuple(y.shape))


## 5) Train

In [None]:
DATA_ROOT = "/content/drive/MyDrive/tartanair_subset/AbandonedFactory"
print("DATA_ROOT =", DATA_ROOT)

epochs, batch, img_sz = 12, 8, 384
!python -m src.edd.train --data_root $DATA_ROOT --epochs {epochs} --batch_size {batch} --img_size {img_sz}


## 6) Evaluate + Export

In [None]:
!python -m src.edd.eval --data_root $DATA_ROOT --ckpt runs/edd_midas.pt --bench
!python -m src.edd.export_onnx --ckpt runs/edd_midas.pt --onnx runs/edd_midas.onnx
!python -m src.edd.quantize_dynamic --onnx runs/edd_midas.onnx --out runs/edd_midas_int8.onnx


## 7) Save sample predictions

In [None]:
import os, torch, cv2
from src.edd.modeling import DPTSmall
from src.edd.data import TartanAirDepth

os.makedirs("runs/figures", exist_ok=True)
m = DPTSmall(pretrained=False).eval()
m.load_state_dict(torch.load("runs/edd_midas.pt", map_location="cpu")["model"], strict=False)

ds = TartanAirDepth(DATA_ROOT, img_size=384, limit_samples=50, train=False)
n = min(4, len(ds))
for i in range(n):
    ex = ds[i]
    y = m(ex["image"].unsqueeze(0)).squeeze().detach().numpy()
    y = (y / (y.max() + 1e-6) * 255).astype("uint8")
    cv2.imwrite(f"runs/figures/pred_{i}.png", y)
print("saved:", n, "figures → runs/figures/")
