In [1]:
import os
import pandas as pd
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [2]:
class NYUDepthDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.data.iloc[idx, 0]
        depth_path = self.data.iloc[idx, 1]

        image = Image.open(img_path).convert('RGB')
        depth = Image.open(depth_path).convert('L')

        sample = { "image": image, "depth": depth }
        
        if self.transform:
            sample = self.transform(sample)

        return sample

class NYUDepthTransform:
    def __init__(self, img_size=(224,224)):
        self.img_transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])
        self.depth_transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor()
        ])

    def __call__(self, sample):
        image = self.img_transform(sample["image"])
        depth = self.depth_transform(sample["depth"])
        return { "image": image, "depth": depth }

class DepthEstimationModel(nn.Module):
    def __init__(self):
        super(DepthEstimationModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # 224 -> 112
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # 112 -> 56
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # 56 -> 28
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 28 -> 56
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 56 -> 112
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), # 112 -> 224
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"using device: {device}");

using device: cuda


In [5]:
model = DepthEstimationModel()
model.load_state_dict(torch.load("depth_estimation_model.pth", weights_only=True))
model.to(device)

def predict_depth(model, image_path, transform, output_path="predicted_depth.png"):
    model.eval()
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform.img_transform(image).unsqueeze(0).to(device)
    with torch.inference_mode():
        predicted_depth = model(input_tensor)
        predicted_depth = predicted_depth.squeeze().cpu().numpy()
    plt.imsave(output_path, predicted_depth, cmap="gray")
    print(f"predicted image save: {output_path}")

image_path = "./00000_colors.png"
predict_depth(model, image_path, NYUDepthTransform())
    

predicted image save: predicted_depth.png
