# Prepare Environment
```
conda create --name hydranet pytorch torchvision torchaudio pytorch-cuda -c pytorch -c nvidia
conda activate hydranet
pip install opencv-python matplotlib ipykernel tqdm notebook
```

# Import Libraries
These are all the libraries you'll need throughout the notebook. The next several sections have code similar to the previous modules in this course in order to define a consistent model architecture for the transfer learning exercise. 

In [9]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from PIL import Image
import glob
from utils import Normalise, RandomCrop, ToTensor, RandomMirror, InvHuberLoss, AverageMeter, MeanIoU, RMSE
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from model_helpers import Saver, load_state_dict
import operator
import torch.nn.functional as F
from tqdm import tqdm
from torch.autograd import Variable
import cv2

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as co
from IPython.display import HTML
from base64 import b64encode
from IPython.display import Video

# Create Dataset
Setup the PyTorch classes to handle the NYU Depth Dataset.

In [10]:
img_scale = 1.0 / 255
depth_scale = 5000.0

img_mean = np.array([0.485, 0.456, 0.406])
img_std = np.array([0.229, 0.224, 0.225])

normalise_params = [img_scale, img_mean.reshape((1, 1, 3)), img_std.reshape((1, 1, 3)), depth_scale,]

transform_common = [Normalise(*normalise_params), ToTensor()]

crop_size = 400
transform_train = transforms.Compose([RandomMirror(), RandomCrop(crop_size)] + transform_common)
transform_val = transforms.Compose(transform_common)

train_batch_size = 4
val_batch_size = 4
train_file = "train_list_depth.txt"
val_file = "val_list_depth.txt"

depth = sorted(glob.glob("nyud/depth/*.png"))
seg = sorted(glob.glob("nyud/masks/*.png"))
images = sorted(glob.glob("nyud/rgb/*.png"))

class HydranetDataset(Dataset):

    def __init__(self, data_file, transform=None):
        with open(data_file, "rb") as f:
            datalist = f.readlines()
        self.datalist = [x.decode("utf-8").strip("\n").split("\t") for x in datalist]
        self.root_dir = "nyud"
        self.transform = transform
        self.masks_names = ("segm", "depth")

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, idx):
        abs_paths = [os.path.join(self.root_dir, rpath) for rpath in self.datalist[idx]] # Will output list of nyud/*/00000.png
        sample = {}
        sample["image"] = np.array(Image.open(abs_paths[0])) #dtype = np.float32

        for mask_name, mask_path in zip(self.masks_names, abs_paths[1:]):
            mask = np.array(Image.open(mask_path))
            assert len(mask.shape) == 2, "Masks must be encoded without colourmap"
            sample[mask_name] = mask

        if self.transform:
            sample["names"] = self.masks_names
            sample = self.transform(sample)
            # the names key can be removed by the transformation
            if "names" in sample:
                del sample["names"]
        return sample
    
#TRAIN DATALOADER
trainloader = DataLoader(
    HydranetDataset(train_file, transform=transform_train,),
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

# VALIDATION DATALOADER
valloader = DataLoader(
    HydranetDataset(val_file, transform=transform_val,),
    batch_size=val_batch_size, 
    shuffle=False, 
    num_workers=4, 
    pin_memory=True,
    drop_last=False,)

# Encoder
Define all the necessary layers for MobileNet, load the pre-trained weights and freeze them to speed up training.

In [11]:
def conv1x1(in_channels, out_channels, stride=1, groups=1, bias=False,):
    "1x1 Convolution: Pointwise"
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias, groups=groups)

def conv3x3(in_channels, out_channels, stride=1, dilation=1, groups=1, bias=False):
    """3x3 Convolution: Depthwise: 
    https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    """
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=bias, groups=groups)

def batchnorm(num_features):
    """
    https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
    """
    return nn.BatchNorm2d(num_features, affine=True, eps=1e-5, momentum=0.1)

def convbnrelu(in_channels, out_channels, kernel_size, stride=1, groups=1, act=True):
    "conv-batchnorm-relu"
    if act:
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=int(kernel_size / 2.), groups=groups, bias=False),
                             batchnorm(out_channels),
                             nn.ReLU6(inplace=True))
    else:
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=int(kernel_size / 2.), groups=groups, bias=False),
                             batchnorm(out_channels))

class InvertedResidualBlock(nn.Module):
    """Inverted Residual Block from https://arxiv.org/abs/1801.04381"""
    def __init__(self, in_planes, out_planes, expansion_factor, stride=1):
        super().__init__() # Python 3
        intermed_planes = in_planes * expansion_factor
        self.residual = (in_planes == out_planes) and (stride == 1) # Boolean/Condition
        self.output = nn.Sequential(convbnrelu(in_planes, intermed_planes, 1),
                                    convbnrelu(intermed_planes, intermed_planes, 3, stride=stride, groups=intermed_planes),
                                    convbnrelu(intermed_planes, out_planes, 1, act=False))

    def forward(self, x):
        #residual = x
        out = self.output(x)
        if self.residual:
            return (out + x)#+residual
        else:
            return out

class MobileNetv2(nn.Module):
    def __init__(self, return_idx=[6]):
        super().__init__()
        # expansion rate, output channels, number of repeats, stride
        self.mobilenet_config = [
        [1, 16, 1, 1],
        [6, 24, 2, 2],
        [6, 32, 3, 2],
        [6, 64, 4, 2],
        [6, 96, 3, 1],
        [6, 160, 3, 2],
        [6, 320, 1, 1],
        ]
        self.in_channels = 32  # number of input channels
        self.num_layers = len(self.mobilenet_config)
        self.layer1 = convbnrelu(3, self.in_channels, kernel_size=3, stride=2)
    
        self.return_idx = [1, 2, 3, 4, 5, 6]
        #self.return_idx = make_list(return_idx)

        c_layer = 2
        for t, c, n, s in self.mobilenet_config:
            layers = []
            for idx in range(n):
                layers.append(InvertedResidualBlock(self.in_channels,c,expansion_factor=t,stride=s if idx == 0 else 1,))
                self.in_channels = c
            setattr(self, "layer{}".format(c_layer), nn.Sequential(*layers))
            c_layer += 1

        self._out_c = [self.mobilenet_config[idx][1] for idx in self.return_idx] # Output: [24, 32, 64, 96, 160, 320]

    def forward(self, x):
        outs = []
        x = self.layer1(x)
        outs.append(self.layer2(x))  # 16, x / 2
        outs.append(self.layer3(outs[-1]))  # 24, x / 4
        outs.append(self.layer4(outs[-1]))  # 32, x / 8
        outs.append(self.layer5(outs[-1]))  # 64, x / 16
        outs.append(self.layer6(outs[-1]))  # 96, x / 16
        outs.append(self.layer7(outs[-1]))  # 160, x / 32
        outs.append(self.layer8(outs[-1]))  # 320, x / 32
        return [outs[idx] for idx in self.return_idx]
    
