In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from nyudataset import NYUDepthDataset

In [3]:
class DepthNetV1(nn.Module):
    def __init__(self):
        super(DepthNetV1, self).__init__()

        base_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.encoder = nn.Sequential(*list(base_model.children())[:-3])

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, 3, stride=1, padding=1),
            nn.ReLU()
        )
    
    def forward(self, X):
        X = self.encoder.forward(X)
        X = self.decoder.forward(X)
        return X

In [4]:
from math import exp

# Implementation from: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = torch.autograd.Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)
    
def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

In [5]:
def depth_loss(y_pred, y_true):
    l1_loss = F.l1_loss(y_pred, y_true)
    ssim_loss = 1 - ssim(y_pred, y_true)
    return l1_loss + 0.1 * ssim_loss

def delta_accuracy(y_pred, y_true, threshold=1.25):
    ratio = torch.maximum(y_pred / y_true, y_true / y_pred)
    return (ratio < threshold).float().mean().item()

In [6]:
basic_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset_train = NYUDepthDataset(
    root_dir="./nyu_depth_data",
    csv_index="nyu2_train.csv",
    transform=basic_transform
)

dataset_test = NYUDepthDataset(
    root_dir="./nyu_depth_data",
    csv_index="nyu2_test.csv",
    transform=basic_transform
)

dataloader_train = torch.utils.data.DataLoader(
    dataset=dataset_train,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

dataloader_test = torch.utils.data.DataLoader(
    dataset=dataset_test,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Input shape - (batch, 3, width, height)
# Output shape - (batch, 1, width, height)
model = DepthNetV1().to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-5
)

criterion = depth_loss

num_epochs = 10

In [None]:
for epoch in range(1, num_epochs+1):
    model.train()

    total_loss = 0
    for batch, (images, depth_maps) in enumerate(dataloader_train, 1):
        images, depth_maps = images.to(device), depth_maps.to(device)

        y_pred = model.forward(images)
        loss = criterion(y_pred, depth_maps)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss
        if batch % 20 == 0:
            print(f"\t[{batch}/{len(dataloader_train)}] Batch depth loss: {loss:.4f}")
            
    avg_loss = total_loss / batch    
    print(f"[{epoch}/{num_epochs}] Depth AVG loss: {avg_loss:.4f}")

In [None]:
torch.save(model, "DepthNetV1-T.pth")