In [None]:
# --- SETUP ---
%pip install datasets webdataset torch torchvision matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# 2. Login to Hugging Face

from huggingface_hub import login
login("#")   #  paste your HuggingFace token here

In [None]:
# --- TRANSFORMS ---
transform_img = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()    # scales image to [0,1], shape [C,H,W]
])
transform_depth = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()     # converts depth np.array → tensor, shape [1,H,W]
])


def preprocess(example):
    # RGB image (PIL)
    img = example["jpg"].convert("RGB")
    img = transform_img(img)

    # Depth map (NumPy array)
    depth = example["depth.npy"]

    # Convert to tensor (float32) and resize
    depth = torch.tensor(depth, dtype=torch.float32)  # shape [H,W]
    depth = depth.unsqueeze(0)  # add channel dim → [1,H,W]
    depth = T.functional.resize(depth, (224, 224))   # resize to match image

    # Replace in dict
    example["image"] = img
    example["depth"] = depth
    return example


In [None]:
from datasets import load_dataset

# Stream instead of downloading
dataset = load_dataset("adams-story/nyu-depthv2-wds", split="train", streaming=True)

train_data = dataset.take(500)

In [None]:


# Peek at one sample
sample = next(iter(train_data))
sample=preprocess(sample)

print(sample.keys())        # should show dict_keys(['image', 'depth', ...])
print(sample["image"].shape, sample["depth"].shape)


In [None]:

# Wrap into iterable
def collate_fn(batch):
    imgs, depths = [], []
    for b in batch:
        sample = preprocess(b)   # returns {"image": tensor, "depth": tensor}
        imgs.append(sample["image"])
        depths.append(sample["depth"])
    return torch.stack(imgs), torch.stack(depths)

train_loader = DataLoader(train_data, batch_size=64, collate_fn=collate_fn,num_workers=4, pin_memory=True)

In [None]:
# --- MODEL (simple U-Net style encoder-decoder) ---
class DepthEstimationNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64,128,3,stride=2,padding=1), nn.ReLU(),
            nn.Conv2d(128,256,3,stride=2,padding=1), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256,128,3,stride=2,output_padding=1,padding=1), nn.ReLU(),
            nn.ConvTranspose2d(128,64,3,stride=2,output_padding=1,padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64,1,3,stride=2,output_padding=1,padding=1)
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out

device = "cuda" if torch.cuda.is_available() else "cpu"
model = DepthEstimationNet().to(device)

In [None]:
# --- TRAINING ---
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.L1Loss()  # Mean Absolute Error

EPOCHS = 15
import time

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    start_time = time.time()

    for i, batch in enumerate(train_loader):
        batch_start = time.time()
        imgs, depths = batch
        imgs, depths = imgs.to(device), depths.to(device)
        preds = model(imgs)
        loss = loss_fn(preds, depths)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        batch_duration = time.time() - batch_start
        print(f"Batch {i+1} processing time: {batch_duration:.2f}s")

    epoch_duration = time.time() - start_time
    avg_loss = total_loss / (i + 1)
    print(f"Epoch {epoch+1} done in {epoch_duration:.2f}s - Average Loss: {avg_loss:.4f}")



In [None]:
# Save final trained model after all epochs
torch.save(model.state_dict(), "NYUDEPTH.pt")


In [None]:
from google.colab import drive
drive.mount('/content/drive')
torch.save(model.state_dict(), "/content/drive/MyDrive/NYUDEPTH.pt")