encoder = MobileNetv2()
encoder.load_state_dict(torch.load("mobilenetv2-e6e8dd43.pth"))

# Freeze the MobileNet weights
for param in encoder.parameters():
    param.requires_grad = False

# Decoder
Create the decoder and initial model. Most things here are the same as the previous code but notice the parameter *return_backbone* has been added when constructing MTLWRefineNet. Later this will be used to get the outputs from the backbone which will be needed to train a new head, you can see how this is done in the forward method.

In [12]:
def make_list(x):
    """Returns the given input as a list."""
    if isinstance(x, list):
        return x
    elif isinstance(x, tuple):
        return list(x)
    else:
        return [x]
    
class CRPBlock(nn.Module):
    """CRP definition"""
    def __init__(self, in_planes, out_planes, n_stages, groups=False):
        super().__init__()
        for i in range(n_stages):
            setattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'),
                    conv1x1(in_planes if (i == 0) else out_planes,
                            out_planes, stride=1,
                            bias=False, groups=in_planes if groups else 1))
        self.stride = 1
        self.n_stages = n_stages
        self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)

    def forward(self, x):
        top = x
        for i in range(self.n_stages):
            top = self.maxpool(top)
            top = getattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'))(top)
            x = top + x
        return x

class MTLWRefineNet(nn.Module):
    def __init__(self, input_sizes, num_classes, agg_size=256, n_crp=4, return_backbone=False):
        super().__init__()

        stem_convs = nn.ModuleList()
        crp_blocks = nn.ModuleList()
        adapt_convs = nn.ModuleList()
        heads = nn.ModuleList()

        # Reverse since we recover information from the end
        input_sizes = list(reversed((input_sizes)))

        # No reverse for collapse indices is needed
        self.collapse_ind = [[0, 1], [2, 3], 4, 5]

        groups = [False] * len(self.collapse_ind)
        groups[-1] = True

        for size in input_sizes:
            stem_convs.append(conv1x1(size, agg_size, bias=False))

        for group in groups:
            crp_blocks.append(self._make_crp(agg_size, agg_size, n_crp, group))
            adapt_convs.append(conv1x1(agg_size, agg_size, bias=False))

        self.stem_convs = stem_convs
        self.crp_blocks = crp_blocks
        self.adapt_convs = adapt_convs[:-1]

        num_classes = list(num_classes)
        for n_out in num_classes:
            heads.append(
                nn.Sequential(
                    conv1x1(agg_size, agg_size, groups=agg_size, bias=False),
                    nn.ReLU6(inplace=False),
                    conv3x3(agg_size, n_out, bias=True),
                )
            )

        self.heads = heads
        self.relu = nn.ReLU6(inplace=True)
        self.return_backbone = return_backbone

    def forward(self, xs):
        xs = list(reversed(xs))
        for idx, (conv, x) in enumerate(zip(self.stem_convs, xs)):
            xs[idx] = conv(x)

        # Collapse layers
        c_xs = [sum([xs[idx] for idx in make_list(c_idx)]) for c_idx in self.collapse_ind ]

        for idx, (crp, x) in enumerate(zip(self.crp_blocks, c_xs)):
            if idx == 0:
                y = self.relu(x)
            else:
                y = self.relu(x + y)
            y = crp(y)
            if idx < (len(c_xs) - 1):
                y = self.adapt_convs[idx](y)
                y = F.interpolate(
                    y,
                    size=c_xs[idx + 1].size()[2:],
                    mode="bilinear",
                    align_corners=True,
                )

        outs = []
        for head in self.heads:
            outs.append(head(y))

        if self.return_backbone:
            outs.append(y)
        
        return outs

    @staticmethod
    def _make_crp(in_planes, out_planes, stages, groups):
        # Same as previous, but showing the use of a @staticmethod
        layers = [CRPBlock(in_planes, out_planes, stages, groups)]
        return nn.Sequential(*layers)
        
        
