In [1]:
# GPU setting

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"

In [2]:
!nvidia-smi

Thu Jan 30 17:58:46 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.33.01    Driver Version: 440.33.01    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 1080    Off  | 00000000:03:00.0 Off |                  N/A |
| 58%   63C    P0    59W / 210W |      0MiB /  8119MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX TIT...  Off  | 00000000:04:00.0 Off |                  N/A |
| 39%   82C    P0    96W / 250W |      0MiB / 12210MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                            

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, models, transforms
from matplotlib.pyplot import imshow, imsave
from PIL import Image
from collections import OrderedDict

import copy

import cv2

In [4]:
# utility function to measure time

import time
import math

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [5]:
# utility function to measure time

import time
import math

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [6]:
device = ("cuda:0" if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [7]:
# KMNet_vgg16bn: atros, index pooling(segnet), blur pooling, until pool4


class KMNet_vgg16bn(nn.Module):
    def __init__(self, in_channels=3, n_classes=34):
        super(KMNet_vgg16bn, self).__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        
        model = models.vgg16_bn(pretrained=True)
                
        #cityscape input (1, 3,1024, 2048)
              
        self.encoder1 = model.features[0:6]      # maxpool: size / 2
        self.encoder2 = model.features[7:13]     # maxpool: size / 4
        self.encoder3 = model.features[14:23]     # maxpool: size / 8
        # atrous
        self.encoder3[0] = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3,3), stride=(1,1), padding=(2,2), dilation=(2,2))
        self.encoder4 = model.features[24:33]     # maxpool: size / 16
        # atrous
        self.encoder4[0] = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3,3), stride=(1,1), padding=(4,4), dilation=(4,4))
        
        self.downsample = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, return_indices=True)
        self.upsample = nn.MaxUnpool2d(kernel_size=2, stride=2, padding=0)   # input size * 2
        self.blurpool = BlurPooling()
        
        self.bridge = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn1', nn.BatchNorm2d(1024)),
            ('relu1', nn.ReLU(inplace=True)),
            ('conv2', nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn2', nn.BatchNorm2d(512)),
            ('relu2', nn.ReLU(inplace=True)),
        ]))


        
        self.decoder1 = nn.Sequential(OrderedDict([
            ('conv1',nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn1', nn.BatchNorm2d(512)),
            ('relu1',nn.ReLU(inplace=True)),
            ('conv2', nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn2', nn.BatchNorm2d(256)),
            ('relu2',nn.ReLU(inplace=True)),
        ]))
        
        self.decoder2 = nn.Sequential(OrderedDict([
            ('conv1',nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn1',nn.BatchNorm2d(256)),
            ('relu1',nn.ReLU(inplace=True)),
            ('conv2',nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn2', nn.BatchNorm2d(128)),
            ('relu2', nn.ReLU(inplace=True)),
        ]))
        
        self.decoder3 = nn.Sequential(OrderedDict([
            ('conv1',nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn1',nn.BatchNorm2d(128)),
            ('relu1', nn.ReLU(inplace=True)),
            ('conv2',nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn2', nn.BatchNorm2d(64)),
            ('relu2',nn.ReLU(inplace=True)),
        ]))
        
        self.decoder4 = nn.Sequential(OrderedDict([
            ('conv1',nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)),
            ('bn1',nn.BatchNorm2d(64)),
            ('relu1',nn.ReLU(inplace=True)),
        ]))
        
        self.last_conv = nn.Sequential(OrderedDict([
            ('conv1',nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1), bias=False)),
            ('bn1',nn.BatchNorm2d(64)),
            ('relu1',nn.ReLU(inplace=True)),
            ('conv2',nn.Conv2d(in_channels=64, out_channels=n_classes, kernel_size=(1,1))),
        ]))
        
    def forward(self, x):
        #y_, idx_= self.in_layer(x)
        y_1 = self.encoder1(x)
        #print('y1')
        #print(y_1.shape)
        y_2,idx1 = self.downsample(y_1)
        #print('y_2_down')
        #print(y_2.shape)
#         print('idx1')
#         print(idx1.shape)
        
        y_2 = self.encoder2(y_2)
#         print('y2')
#         print(y_2.shape)
        y_3, idx2 = self.downsample(y_2)
#         print('y3_down')
#         print(y_3.shape)
#         print('idx2')
#         print(idx2.shape)
        
        y_3 = self.encoder3(y_3)
#         print('y3')
#         print(y_3.shape)
        y_4, idx3 = self.downsample(y_3)
#         print('y4_down')
#         print(y_4.shape)
#         print('idx3')
#         print(idx3.shape)
        
        y_4 = self.encoder4(y_4)
#         print('y_4')
#         print(y_4.shape)
        y_5, idx4 = self.downsample(y_4)
#         print('y5_down')
#         print(y_5.shape)
#         print('idx4')
#         print(idx4.shape)
        
        y_5 = self.bridge(y_5)
#         print('y_5')
#         print(y_5.shape)
        d_4 = self.upsample(y_5, idx4)
#         print('d_4_up')
#         print(d_4.shape)
        
        y_4 = self.blurpool(y_4)
        d_4 = torch.cat((d_4,y_4), dim=1)
#         print('1st concat ok')
#         print('d4_concat')
#         print(d_4.shape)
        d_4 = self.decoder1(d_4)
#         print('1st deco ok')
#         print('d4_deco')
#         print(d_4.shape)
        d_4 = self.upsample(d_4, idx3)
#         print('d4_up')
#         print(d_4.shape)
        
        y_3 = self.blurpool(y_3)
        d_3 = torch.cat((d_4, y_3), dim=1)
#         print('2nd concat ok')
#         print('d3_concat')
#         print(d_3.shape)
        d_3 = self.decoder2(d_3)
#         print('2nd deco ok')
#         print('d3_deco')
#         print(d_3.shape)
        d_3 = self.upsample(d_3, idx2)
#         print('d3_up')
#         print(d_3.shape)
        
#         print('8')
        
        y_2 = self.blurpool(y_2)
        d_2 = torch.cat((d_3,y_2), dim=1)
        d_2 = self.decoder3(d_2)
        d_2 = self.upsample(d_2, idx1)
        
        
        y_1 = self.blurpool(y_1)
        d_1 = torch.cat((d_2,y_1), dim=1)
        d_1 = self.decoder4(d_1)      
        
        y_ = self.last_conv(d_1)
#         print('14')
        y_ = F.interpolate(y_, (x.size(-2), x.size(-1)), mode='bilinear', align_corners=True)
#         print('15')
#         _, y_ = torch.max(y_, dim=1)
        
        return y_

Blur Pooling

: Concatenation helps the decoder to recover positional data. In this case, the boundary data of original image is critical. For the concatenating images' boundary to be more specific, we introduce Blur Pooling. It generates average blurred images from concatenation image. From average pooling, the boundary data of original image is blurred well. After that, original image is subtracted by blurred image so that the boundary data is emphasized.

In [8]:
class BlurPooling(nn.Module):
    def __init__(self):
        super(BlurPooling, self).__init__()
        
        
    def forward(self, x):
        blur = torch.zeros(x.size(0), x.size(1), x.size(2), x.size(3))
        blur = blur.detach().numpy()
        img_np = x.cpu().detach().numpy()
        #blur = np.zeros([img_np.size(0),img_np.size(1), img_np.size(2), img_np.size(3)])
        #print('input_img')
        #print(img_np.shape)
        for i in range(x.size(0)):
            for j in range(x.size(1)):
                blur[i][j] = cv2.blur(img_np[i][j], (8,16))
        #blur = cv2.blur(img_np, (8,16))
        y_ = img_np - 1.05*blur
        y_ = torch.from_numpy(y_).float().to(device)
    
        return y_

In [9]:
a = torch.zeros([4,2,3])
#print(a[0][0][0])
b = np.zeros([4,2,3])
print(b[0].shape)

(2, 3)


In [10]:
from torchsummary import summary

model = KMNet_vgg16bn(n_classes=34).to(device)
print(model)

#summary(model, input_size=(3,1024,2048))

KMNet_vgg16bn(
  (encoder1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (encoder2): Sequential(
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
  )
  (encoder3): Sequential(
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
    (15): BatchNorm2d(256, ep

In [11]:
class MaskToTensor(object):
    def __call__(self, img):
        return torch.from_numpy(np.array(img, dtype=np.int32)).long()

In [12]:
img_transform = transforms.Compose([
    transforms.RandomCrop(size=(1000,2000), padding_mode='constant'),
    transforms.Resize(size=(512,1024)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.28689554, 0.32513303, 0.28389177], std=[0.18696375, 0.19017339, 0.18720214]),
])

In [13]:
target_transform = transforms.Compose([
    transforms.Resize(size=(512,1024)),
    MaskToTensor(),
])

In [14]:
def getMeanIoU(label, yhat, n_classes=34, ignore_class=[0,1,2,3,4,5,6,9,10,14,15,16,18,29,30]):
    if ignore_class==None:
        ignore_len = 0
    else:
        ignore_len = len(ignore_class)
    iou = np.zeros(n_classes-ignore_len)
    minus = 0
    for i in range(n_classes):
        if np.isin(i, ignore_class):
            minus+=1
            continue
#         print(i-minus)
        tfpn = torch.eq(yhat, i)

        tp = len(tfpn[(tfpn==True) & (label==i) &(yhat==i)])
        fp = len(tfpn[(tfpn==True) & (label!=i) & (yhat==i)])
        fn = len(tfpn[(tfpn==False) & (label==i) & (yhat!=i)])
        tn = len(tfpn[(tfpn==False) & (label!=i) & (yhat!=i)])
        
        if tp==0:
            iou[i-minus] = 0.0
        else:
            iou[i-minus] = tp/(tp+fp+fn)
#         print(iou[i-minus])
            
    mIoU = np.mean(iou)
#     print(iou)
    
    return mIoU

In [15]:
train_set = datasets.Cityscapes(root='/home/km/data/cityscapes', split='train', mode='fine', target_type='semantic', transform=img_transform, target_transform=target_transform)
val_set = datasets.Cityscapes(root='/home/km/data/cityscapes', split='val', mode='fine', target_type='semantic', transform=img_transform, target_transform=target_transform)

In [16]:
len(train_set)

2975

In [17]:
train_set[0][0].shape

torch.Size([3, 512, 1024])

In [18]:
val_set[0][0].shape

torch.Size([3, 512, 1024])

## Using multi-GPUs

In [19]:
print(torch.cuda.device_count())

2


In [20]:
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

Let's use 2 GPUs!


    There is an imbalance between your GPUs. You may want to exclude GPU 0 which
    has less than 75% of the memory or cores of GPU 1. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


DataParallel(
  (module): KMNet_vgg16bn(
    (encoder1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (encoder2): Sequential(
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ReLU(inplace=True)
    )
    (encoder3): Sequential(
      (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), 

In [21]:
batch_size = 2
learning_rate = 0.01

In [22]:
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset=val_set, batch_size=10, shuffle=False, drop_last=False)

In [23]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [24]:
max_iter = 90000
max_epoch = max_iter//(len(train_set)//batch_size)
step = 0

In [25]:
from torch.optim.lr_scheduler import _LRScheduler

class PolyLR(_LRScheduler):
    def __init__(self, optimizer, max_iter, power=0.9, last_epoch=-1):
        self.max_iter = max_iter
        self.power = power
        super(PolyLR, self).__init__(optimizer, last_epoch)
        
    def get_lr(self):
        return [base_lr*(1-self.last_epoch/self.max_iter)**self.power 
                for base_lr in self.base_lrs]
    


In [26]:
from torch.optim.lr_scheduler import StepLR
#scheduler = StepLR(optim, step_size=50, gamma=0.9)
scheduler = PolyLR(optim, max_iter = max_iter)

## Train

In [27]:
totalLoss = 0
trainLoss = []

best_model_wts = copy.deepcopy(model.state_dict())
best_miou = 0

model.train()
for epoch in range(max_epoch):
    for idx, (images, labels) in enumerate(train_loader):
        start = time.time()
        x, y = images.to(device), labels.to(device)
        yhat = model(x)
#         print(epoch,step)
        
        loss = criterion(yhat, y)
        totalLoss += loss.item()
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        scheduler.step()
        
        if (step+1)%5 == 0:
            print('{}th Epoch, {}th Step, learning rate = {} - Loss: {}'.format(epoch+1, step+1, scheduler.get_lr()[0], loss.item()))
            trainLoss.append(totalLoss/5)
            totalLoss = 0
            
        if (step+1)%100 == 0:
            model.eval()
            mIoU = 0
            curr = 0
            with torch.no_grad():
                for idx, (images, labels) in enumerate(val_loader):
                    x, y = images.to(device), labels.to(device)
                    yhat = model(x)
                    
                    loss = criterion(yhat, y)

                    _, y_pred = torch.max(yhat.cpu(), dim=1)
                    mIoU += getMeanIoU(y.cpu(), y_pred)*len(y)
                    curr += len(y)
                    print(curr, end=' ')
                    if curr==300:
                        print()
                    
            mIoU /= len(val_set)
            print()
            print('*'*27, 'Test', '*'*27)
            print('time:{}, {}th Step, Loss: {}, Mean IoU = {:.3f}%'.format(timeSince(start), step+1, loss.item(), mIoU*100))
            print('*'*60)
            
            model.train()
        
        step+=1
        

1th Epoch, 5th Step, learning rate = 0.009999499998611083 - Loss: 2.8305823802948
1th Epoch, 10th Step, learning rate = 0.00999899999444422 - Loss: 2.11367130279541
1th Epoch, 15th Step, learning rate = 0.009998499987499236 - Loss: 2.1761856079101562
1th Epoch, 20th Step, learning rate = 0.009997999977775967 - Loss: 1.970829963684082
1th Epoch, 25th Step, learning rate = 0.00999749996527424 - Loss: 1.8596738576889038
1th Epoch, 30th Step, learning rate = 0.009996999949993889 - Loss: 2.6632227897644043
1th Epoch, 35th Step, learning rate = 0.009996499931934738 - Loss: 1.6601877212524414
1th Epoch, 40th Step, learning rate = 0.009995999911096622 - Loss: 1.935294270515442
1th Epoch, 45th Step, learning rate = 0.009995499887479371 - Loss: 1.6238504648208618
1th Epoch, 50th Step, learning rate = 0.00999499986108281 - Loss: 1.5821216106414795
1th Epoch, 55th Step, learning rate = 0.009994499831906777 - Loss: 1.956046462059021
1th Epoch, 60th Step, learning rate = 0.009993999799951093 - Loss:

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/km/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/km/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-7-b2fe7438f891>", line 159, in forward
    d_1 = self.decoder4(d_1)
  File "/home/km/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/km/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/km/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/km/anaconda3/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 81, in forward
    exponential_average_factor, self.eps)
  File "/home/km/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1670, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 7.93 GiB total capacity; 6.53 GiB already allocated; 390.06 MiB free; 558.71 MiB cached)


In [None]:
!nvidia-smi

## model save

In [None]:
PATH = '/root/km/models/KMNet/KMNet_vgg16bn.pth'
torch.save(model.state_dict(), PATH)


## validation

In [None]:
model.eval()   #  model load 후 이어서 학습하기 위해서는 model.eval() 한번 호출이 필요
mIoU = 0
curr = 0
with torch.no_grad():
    for idx, (images, labels) in enumerate(val_loader):
        x, y = images.to(device), labels.to(device)
        yhat = model(x)

        loss = criterion(yhat, y)

        _, y_pred = torch.max(yhat.cpu(), dim=1)
        mIoU += getMeanIoU(y.cpu(), y_pred)*len(y)
        curr += len(y)
        print(curr, end=' ')
        if curr==300:
            print()

mIoU /= len(val_set)
print()
print('*'*27, 'Test', '*'*27)
print('time:{}, {}th Step, Loss: {}, Mean IoU = {:.3f}%'.format(timeSince(start), step+1, loss.item(), mIoU*100))
print('*'*60)

In [None]:
plt.plot(range(len(trainLoss)), trainLoss, marker='.')
plt.xlabel("steps")
plt.ylabel("train loss")
plt.show()

## test

In [None]:
test_data = datasets.Cityscapes(root='/root/km/data/cityscapes', split='test', mode='fine', target_type='semantic', transform=img_transform, target_transform=target_transform)
test_loader = DataLoader(dataset=test_data, batch_size=10, shuffle=False, drop_last=False)

In [None]:
print(len(test_data))

In [None]:
KMNet_model = model
KMNet_model.load_state_dict(torch.load(PATH))

In [None]:
KMNet_model.eval()
print('<{}th step>'.format(step+1))

curr = 0
y_preds = torch.Tensor([]).type(torch.long)
with torch.no_grad():
    for idx, (images, labels) in enumerate(test_loader):
        x, y =images.to(device), labels.to(device)
        yhat = KMNet_model(x)
        
        loss = criterion(yhat, y)
        
        _, y_pred = torch.max(yhat.cpu(), dim=1)
        y_preds = torch.cat([y_preds, y_pred], dim=0)
        curr += len(y)
        if(curr%50)==0 or curr==len(test_data):
            print(curr, end=' ')
            
print("Prediction Ended")

In [None]:
test_img = datasets.Cityscapes(root='/root/km/data/cityscapes', split='test', mode='fine', target_type='semantic')

In [None]:
class TensorToMask(object):
    def __call__(self, tensor):
        mask = tensor.type(torch.uint8)
        
        return mask

In [None]:
class TestTransform(object):
    def __call__(self, tensor, size=(1024, 2048)):
        test_transform = transforms.Compose([
            TensorToMask(),
            transforms.ToPILImage(),
            transforms.Resize(size=size),
        ])
        
        return test_transform(tensor)

In [None]:
test_transform = transforms.Compose([
    TensorToMask(),
    transforms.ToPILImage(),
    transforms.Resize(size=(1024,2048)),
])


In [None]:
a, b = 15, 30

In [None]:
fig = plt.figure(figsize=(a,b))

for i in range(10):
    ax = fig.add_subplot(5, 2, i+1)
    ax.imshow(test_img[i][0])

In [None]:
fig = plt.figure(figsize=(a,b))

for i in range(10):
    ax = fig.add_subplot(5, 2, i+1)
    y_img = test_transform(y_preds[i])
    ax.imshow(y_img)