In [1]:
import sys
sys.path.insert(0,'./')
from cityscapes import CityScapes

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np

import os
import os.path as osp
import logging
import time
import datetime
from shelfnet import ShelfNet

In [2]:
# Some helper functions to view network parameters
from torchinfo import summary

def view_network_shapes(model, input_shape):
    print(summary(model, input_size=input_shape))


def view_network_parameters(model):
    # Visualise the number of parameters
    tensor_list = list(model.state_dict().items())
    total_parameters = 0
    print("Model Summary\n")
    for layer_tensor_name, tensor in tensor_list:
        total_parameters += int(torch.numel(tensor))
        print("{}: {} elements".format(layer_tensor_name, torch.numel(tensor)))
    print(f"\nTotal Trainable Parameters: {total_parameters}!")


In [3]:
# training dataset
n_classes = 19
batch = 4
n_workers = 4
cropsize = [1024, 1024] # [h, w] of the cropped image
ds_train = CityScapes('data/', cropsize=cropsize, mode='train')
dl_train = DataLoader(ds_train,
                batch_size = batch,
                shuffle = True,
                num_workers = n_workers,
                pin_memory = True,
                drop_last = True)

# iterate through the dataset
for i, (imgs, label) in enumerate(dl_train):
    print(imgs.shape, label.shape)
    break

torch.Size([4, 3, 1024, 1024]) torch.Size([4, 1, 1024, 1024])


Goal is to simplify the model even further by reducing the channels more

Below is the format of the realtime shelfnet 18

```python
class ShelfNet(nn.Module):
    def __init__(self, n_classes, *args, **kwargs):
        super(ShelfNet, self).__init__()
        self.backbone = Resnet18()

        self.decoder = Decoder(planes=64,layers=3,kernel=3)
        self.ladder = LadderBlock(planes=64,layers=3, kernel=3)

        self.conv_out = NetOutput(64, 64, n_classes)
        self.conv_out16 = NetOutput(128, 64, n_classes)
        self.conv_out32 = NetOutput(256, 64, n_classes)

        self.trans1 = ConvBNReLU(128,64,ks=1,stride=1,padding=0)
        self.trans2 = ConvBNReLU(256, 128, ks=1, stride=1, padding=0)
        self.trans3 = ConvBNReLU(512, 256, ks=1, stride=1, padding=0)
    def forward(self, x, aux = True):
        H, W = x.size()[2:]

        feat8, feat16, feat32 = self.backbone(x)

        feat8 = self.trans1(feat8)
        feat16 = self.trans2(feat16)
        feat32 = self.trans3(feat32)

        out = self.decoder([feat8, feat16, feat32])

        out2 = self.ladder(out)

        feat_cp8, feat_cp16, feat_cp32 = out2[-1], out2[-2], out2[-3]

        feat_out = self.conv_out(feat_cp8)
        feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)

        if aux:
            feat_out16 = self.conv_out16(feat_cp16)
            feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)

            feat_out32 = self.conv_out32(feat_cp32)
            feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)

            return feat_out, feat_out16, feat_out32
        else:
            return feat_out

    def get_params(self):
        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            child_wd_params, child_nowd_params = child.get_params()
            if isinstance(child, LadderBlock) or isinstance(child, NetOutput) or isinstance(child, Decoder)\
                    or isinstance(child, ConvBNReLU):
                lr_mul_wd_params += child_wd_params
                lr_mul_nowd_params += child_nowd_params
            else:
                wd_params += child_wd_params
                nowd_params += child_nowd_params
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
```

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from resnet2 import resnet18
from modules.bn import InPlaceABNSync
from ShelfBlock import Decoder, LadderBlock
from shelfnet import ConvBNReLU, NetOutput

class SimpleShelfNet(nn.Module):
    def __init__(self, n_classes, channel_ratio=0.25):
        # Add channel_ratio for control of decoder channels
        super(SimpleShelfNet, self).__init__()

        self.backbone = resnet18()

        # Reduce decoder channels by channel_ratio
        self.decoder = Decoder(planes=int(64 * channel_ratio), layers=3, kernel=3)
        self.ladder = LadderBlock(planes=int(64 * channel_ratio), layers=3, kernel=3)

        self.conv_out = NetOutput(int(64 * channel_ratio), 64, n_classes)
        self.conv_out16 = NetOutput(int(128 * channel_ratio), 64, n_classes)
        self.conv_out32 = NetOutput(int(256 * channel_ratio), 64, n_classes)

        # Adjust transition layers if decoder channels are reduced
        self.trans1 = ConvBNReLU(128, int(64 * channel_ratio), ks=1, stride=1, padding=0)
        self.trans2 = ConvBNReLU(256, int(128 * channel_ratio), ks=1, stride=1, padding=0)
        self.trans3 = ConvBNReLU(512, int(256 * channel_ratio), ks=1, stride=1, padding=0)
        
    def forward(self, x, aux = True):
        H, W = x.size()[2:]

        feat8, feat16, feat32 = self.backbone(x)

        feat8 = self.trans1(feat8)
        feat16 = self.trans2(feat16)
        feat32 = self.trans3(feat32)

        out = self.decoder([feat8, feat16, feat32])

        out2 = self.ladder(out)

        feat_cp8, feat_cp16, feat_cp32 = out2[-1], out2[-2], out2[-3]

        feat_out = self.conv_out(feat_cp8)
        feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)

        if aux:
            feat_out16 = self.conv_out16(feat_cp16)
            feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)

            feat_out32 = self.conv_out32(feat_cp32)
            feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)

            return feat_out, feat_out16, feat_out32
        else:
            return feat_out

    def get_params(self):
        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            child_wd_params, child_nowd_params = child.get_params()
            if isinstance(child, LadderBlock) or isinstance(child, NetOutput) or isinstance(child, Decoder)\
                    or isinstance(child, ConvBNReLU):
                lr_mul_wd_params += child_wd_params
                lr_mul_nowd_params += child_nowd_params
            else:
                wd_params += child_wd_params
                nowd_params += child_nowd_params
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params