num_classes = (40, 1)
decoder = MTLWRefineNet(encoder._out_c, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hydranet = nn.DataParallel(nn.Sequential(encoder, decoder)).to(device)
print("Model has {} parameters".format(sum([p.numel() for p in hydranet.parameters()])))

Model has 3070057 parameters


# Loss
Create the loss functions to train the depth and segmentation heads.

In [13]:
ignore_index = 255
ignore_depth = 0

crit_segm  = nn.CrossEntropyLoss(ignore_index=ignore_index)
crit_depth = InvHuberLoss(ignore_index=ignore_depth)

# Optimizers
Create the encoder and decoder optimizers.

In [14]:
lr_encoder = 1e-2
lr_decoder = 1e-3
momentum_encoder = 0.9
momentum_decoder = 0.9
weight_decay_encoder = 1e-5
weight_decay_decoder = 1e-5

optims = [torch.optim.SGD(encoder.parameters(), lr=lr_encoder, momentum=momentum_encoder, weight_decay=weight_decay_encoder),
          torch.optim.SGD(decoder.parameters(), lr=lr_decoder, momentum=momentum_decoder, weight_decay=weight_decay_decoder)]

# Train
Define the train and validate methods.

In [15]:
init_vals = (0.0, 10000.0)
comp_fns = [operator.gt, operator.lt]
ckpt_dir = "./base_model"
ckpt_path = "checkpoint.pth.tar"

batch_size = 16
val_batch_size = 16
val_every = 5
loss_coeffs = (0.5, 0.5)
n_epochs = 500

saver = Saver(
    #args=locals(), # Causes issues in vscode
    args={"items": None},
    ckpt_dir=ckpt_dir,
    best_val=init_vals,
    condition=comp_fns,
    save_several_mode=all,
)

start_epoch, _, state_dict = saver.maybe_load(ckpt_path=ckpt_path, keys_to_load=["epoch", "best_val", "state_dict"],)
load_state_dict(hydranet, state_dict)

if start_epoch is None:
    start_epoch = 0

opt_scheds = []
for opt in optims:
    opt_scheds.append(torch.optim.lr_scheduler.MultiStepLR(opt, np.arange(start_epoch + 1, n_epochs, 100), gamma=0.1))


def train(model, opts, crits, dataloader, loss_coeffs=(1.0,), grad_norm=0.0):
    model.train()
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    loss_meter = AverageMeter()
    pbar = tqdm(dataloader)

    for sample in pbar:
        loss = 0.0
        input = sample["image"].float().to(device)
        targets = [sample[k].to(device) for k in dataloader.dataset.masks_names]
        outputs = model(input) # Forward

        for out, target, crit, loss_coeff in zip(outputs, targets, crits, loss_coeffs):
            loss += loss_coeff * crit(
                F.interpolate(
                    out, size=target.size()[1:], mode="bilinear", align_corners=False
                ).squeeze(dim=1),
                target.squeeze(dim=1),
            )

        # Backward
        for opt in opts:
            opt.zero_grad()
        loss.backward()
        if grad_norm > 0.0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
        for opt in opts:
            opt.step()

        loss_meter.update(loss.item())
        pbar.set_description(
            "Loss {:.3f} | Avg. Loss {:.3f}".format(loss.item(), loss_meter.avg)
        )

def validate(model, metrics, dataloader):
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    model.eval()
    for metric in metrics:
        metric.reset()

    pbar = tqdm(dataloader)

    def get_val(metrics):
        results = [(m.name, m.val()) for m in metrics]
        names, vals = list(zip(*results))
        out = ["{} : {:4f}".format(name, val) for name, val in results]
        return vals, " | ".join(out)

    with torch.no_grad():
        for sample in pbar:
            # Get the Data
            input = sample["image"].float().to(device)
            targets = [sample[k].to(device) for k in dataloader.dataset.masks_names]

            targets = [target.squeeze(dim=1).cpu().numpy() for target in targets]

            outputs = model(input) # Forward

            # Backward
            for out, target, metric in zip(outputs, targets, metrics):
                metric.update(
                    F.interpolate(out, size=target.shape[1:], mode="bilinear", align_corners=False)
                    .squeeze(dim=1)
                    .cpu()
                    .numpy(),
                    target,
                )
            pbar.set_description(get_val(metrics)[1])
    vals, _ = get_val(metrics)
    print("----" * 5)
    return vals

Train the initial model

In [16]:
for i in range(start_epoch, n_epochs + 1):
    
    print("Epoch {:d}".format(i))
    train(hydranet, optims, [crit_segm, crit_depth], trainloader, loss_coeffs)

    for sched in opt_scheds:
        sched.step()
    
    if i % val_every == 0:
        metrics = [MeanIoU(num_classes[0]),RMSE(ignore_val=ignore_depth),]

        with torch.no_grad():
            vals = validate(hydranet, metrics, valloader)
        saver.maybe_save(new_val=vals, dict_to_save={"state_dict": hydranet.state_dict(), "epoch": i})

Epoch 205


Loss 1.092 | Avg. Loss 1.394: 100%|██████████| 198/198 [00:15<00:00, 13.14it/s]
meaniou : 0.075640 | rmse : 0.979666: 100%|██████████| 164/164 [01:49<00:00,  1.50it/s]


--------------------
Epoch 206


Loss 1.207 | Avg. Loss 1.376: 100%|██████████| 198/198 [00:15<00:00, 13.13it/s]


Epoch 207


Loss 1.318 | Avg. Loss 1.359: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]


Epoch 208


Loss 1.482 | Avg. Loss 1.340: 100%|██████████| 198/198 [00:15<00:00, 13.10it/s]


Epoch 209


Loss 1.290 | Avg. Loss 1.366: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]


Epoch 210


Loss 1.457 | Avg. Loss 1.269: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]
meaniou : 0.098136 | rmse : 0.908937: 100%|██████████| 164/164 [01:49<00:00,  1.50it/s]


--------------------
Epoch 211


Loss 1.100 | Avg. Loss 1.260: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]


Epoch 212


Loss 1.354 | Avg. Loss 1.290: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]


Epoch 213


Loss 1.230 | Avg. Loss 1.322: 100%|██████████| 198/198 [00:15<00:00, 13.10it/s]


Epoch 214


Loss 1.182 | Avg. Loss 1.280: 100%|██████████| 198/198 [00:15<00:00, 13.06it/s]


Epoch 215


Loss 1.078 | Avg. Loss 1.220: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]
meaniou : 0.114806 | rmse : 0.912386: 100%|██████████| 164/164 [01:49<00:00,  1.50it/s]


--------------------
Epoch 216


Loss 1.212 | Avg. Loss 1.218: 100%|██████████| 198/198 [00:14<00:00, 13.21it/s]


Epoch 217


Loss 0.972 | Avg. Loss 1.193: 100%|██████████| 198/198 [00:15<00:00, 13.07it/s]


Epoch 218


Loss 1.401 | Avg. Loss 1.266: 100%|██████████| 198/198 [00:15<00:00, 13.06it/s]


Epoch 219


Loss 1.288 | Avg. Loss 1.161: 100%|██████████| 198/198 [00:15<00:00, 13.09it/s]


Epoch 220


Loss 1.205 | Avg. Loss 1.159: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]
meaniou : 0.111765 | rmse : 0.862699: 100%|██████████| 164/164 [01:48<00:00,  1.52it/s]


--------------------
Epoch 221


Loss 1.382 | Avg. Loss 1.158: 100%|██████████| 198/198 [00:15<00:00, 13.12it/s]


Epoch 222


Loss 1.035 | Avg. Loss 1.167: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]


Epoch 223


Loss 1.485 | Avg. Loss 1.166: 100%|██████████| 198/198 [00:15<00:00, 13.12it/s]


Epoch 224


Loss 1.288 | Avg. Loss 1.188: 100%|██████████| 198/198 [00:15<00:00, 13.08it/s]


Epoch 225


Loss 1.362 | Avg. Loss 1.132: 100%|██████████| 198/198 [00:15<00:00, 13.06it/s]
meaniou : 0.128287 | rmse : 0.886979: 100%|██████████| 164/164 [01:50<00:00,  1.49it/s]


--------------------
Epoch 226


Loss 1.140 | Avg. Loss 1.183: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]


Epoch 227


Loss 1.158 | Avg. Loss 1.172: 100%|██████████| 198/198 [00:15<00:00, 13.15it/s]


Epoch 228


Loss 0.916 | Avg. Loss 1.118: 100%|██████████| 198/198 [00:15<00:00, 13.07it/s]


Epoch 229


Loss 1.302 | Avg. Loss 1.180: 100%|██████████| 198/198 [00:15<00:00, 13.08it/s]


Epoch 230


Loss 1.081 | Avg. Loss 1.112: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]
meaniou : 0.147989 | rmse : 0.893857: 100%|██████████| 164/164 [01:47<00:00,  1.52it/s]


--------------------
Epoch 231


Loss 0.962 | Avg. Loss 1.118: 100%|██████████| 198/198 [00:14<00:00, 13.24it/s]


Epoch 232


Loss 1.229 | Avg. Loss 1.079: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]


Epoch 233


Loss 0.881 | Avg. Loss 1.106: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]


Epoch 234


Loss 1.126 | Avg. Loss 1.062: 100%|██████████| 198/198 [00:15<00:00, 13.08it/s]


Epoch 235


Loss 1.170 | Avg. Loss 1.054: 100%|██████████| 198/198 [00:15<00:00, 13.10it/s]
meaniou : 0.147022 | rmse : 0.909321: 100%|██████████| 164/164 [01:47<00:00,  1.52it/s]


--------------------
Epoch 236


Loss 1.235 | Avg. Loss 1.068: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]


Epoch 237


Loss 1.177 | Avg. Loss 1.086: 100%|██████████| 198/198 [00:15<00:00, 13.10it/s]


Epoch 238


Loss 1.044 | Avg. Loss 1.060: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]


Epoch 239


Loss 1.042 | Avg. Loss 1.014: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]


Epoch 240


Loss 0.945 | Avg. Loss 1.052: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]
meaniou : 0.163417 | rmse : 0.826619: 100%|██████████| 164/164 [01:49<00:00,  1.50it/s]


--------------------
Epoch 241


Loss 1.094 | Avg. Loss 1.009: 100%|██████████| 198/198 [00:14<00:00, 13.27it/s]


Epoch 242


Loss 0.942 | Avg. Loss 1.092: 100%|██████████| 198/198 [00:15<00:00, 13.09it/s]


Epoch 243


Loss 1.003 | Avg. Loss 1.033: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]


Epoch 244


Loss 0.866 | Avg. Loss 1.022: 100%|██████████| 198/198 [00:15<00:00, 12.99it/s]


Epoch 245


Loss 0.984 | Avg. Loss 1.039: 100%|██████████| 198/198 [00:15<00:00, 13.09it/s]
meaniou : 0.175906 | rmse : 0.816754: 100%|██████████| 164/164 [01:49<00:00,  1.49it/s]


--------------------
Epoch 246


Loss 0.968 | Avg. Loss 0.973: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]


Epoch 247


Loss 0.937 | Avg. Loss 1.031: 100%|██████████| 198/198 [00:15<00:00, 13.15it/s]


Epoch 248


Loss 0.910 | Avg. Loss 1.006: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]


Epoch 249


Loss 1.152 | Avg. Loss 0.991: 100%|██████████| 198/198 [00:15<00:00, 13.12it/s]


Epoch 250


Loss 0.875 | Avg. Loss 0.994: 100%|██████████| 198/198 [00:15<00:00, 13.19it/s]
meaniou : 0.186285 | rmse : 0.829028: 100%|██████████| 164/164 [01:50<00:00,  1.48it/s]


--------------------
Epoch 251


Loss 0.887 | Avg. Loss 0.960: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]


Epoch 252


Loss 0.891 | Avg. Loss 0.999: 100%|██████████| 198/198 [00:15<00:00, 13.14it/s]


Epoch 253


Loss 0.898 | Avg. Loss 0.971: 100%|██████████| 198/198 [00:15<00:00, 13.08it/s]


Epoch 254


Loss 0.964 | Avg. Loss 0.997: 100%|██████████| 198/198 [00:15<00:00, 13.07it/s]


Epoch 255


Loss 1.109 | Avg. Loss 0.960: 100%|██████████| 198/198 [00:15<00:00, 13.04it/s]
meaniou : 0.190064 | rmse : 0.835472: 100%|██████████| 164/164 [01:50<00:00,  1.48it/s]


--------------------
Epoch 256


Loss 1.118 | Avg. Loss 0.940: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]


Epoch 257


Loss 1.086 | Avg. Loss 0.962: 100%|██████████| 198/198 [00:14<00:00, 13.23it/s]


Epoch 258


Loss 1.076 | Avg. Loss 0.929: 100%|██████████| 198/198 [00:15<00:00, 13.14it/s]


Epoch 259


Loss 0.804 | Avg. Loss 0.947: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]


Epoch 260


Loss 0.872 | Avg. Loss 0.919: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]
meaniou : 0.188537 | rmse : 0.820088: 100%|██████████| 164/164 [01:50<00:00,  1.48it/s]


--------------------
Epoch 261


Loss 0.806 | Avg. Loss 0.903: 100%|██████████| 198/198 [00:15<00:00, 13.12it/s]


Epoch 262


Loss 1.185 | Avg. Loss 0.977: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]


Epoch 263


Loss 0.965 | Avg. Loss 0.917: 100%|██████████| 198/198 [00:15<00:00, 13.13it/s]


Epoch 264


Loss 0.971 | Avg. Loss 0.917: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]


Epoch 265