In [5]:
# original realtime model
simple_model = SimpleShelfNet(n_classes=n_classes, channel_ratio=.25).cuda()

view_network_shapes(simple_model, torch.Size([1, 3, cropsize[0], cropsize[1]])) # batch size is 8, but see pass through of 1 image at a time

Layer (type:depth-idx)                             Output Shape              Param #
SimpleShelfNet                                     [1, 19, 1024, 1024]       --
├─ResNet: 1-1                                      [1, 128, 128, 128]        513,000
│    └─Conv2d: 2-1                                 [1, 64, 512, 512]         9,408
│    └─BatchNorm2d: 2-2                            [1, 64, 512, 512]         128
│    └─MaxPool2d: 2-3                              [1, 64, 256, 256]         --
│    └─Sequential: 2-4                             [1, 64, 256, 256]         --
│    │    └─BasicBlock: 3-1                        [1, 64, 256, 256]         73,984
│    │    └─BasicBlock: 3-2                        [1, 64, 256, 256]         73,984
│    └─Sequential: 2-5                             [1, 128, 128, 128]        --
│    │    └─BasicBlock: 3-3                        [1, 128, 128, 128]        230,144
│    │    └─BasicBlock: 3-4                        [1, 128, 128, 128]        295,424
│    └─S

In [6]:
# load weights
simple_model.eval()
run_time = []

# measure inference time via random input
for i in range(0,100):
    input = torch.randn(1,3,1024,1024).cuda()
    # ensure that context initialization and normal_() operations
    # finish before you start measuring time
    torch.cuda.synchronize()
    torch.cuda.synchronize()
    start = time.perf_counter()

    with torch.no_grad():
        output = simple_model(input , aux=False)

    torch.cuda.synchronize()  # wait for mm to finish
    end = time.perf_counter()

    run_time.append(end-start)

print('Mean running time is ', np.mean(run_time))

Mean running time is  0.011665486700003384


The following is a modified version of their training

In [7]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [8]:
from loss import OhemCELoss
from optimizer import Optimizer

respth = './res'

## model
ignore_idx = 255
net = SimpleShelfNet(n_classes=n_classes, channel_ratio=.25)
net.cuda()
net.train()
score_thres = 0.7
n_min = batch*cropsize[0]*cropsize[1]//16
LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)

## optimizer
max_iter = 80000
learning_rate = 0.0001
weight_decay = 5e-4
optim = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)

## train loop
msg_iter = 50
loss_avg = []
st = glob_st = time.time()
diter = iter(dl_train)
epoch = 0
for it in range(max_iter):
    im, lb = next(diter)

    if not im.size()[0]==batch:
        epoch += 1
        diter = iter(dl_train)
        im, lb = next(diter)
        continue

    im = im.cuda()
    lb = lb.cuda()
    H, W = im.size()[2:]
    lb = torch.squeeze(lb, 1)

    optim.zero_grad()
    out, out16, out32 = net(im)
    lossp = LossP(out, lb)
    loss2 = Loss2(out16, lb)
    loss3 = Loss3(out32, lb)
    loss = lossp + loss2 + loss3
    loss.backward()
    optim.step()


    loss_avg.append(loss.item())

    ## print training log message
    if (it+1)%msg_iter==0:
        loss_avg = sum(loss_avg) / len(loss_avg)
        ed = time.time()
        t_intv, glob_t_intv = ed - st, ed - glob_st
        eta = int((max_iter - it) * (glob_t_intv / it))
        eta = str(datetime.timedelta(seconds=eta))
        msg = ', '.join([
                'it: {it}/{max_it}',
                'loss: {loss:.4f}',
                'eta: {eta}',
                'time: {time:.4f}',
            ]).format(
                it = it+1,
                max_it = max_iter,
                loss = loss_avg,
                time = t_intv,
                eta = eta
            )
        print(msg)
        loss_avg = []
        st = ed

    if it % 1000 == 0:
        ## dump the models in between
        save_pth = osp.join(respth, 'shelfnet_model_it_%d.pth'%it)
        torch.save(net.state_dict(), save_pth)

## dump the final model
save_pth = osp.join(respth, 'model_final.pth')
net.cpu()
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
torch.save(state, save_pth)
print('training done, model saved to: {}'.format(save_pth))

data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
it: 10/1000, loss: 8.7320, eta: 0:12:13, time: 6.6570
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
data loaded
outputs done
losses calculated
backwards step done
d

KeyboardInterrupt: 

In [3]:
from tqdm import tqdm

# validation dataset
batchsize = 2
n_workers = 2
ds_val = CityScapes('data/', mode='val')
dl_val = DataLoader(ds_val,
                batch_size = batchsize,
                shuffle = False,
                num_workers = n_workers,
                drop_last = False)

In [4]:
from evaluate import MscEval

# loading pretrained weights
net = SimpleShelfNet(n_classes=19, channel_ratio=.25)
net.load_state_dict(torch.load('res/model_final.pth'))
net.cuda()
net.eval()

# evaluate via mIoU
mEval = MscEval(net, dl_val, flip=False)
mIoU, hist = mEval.evaluate()
print(f'mIoU: {mIoU}')

100%|██████████| 134/134 [01:21<00:00,  1.65it/s]

mIoU: 0.004969619019889352