Loss 0.998 | Avg. Loss 0.908: 100%|██████████| 198/198 [00:15<00:00, 13.13it/s]
meaniou : 0.208602 | rmse : 0.808919: 100%|██████████| 164/164 [01:50<00:00,  1.48it/s]


--------------------
Epoch 266


Loss 0.843 | Avg. Loss 0.922: 100%|██████████| 198/198 [00:15<00:00, 13.19it/s]


Epoch 267


Loss 0.860 | Avg. Loss 0.926: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]


Epoch 268


Loss 0.795 | Avg. Loss 0.918: 100%|██████████| 198/198 [00:15<00:00, 13.10it/s]


Epoch 269


Loss 1.115 | Avg. Loss 0.948: 100%|██████████| 198/198 [00:15<00:00, 13.12it/s]


Epoch 270


Loss 1.076 | Avg. Loss 0.909: 100%|██████████| 198/198 [00:14<00:00, 13.22it/s]
meaniou : 0.199809 | rmse : 0.807512: 100%|██████████| 164/164 [01:50<00:00,  1.48it/s]


--------------------
Epoch 271


Loss 1.018 | Avg. Loss 0.893: 100%|██████████| 198/198 [00:15<00:00, 13.15it/s]


Epoch 272


Loss 0.826 | Avg. Loss 0.920: 100%|██████████| 198/198 [00:15<00:00, 13.15it/s]


Epoch 273


Loss 0.879 | Avg. Loss 0.896: 100%|██████████| 198/198 [00:15<00:00, 13.02it/s]


Epoch 274


Loss 0.722 | Avg. Loss 0.861: 100%|██████████| 198/198 [00:15<00:00, 13.14it/s]


Epoch 275


Loss 0.740 | Avg. Loss 0.937: 100%|██████████| 198/198 [00:15<00:00, 13.13it/s]
meaniou : 0.202367 | rmse : 0.810330: 100%|██████████| 164/164 [01:51<00:00,  1.47it/s]


--------------------
Epoch 276


Loss 0.762 | Avg. Loss 0.928: 100%|██████████| 198/198 [00:15<00:00, 13.19it/s]


Epoch 277


Loss 0.865 | Avg. Loss 0.908: 100%|██████████| 198/198 [00:15<00:00, 13.15it/s]


Epoch 278


Loss 0.784 | Avg. Loss 0.861: 100%|██████████| 198/198 [00:15<00:00, 13.07it/s]


Epoch 279


Loss 0.907 | Avg. Loss 0.868: 100%|██████████| 198/198 [00:15<00:00, 13.06it/s]


Epoch 280


Loss 0.956 | Avg. Loss 0.895: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]
meaniou : 0.216132 | rmse : 0.811019: 100%|██████████| 164/164 [01:51<00:00,  1.48it/s]


--------------------
Epoch 281


Loss 0.945 | Avg. Loss 0.879: 100%|██████████| 198/198 [00:15<00:00, 13.15it/s]


Epoch 282


Loss 0.637 | Avg. Loss 0.853: 100%|██████████| 198/198 [00:15<00:00, 13.12it/s]


Epoch 283


Loss 0.785 | Avg. Loss 0.850: 100%|██████████| 198/198 [00:15<00:00, 13.13it/s]


Epoch 284


Loss 1.060 | Avg. Loss 0.877: 100%|██████████| 198/198 [00:15<00:00, 13.10it/s]


Epoch 285


Loss 0.743 | Avg. Loss 0.812: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]
meaniou : 0.220676 | rmse : 0.817767: 100%|██████████| 164/164 [01:50<00:00,  1.49it/s]


--------------------
Epoch 286


Loss 0.771 | Avg. Loss 0.816: 100%|██████████| 198/198 [00:15<00:00, 13.15it/s]


Epoch 287


Loss 0.698 | Avg. Loss 0.877: 100%|██████████| 198/198 [00:15<00:00, 13.09it/s]


Epoch 288


Loss 0.940 | Avg. Loss 0.851: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]


Epoch 289


Loss 0.874 | Avg. Loss 0.836: 100%|██████████| 198/198 [00:15<00:00, 13.10it/s]


Epoch 290


Loss 0.932 | Avg. Loss 0.901: 100%|██████████| 198/198 [00:15<00:00, 13.10it/s]
meaniou : 0.220818 | rmse : 0.791059: 100%|██████████| 164/164 [01:50<00:00,  1.49it/s]


--------------------
Epoch 291


Loss 0.612 | Avg. Loss 0.819: 100%|██████████| 198/198 [00:14<00:00, 13.25it/s]


Epoch 292


Loss 1.027 | Avg. Loss 0.853: 100%|██████████| 198/198 [00:14<00:00, 13.30it/s]


Epoch 293


Loss 0.949 | Avg. Loss 0.844: 100%|██████████| 198/198 [00:15<00:00, 13.20it/s]


Epoch 294


Loss 0.842 | Avg. Loss 0.814: 100%|██████████| 198/198 [00:15<00:00, 13.14it/s]


Epoch 295


Loss 0.794 | Avg. Loss 0.806: 100%|██████████| 198/198 [00:15<00:00, 13.20it/s]
meaniou : 0.228635 | rmse : 0.796994: 100%|██████████| 164/164 [01:53<00:00,  1.44it/s]


--------------------
Epoch 296


Loss 0.889 | Avg. Loss 0.785: 100%|██████████| 198/198 [00:14<00:00, 13.26it/s]


Epoch 297


Loss 0.874 | Avg. Loss 0.800: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]


Epoch 298


Loss 0.667 | Avg. Loss 0.800: 100%|██████████| 198/198 [00:15<00:00, 13.19it/s]


Epoch 299


Loss 0.768 | Avg. Loss 0.779: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]


Epoch 300


Loss 0.889 | Avg. Loss 0.798: 100%|██████████| 198/198 [00:14<00:00, 13.22it/s]
meaniou : 0.213478 | rmse : 0.842136: 100%|██████████| 164/164 [01:50<00:00,  1.48it/s]


--------------------
Epoch 301


Loss 0.933 | Avg. Loss 0.780: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]


Epoch 302


Loss 0.977 | Avg. Loss 0.808: 100%|██████████| 198/198 [00:14<00:00, 13.22it/s]


Epoch 303


Loss 0.981 | Avg. Loss 0.818: 100%|██████████| 198/198 [00:15<00:00, 13.19it/s]


Epoch 304


Loss 0.664 | Avg. Loss 0.782: 100%|██████████| 198/198 [00:14<00:00, 13.21it/s]


Epoch 305


Loss 0.842 | Avg. Loss 0.779: 100%|██████████| 198/198 [00:15<00:00, 13.18it/s]
meaniou : 0.211166 | rmse : 0.798123: 100%|██████████| 164/164 [01:51<00:00,  1.47it/s]


--------------------
Epoch 306


Loss 0.838 | Avg. Loss 0.778: 100%|██████████| 198/198 [00:14<00:00, 13.27it/s]


Epoch 307


Loss 0.761 | Avg. Loss 0.785: 100%|██████████| 198/198 [00:14<00:00, 13.26it/s]


Epoch 308


Loss 0.766 | Avg. Loss 0.798: 100%|██████████| 198/198 [00:14<00:00, 13.23it/s]


Epoch 309


Loss 0.806 | Avg. Loss 0.777: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]


Epoch 310


Loss 0.720 | Avg. Loss 0.773: 100%|██████████| 198/198 [00:15<00:00, 13.06it/s]
meaniou : 0.234403 | rmse : 0.802896: 100%|██████████| 164/164 [01:45<00:00,  1.55it/s]


--------------------
Epoch 311


Loss 0.706 | Avg. Loss 0.795: 100%|██████████| 198/198 [00:14<00:00, 13.22it/s]


Epoch 312


Loss 0.688 | Avg. Loss 0.749: 100%|██████████| 198/198 [00:14<00:00, 13.24it/s]


Epoch 313


Loss 0.622 | Avg. Loss 0.772: 100%|██████████| 198/198 [00:15<00:00, 13.17it/s]


Epoch 314


Loss 0.688 | Avg. Loss 0.785: 100%|██████████| 198/198 [00:15<00:00, 13.20it/s]


Epoch 315


Loss 0.738 | Avg. Loss 0.777: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]
meaniou : 0.234400 | rmse : 0.814736: 100%|██████████| 164/164 [01:48<00:00,  1.51it/s]


--------------------
Epoch 316


Loss 0.731 | Avg. Loss 0.746: 100%|██████████| 198/198 [00:15<00:00, 13.14it/s]


Epoch 317


Loss 0.740 | Avg. Loss 0.733: 100%|██████████| 198/198 [00:15<00:00, 13.15it/s]


Epoch 318


Loss 0.624 | Avg. Loss 0.771: 100%|██████████| 198/198 [00:15<00:00, 13.14it/s]


Epoch 319


Loss 0.808 | Avg. Loss 0.749: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]


Epoch 320


Loss 0.609 | Avg. Loss 0.752: 100%|██████████| 198/198 [00:15<00:00, 13.13it/s]
meaniou : 0.234563 | rmse : 0.822322: 100%|██████████| 164/164 [01:51<00:00,  1.47it/s]


--------------------
Epoch 321


Loss 0.746 | Avg. Loss 0.726: 100%|██████████| 198/198 [00:14<00:00, 13.24it/s]


Epoch 322


Loss 0.682 | Avg. Loss 0.739: 100%|██████████| 198/198 [00:15<00:00, 13.19it/s]


Epoch 323


Loss 1.023 | Avg. Loss 0.764: 100%|██████████| 198/198 [00:15<00:00, 13.13it/s]


Epoch 324


Loss 0.686 | Avg. Loss 0.766: 100%|██████████| 198/198 [00:15<00:00, 13.13it/s]


Epoch 325


Loss 0.786 | Avg. Loss 0.712: 100%|██████████| 198/198 [00:15<00:00, 13.16it/s]
meaniou : 0.233812 | rmse : 0.825800: 100%|██████████| 164/164 [01:48<00:00,  1.51it/s]


--------------------
Epoch 326


Loss 0.604 | Avg. Loss 0.696: 100%|██████████| 198/198 [00:15<00:00, 13.14it/s]


Epoch 327


Loss 0.627 | Avg. Loss 0.738: 100%|██████████| 198/198 [00:15<00:00, 13.19it/s]


Epoch 328


Loss 0.718 | Avg. Loss 0.722: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]


Epoch 329


Loss 0.623 | Avg. Loss 0.729: 100%|██████████| 198/198 [00:15<00:00, 13.11it/s]


Epoch 330


Loss 0.660 | Avg. Loss 0.707: 100%|██████████| 198/198 [00:15<00:00, 13.07it/s]
meaniou : 0.134420 | rmse : 1.024391:   2%|▏         | 4/164 [00:03<01:55,  1.38it/s]

# Inference
Let's load the best model and visualize what it learned.

In [None]:
checkpoint = 'base_model/checkpoint.pth.tar'
model = nn.DataParallel(nn.Sequential(encoder, decoder)).to(device)
model.load_state_dict(torch.load(checkpoint)['state_dict'])
res = model.eval()

Get and display a random image.

In [None]:
# Pre-processing and post-processing constants 
IMG_SCALE  = 1./255
IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))

def prepare_img(img):
    return (img * IMG_SCALE - IMG_MEAN) / IMG_STD

CMAP = np.load('cmap_nyud.npy')
NUM_CLASSES = 40

images_files = glob.glob('data/*.png')
idx = np.random.randint(0, len(images_files))

img_path = images_files[idx]
img = np.array(Image.open(img_path))
plt.imshow(img)
plt.show()

Get and display the model's output for that image.

In [None]:
def pipeline(img):
    with torch.no_grad():
        img_var = Variable(torch.from_numpy(prepare_img(img).transpose(2, 0, 1)[None]), requires_grad=False).float()
        if torch.cuda.is_available():
            img_var = img_var.cuda()
        segm, depth = model(img_var)
        segm = cv2.resize(segm[0, :NUM_CLASSES].cpu().data.numpy().transpose(1, 2, 0),
                        img.shape[:2][::-1],
                        interpolation=cv2.INTER_CUBIC)
        depth = cv2.resize(depth[0, 0].cpu().data.numpy(),
                        img.shape[:2][::-1],
                        interpolation=cv2.INTER_CUBIC)
        segm = CMAP[segm.argmax(axis=2)].astype(np.uint8)
        depth = np.abs(depth)
        return depth, segm
    
depth, segm = pipeline(img)

f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30,20))
ax1.imshow(img)
ax1.set_title('Original', fontsize=30)
ax2.imshow(segm)
ax2.set_title('Predicted Segmentation', fontsize=30)
ax3.imshow(depth, cmap="plasma", vmin=0, vmax=80)
ax3.set_title("Predicted Depth", fontsize=30)
plt.show()

Let's normalize the depth and visualize it.

In [None]:
def depth_to_rgb(depth):
    normalizer = co.Normalize()
    mapper = cm.ScalarMappable(norm=normalizer, cmap='plasma')
    colormapped_im = (mapper.to_rgba(depth)[:, :, :3] * 255).astype(np.uint8)
    return colormapped_im

depth_rgb = depth_to_rgb(depth)

new_img = np.vstack((img, segm, depth_rgb))
plt.imshow(new_img)
plt.show()

Let's see how it looks as a video.

In [None]:
video_files = sorted(glob.glob("data/*.png"))

# Set the Model to Eval on GPU
if torch.cuda.is_available():
    _ = model.cuda()
_ = model.eval()

# Run the pipeline
result_video = []
for idx, img_path in enumerate(video_files):
    image = np.array(Image.open(img_path))
    h, w, _ = image.shape 
    depth, seg = pipeline(image)
    result_video.append(cv2.cvtColor(cv2.vconcat([image, seg, depth_to_rgb(depth)]), cv2.COLOR_BGR2RGB))

if not os.path.exists('output'):
    os.makedirs('output')

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output/out.mp4', fourcc, 15, (w,3*h))

for i in range(len(result_video)):
    out.write(result_video[i])
out.release()

# Extend the Model
There are many ways to transfer weights from one model to another, this approach uses composition to get backbone output from the base model to use it as input to an additional head. The base model is constructed in the same way however the parameter *return_backbone* is now set to true. This model is loaded with weights from the checkpoint and they are frozen to prevent changes that would affect the heads on the base model which also allows us to use larger batch sizes and ultimately train faster. Next the architecture of the new head is defined and the forward method shows how the backbone output is taken from the base model and passed to the new head. In order to minimize the additional complexities the new head is trained on the existing depth data however it allows us to experiment with different loss functions and optimizers than those used to train the base model.

In [None]:
ckpt_dir = "./extended_model"
ckpt_path = "checkpoint.pth.tar"

saver = Saver(
    #args=locals(), # Causes issues in vscode
    args={"items": None},
    ckpt_dir=ckpt_dir,
    best_val=init_vals,
    condition=comp_fns,
    save_several_mode=all,
)

num_classes = (40, 1)
backbone_path = "./base_model/checkpoint.pth.tar"

class Extended_Model(nn.Module):
    def __init__(self):
        super().__init__()

        # Create an instance of the base model and load the saved weights
        encoder = MobileNetv2()
        decoder = MTLWRefineNet(encoder._out_c, num_classes, return_backbone=True)        
        self.net1 = nn.DataParallel(nn.Sequential(encoder, decoder))
        self.net1.load_state_dict(torch.load(backbone_path)['state_dict'])

        # Freeze the trained weights
        for param in self.net1.parameters():
            param.requires_grad = False

        conv1 = nn.Conv2d(256,128,1,1)
        conv2 = nn.Conv2d(128,64,3,1)
        conv3 = nn.Conv2d(64,1,3,1)
        self.net2 = nn.DataParallel(nn.Sequential(conv1, nn.ReLU6(inplace=True), conv2, nn.ReLU6(inplace=True), conv3, nn.ReLU6(inplace=True)))

    def forward(self, x):
        out = self.net1(x)
        y   = self.net2(out[-1])
        return out, y
    

hydranet = nn.DataParallel(Extended_Model()).to(device)

print("Model has {} parameters".format(sum([p.numel() for p in hydranet.parameters()])))

start_epoch, _, state_dict = saver.maybe_load(ckpt_path=ckpt_path, keys_to_load=["epoch", "best_val", "state_dict"],)
load_state_dict(hydranet, state_dict)

The display pipeline is modified to process the new head.

In [None]:
def pipeline(img):
    with torch.no_grad():
        img_var = Variable(torch.from_numpy(prepare_img(img).transpose(2, 0, 1)[None]), requires_grad=False).float()
        if torch.cuda.is_available():
            img_var = img_var.cuda()
        head1, depth2 = hydranet(img_var)    # The new forward method returns the old heads with the backbone outputs 
        segm, depth1 = head1[0], head1[1]    # and the new depth output, pull off what needs to be processed
        segm = cv2.resize(segm[0, :NUM_CLASSES].cpu().data.numpy().transpose(1, 2, 0),
                        img.shape[:2][::-1],
                        interpolation=cv2.INTER_CUBIC)
        depth1 = cv2.resize(depth1[0, 0].cpu().data.numpy(),
                        img.shape[:2][::-1],
                        interpolation=cv2.INTER_CUBIC)
        depth2 = cv2.resize(depth2[0, 0].cpu().data.numpy(),
                        img.shape[:2][::-1],
                        interpolation=cv2.INTER_CUBIC)
        segm = CMAP[segm.argmax(axis=2)].astype(np.uint8)
        depth1 = np.abs(depth1)
        depth2 = np.abs(depth2)
        return depth1, depth2, segm

Display the output before training.

In [None]:
depth1, depth2, segm = pipeline(img)

print(depth1.shape)
print(depth2.shape)
print(segm.shape)

depth1_rgb = depth_to_rgb(depth1)
depth2_rgb = depth_to_rgb(depth2)

untrained_img = np.vstack((img, segm, depth1_rgb, depth2_rgb))
plt.imshow(untrained_img)
plt.show()

In [None]:
video_files = sorted(glob.glob("data/*.png"))

# Set the Model to Eval on GPU
if torch.cuda.is_available():
    _ = hydranet.cuda()
_ = hydranet.eval()


# Run the pipeline
result_video = []
for idx, img_path in enumerate(video_files):
    image = np.array(Image.open(img_path))
    h, w, _ = image.shape 
    depth1, depth2, seg = pipeline(image)
    result_video.append(cv2.cvtColor(cv2.vconcat([image, seg, depth_to_rgb(depth1), depth_to_rgb(depth2)]), cv2.COLOR_BGR2RGB))

if not os.path.exists('output'):
    os.makedirs('output')

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output/out2.mp4', fourcc, 15, (w,4*h))

for i in range(len(result_video)):
    out.write(result_video[i])
out.release()

# Train Extended Model
The train and validate methods have been modified to work with the new output structure of the forward method and to calculate the metrics based on the new depth head.

In [None]:
init_vals = (0.0, 10000.0)
comp_fns = [operator.gt, operator.lt]
ckpt_dir = "./extended_model"
ckpt_path = "./extended/checkpoint.pth.tar"

saver = Saver(
    #args=locals(),
    args={"items": None},
    ckpt_dir=ckpt_dir,
    best_val=init_vals,
    condition=comp_fns,
    save_several_mode=all,
)

n_epochs = 200

hydranet = nn.DataParallel(Extended_Model())

print("Model has {} parameters".format(sum([p.numel() for p in hydranet.parameters()])))

start_epoch, _, state_dict = saver.maybe_load(ckpt_path=ckpt_path, keys_to_load=["epoch", "best_val", "state_dict"],)
load_state_dict(hydranet, state_dict)

if start_epoch is None:
    start_epoch = 0

def train(model, opts, crits, dataloader, loss_coeffs=(1.0,), grad_norm=0.0):
    model.train()

    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    loss_meter = AverageMeter()
    pbar = tqdm(dataloader)

    for sample in pbar:
        loss = 0.0
        input = sample["image"].float().to(device)
        targets = [sample[k].to(device) for k in dataloader.dataset.masks_names]
        outputs = model(input)

        loss = crits[0](
                F.interpolate(
                    outputs[-1], size=targets[-1].size()[1:], mode="bilinear", align_corners=False
                ).squeeze(dim=1),
                targets[-1].squeeze(dim=1),
            )
        
        opts[0].zero_grad()
        loss.backward()
        if grad_norm > 0.0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
        opts[0].step()

        loss_meter.update(loss.item())
        pbar.set_description(
            "Loss {:.3f} | Avg. Loss {:.3f}".format(loss.item(), loss_meter.avg)
        )

def validate(model, metrics, dataloader):
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    model.eval()
    for metric in metrics:
        metric.reset()

    pbar = tqdm(dataloader)

    def get_val(metrics):
        results = [(m.name, m.val()) for m in metrics]
        names, vals = list(zip(*results))
        out = ["{} : {:4f}".format(name, val) for name, val in results]
        return vals, " | ".join(out)

    with torch.no_grad():
        for sample in pbar:
            # Get the Data
            input = sample["image"].float().to(device)
            targets = [sample[k].to(device) for k in dataloader.dataset.masks_names]

            targets = [target.squeeze(dim=1).cpu().numpy() for target in targets]

            # Forward
            outputs, depth2 = model(input)
            outputs.pop(-1)                 # Remove the backbone from the outputs
            outputs[1] = depth2             # Use the new depth head for metric calculation

            # Backward
            for out, target, metric in zip(outputs, targets, metrics):
                metric.update(
                    F.interpolate(out, size=target.shape[1:], mode="bilinear", align_corners=False)
                    .squeeze(dim=1)
                    .cpu()
                    .numpy(),
                    target,
                )
            pbar.set_description(get_val(metrics)[1])
    vals, _ = get_val(metrics)
    print("----" * 5)
    return vals

In [None]:
print(start_epoch)
batch_size = 16
val_batch_size = 16
val_every = 5

# Try a different optimizer and loss function
optims = [torch.optim.Adam(hydranet.module.net2.parameters(), lr=lr_decoder)]
crit_depth = nn.MSELoss()

opt_scheds = []
for opt in optims:
    opt_scheds.append(torch.optim.lr_scheduler.MultiStepLR(opt, np.arange(start_epoch + 1, n_epochs, 100), gamma=0.1))
          
for i in range(start_epoch, n_epochs + 1):
    
    print("Epoch {:d}".format(i))
    train(hydranet, optims, [crit_depth], trainloader, loss_coeffs)

    for sched in opt_scheds:
        sched.step()
    
    if i % val_every == 0:
        metrics = [MeanIoU(num_classes[0]),RMSE(ignore_val=ignore_depth),]

        with torch.no_grad():
            vals = validate(hydranet, metrics, valloader)
        saver.maybe_save(new_val=vals, dict_to_save={"state_dict": hydranet.state_dict(), "epoch": i})

# Visualize Results
Now that it's trained let's see if the output looks better. Notice that the output for the base model's heads are unchanged.

In [None]:
depth1, depth2, segm = pipeline(img)
depth1_rgb = depth_to_rgb(depth1)
depth2_rgb = depth_to_rgb(depth2)

new_img = np.hstack(untrained_img, np.vstack((img, segm, depth1_rgb, depth2_rgb)))
plt.imshow(new_img)
plt.show()

In [None]:
video_files = sorted(glob.glob("data/*.png"))

# Set the Model to Eval on GPU
if torch.cuda.is_available():
    _ = hydranet.cuda()
_ = hydranet.eval()


# Run the pipeline
result_video = []
for idx, img_path in enumerate(video_files):
    image = np.array(Image.open(img_path))
    h, w, _ = image.shape 
    depth1, depth2, seg = pipeline(image)
    result_video.append(cv2.cvtColor(cv2.vconcat([image, seg, depth_to_rgb(depth1), depth_to_rgb(depth2)]), cv2.COLOR_BGR2RGB))

if not os.path.exists('output'):
    os.makedirs('output')

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
#fourcc = cv2.VideoWriter_fourcc('M','J','P','G')
out = cv2.VideoWriter('output/out3.mp4', fourcc, 15, (w,4*h))

for i in range(len(result_video)):
    out.write(result_video[i])
out.release()

<video width="800" controls>
    <source src="./output/out3.mp4" type="video/mp4">
</video>