# MaskTrack - ResnetUNET

## Required Format data

In [1]:
import glob
import random

im_dir = 'D:/INTERVIEWS/stryker/Pytorch-UNet/data/DAVIS2017-master/JPEGImages'
an_dir = 'D:/INTERVIEWS/stryker/Pytorch-UNet/data/DAVIS2017-master/Annotations'

im_files = glob.glob(im_dir + '/**/*.jpg', recursive=True)
an_files = glob.glob(an_dir + '/**/*.png', recursive=True)

im_files = [i.split('/DAVIS2017-master/')[1] for i in im_files]
an_files = [i.split('/DAVIS2017-master/')[1] for i in an_files]

conc = [str('%s %s\n') %(i,j) for i,j in zip(im_files, an_files)]

random.shuffle(conc)

tr_len = int(0.80 * len(conc))
va_len = len(conc) - int(0.80 * len(conc))

file1 = open("train.txt","w")#write mode 
file1.writelines(conc[:tr_len]) 
file1.close()

file2 = open("val.txt","w")#write mode 
file2.writelines(conc[tr_len:]) 
file2.close()

file3 = open("trainval.txt","w")#write mode 
file3.writelines(conc) 
file3.close()

## Data Generation - Augmented Masks

In [None]:
import numpy as np 
import cv2 
from PIL import Image
import random
import os
import datetime

class Masktrack_aug():
    def __init__(self,Davis_path=None):
        self.base_path = Davis_path
        self.annotation_path = os.path.join(self.base_path,'Annotations/')
        self.deformation_path = os.path.join(self.base_path,'Deformations/')
        self.gt_path = os.path.join(self.base_path,'Annotations_binary/')
        if not os.path.exists(self.deformation_path):
            os.makedirs(self.deformation_path)
        if not os.path.exists(self.gt_path):
            os.makedirs(self.gt_path)
        with open(os.path.join(self.base_path,'ImageSets','train.txt')) as f:
            trainset = [v.split(' ')[0].split('JPEGImages\\')[1].replace('jpg','png') for v in f.readlines()]
        with open(os.path.join(self.base_path,'ImageSets','val.txt')) as f:
            valset = [v.split(' ')[0].split('JPEGImages\\')[1].replace('jpg','png') for v in f.readlines()]  
        self.videos = trainset+valset  
        
        #self.videos = os.listdir(self.annotation_path)
        print('total {} videos'.format(len(self.videos)))
        
    def augment_image_and_mask(self,gt_arr,gt_path=None,affine_transformation_path=None, non_rigid_deform_path=None):
        #gt_arr shape (H,W) and binary(0,1)

        # let us do non-rigid deformation
        N = 5
        Delta = 0.05
        H,W = gt_arr.shape
        #get the target boundary
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        boundary = cv2.dilate(gt_arr, kernel)-gt_arr
        boundindex = np.where(boundary==1)
        num_index = boundindex[0].shape[0]
        if num_index>N:
            maxH,minH = max(boundindex[0]),min(boundindex[0])
            tarH = maxH - minH
            maxW,minW = max(boundindex[1]),min(boundindex[1])
            tarW = maxW - minW

            # thin plate spline coord num    
            randindex = [random.randint(0,num_index-1) for _ in range(N)]
            sourcepoints=[]
            targetpoints = []
            for i in range(N):
                sourcepoints.append((boundindex[1][randindex[i]],boundindex[0][randindex[i]]))
                x = boundindex[1][randindex[i]]+int(random.uniform(-Delta,Delta)*tarW)
                y = boundindex[0][randindex[i]]+int(random.uniform(-Delta,Delta)*tarH)
                targetpoints.append((x,y))
        
            sourceshape = np.array(sourcepoints,np.int32)
            sourceshape=sourceshape.reshape(1,-1,2)
            targetshape = np.array(targetpoints,np.int32)
            targetshape=targetshape.reshape(1,-1,2)

            matches =[]
            for i in range(0,N):
                matches.append(cv2.DMatch(i,i,0))
            tps= cv2.createThinPlateSplineShapeTransformer()
            tps.estimateTransformation(targetshape, sourceshape,matches)
            no_grid_img=tps.warpImage(gt_arr)
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
            no_grid_img = cv2.dilate(no_grid_img,kernel)
            gt_out = gt_arr*255
            no_grid_img = no_grid_img*255
            gt_out = Image.fromarray(gt_out)
            no_grid_out = Image.fromarray(no_grid_img)

            scale=0.98
            randScale = random.uniform(scale,1/scale)
            M = cv2.getRotationMatrix2D(((maxH+minH)*0.5, (maxW-minW)*0.5), 0, randScale)
            
            dx = round(random.uniform(-0.05,0.05)*tarW)
            dy = round(random.uniform(-0.05,0.05)*tarH)
            M[0,2]+=dx
            M[1,2]+=dy
            affine_out = cv2.warpAffine(gt_arr, M, (W, H))*255
            affine_out = Image.fromarray(affine_out)
        else:
            gt_out = Image.fromarray(gt_arr*255)
            no_grid_out = Image.fromarray(gt_arr*255)
            affine_out = Image.fromarray(gt_arr*255)

        gt_out.save(gt_path)
        no_grid_out.save(non_rigid_deform_path)
        affine_out.save(affine_transformation_path)

    def script(self):
        mask_gt_path = self.annotation_path

        frames = self.videos
        
        no_objects = 1
        for k,frame in enumerate(frames):
            frame_path = os.path.join(mask_gt_path,frame)
            frame_gt_image = Image.open(frame_path)
            frame_gt_image = np.array(frame_gt_image)
            frame_index = frame.split('\\')[1][:-4]
            video = frame.split('\\')[0]

            label_folder_path = os.path.join(self.gt_path,video)
            deform_folder_path = os.path.join(self.deformation_path,video)
            if not os.path.exists(label_folder_path):
                os.makedirs(label_folder_path)

            if not os.path.exists(deform_folder_path):
                os.makedirs(deform_folder_path)

            for object_id in range(1,no_objects+1):
                temp = frame_gt_image.copy()
                
                m1 = temp==object_id
                m0 = temp!=object_id
                
                temp[m1]=1
                temp[m0]=0
                
                gt_path = os.path.join(label_folder_path,frame_index+'_'+str(object_id)+'.png')
                aff_path = os.path.join(deform_folder_path,frame_index+'_'+str(object_id)+'_d1.png')
                non_path = os.path.join(deform_folder_path,frame_index+'_'+str(object_id)+'_d2.png')
                self.augment_image_and_mask(temp,gt_path=gt_path,
                                            affine_transformation_path=aff_path,
                                            non_rigid_deform_path=non_path)       

## New Folder Structure - Generated Coarsened and Deformed Masks

## Common Import

In [1]:
import torch
import torch.nn as nn
from torchvision import models
import numpy as np

## Task 1 - Base Network

### Resnet 18 + Unet

In [2]:
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, num_input_channel, n_class):
        super().__init__()
        self.num_input_channel = num_input_channel
        self.base_model = models.resnet18(pretrained=True)
        self.base_model.conv1 = nn.Conv2d(num_input_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(num_input_channel, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

In [3]:
from torchsummary import summary
model = ResNetUNet(4,2)
summary(model, input_size=(4, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           2,368
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
            Conv2d-5         [-1, 64, 112, 112]          12,544
            Conv2d-6         [-1, 64, 112, 112]          12,544
       BatchNorm2d-7         [-1, 64, 112, 112]             128
       BatchNorm2d-8         [-1, 64, 112, 112]             128
              ReLU-9         [-1, 64, 112, 112]               0
             ReLU-10         [-1, 64, 112, 112]               0
        MaxPool2d-11           [-1, 64, 56, 56]               0
        MaxPool2d-12           [-1, 64, 56, 56]               0
           Conv2d-13           [-1, 64, 56, 56]          36,864
           Conv2d-14           [-1, 64,

In [4]:
# Sanity Check, Checking Output on newly defined model

trial = np.random.randn(1,4,224,224)
model(torch.from_numpy(trial).float())

tensor([[[[0.0570, 0.0828, 0.0702,  ..., 0.0777, 0.0553, 0.0574],
          [0.0873, 0.0657, 0.0686,  ..., 0.0560, 0.0699, 0.0735],
          [0.0791, 0.0901, 0.0764,  ..., 0.0738, 0.0609, 0.0802],
          ...,
          [0.0489, 0.0682, 0.0908,  ..., 0.0823, 0.0701, 0.0622],
          [0.0641, 0.0653, 0.0792,  ..., 0.0656, 0.0791, 0.0769],
          [0.0620, 0.0727, 0.0715,  ..., 0.0852, 0.0739, 0.0726]],

         [[0.1194, 0.1143, 0.1075,  ..., 0.1045, 0.1238, 0.1104],
          [0.1138, 0.1106, 0.1166,  ..., 0.0958, 0.1116, 0.1217],
          [0.0950, 0.1110, 0.1144,  ..., 0.1291, 0.1236, 0.1331],
          ...,
          [0.0888, 0.1190, 0.1361,  ..., 0.1296, 0.1068, 0.1052],
          [0.1135, 0.1174, 0.1167,  ..., 0.0933, 0.1054, 0.1188],
          [0.1053, 0.1256, 0.1232,  ..., 0.1132, 0.1030, 0.1151]]]],
       grad_fn=<MkldnnConvolutionBackward>)

## Dataset davis17_offline_dataset

In [5]:
from torch.utils.data import Dataset
import numpy as np
import os
import glob
from PIL import Image
import cv2

class DAVIS17Offline(Dataset):
    def __init__(self, train=True, mini=False, mega=False,
                 inputRes=None,
                 db_root_dir='DAVIS17',
                 transform=None):

        self.train = train
        self.mini = mini
        self.mega = mega
        self.inputRes = inputRes
        self.db_root_dir = db_root_dir
        self.transform = transform

        if self.mini == False and self.mega == False:
            if self.train:
                fname = 'train'
                # fname = 'train_seqs'
            else:
                fname = 'val'
                # fname = 'val_seqs'
        elif self.mini == True:
            if self.train:
                fname = 'train_mini'
                # fname = 'train_seqs'
            else:
                fname = 'val_mini'
                # fname = 'val_seqs'
        elif self.mega == True:
            if self.train:
                fname = 'train_mega'
            else:
                fname = 'val_mega'

        img_list = []
        labels = []
        deformations = []

        # Initialize the original DAVIS splits for training the parent network
        with open(os.path.join(db_root_dir, 'ImageSets/' + fname + '.txt')) as f:
            seqs = f.readlines()

            for seq in seqs:
                seq = seq.split(' ')[0]
                image = seq
                no_objects = 1

                # for image in images:
                image_id = image.split('\\')[-1][:-4]
                video = image.split('\\')[1]

                for object_id in range(1,no_objects+1):
                    for df in [1,2]:
                        img_list.append(image)
                        labels.append(os.path.join('Annotations_binary', video, image_id + '_' + str(object_id) + '.png'))
                        deformations.append(os.path.join('Deformations', video, image_id + '_' + str(object_id) + '_d' + str(df) + '.png'))

        assert (len(labels) == len(img_list))
        assert (len(labels) == len(deformations))

        self.img_list = img_list
        self.deformations = deformations
        self.labels = labels

        print('Done initializing ' + fname + ' Dataset')
        print(self.__len__(),'images')

    def __getitem__(self, idx):

        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[idx]))
        label = Image.open(os.path.join(self.db_root_dir, self.labels[idx]))
        deformation = Image.open(os.path.join(self.db_root_dir, self.deformations[idx]))
        #print(os.path.join(self.db_root_dir, self.deformations[idx]))
        img,label,deformation = self.transform(np.array(img), np.array(label), np.array(deformation), self.inputRes)
        sample = {'image': img, 'gt': label, 'deformation': deformation}
        return sample

    def __len__(self):
        return len(self.img_list)

## Task 2 - Offline Training

In [6]:
import timeit
import datetime
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable

from utility_functions import *
from path import Path

In [7]:
# Setting of parameters
NoLabels = 2
debug_mode = False
use_cuda = False
weight_decay = float(0.001)
base_lr = float(0.0005)
resume_epoch = int(0)  # Default is 0, change if want to resume
nEpochs = int(1)  # 1 epoch 
batch_size = int(1)
vbatch_size = 3
db_root_dir = 'DAVIS2017-master'
nAveGrad = 4  # keep it even

save_dir = os.path.join(db_root_dir, 'lr_' + str(base_lr) + '_wd_' + str(weight_decay))

if not os.path.exists(save_dir):
    os.makedirs(os.path.join(save_dir))

learnRate = base_lr

"""Initialise the network"""

net = ResNetUNet(4,int(NoLabels))

net.float()

if use_cuda:
    torch.cuda.set_device(0)

if use_cuda:
    print('use_cuda')
    net.cuda()
else:
    print('CUDA not available')

optimizer = optim.SGD(net.parameters(),lr=base_lr, momentum=0.9, weight_decay=weight_decay)

if os.path.exists(os.path.join(save_dir, 'logs')) == False:
    os.mkdir(os.path.join(save_dir, 'logs'))
    
file_offline_loss = open(os.path.join(save_dir, 'logs/logs_offline_training_start_epoch_' + str(resume_epoch) + '.txt'), 'w+')
file_offline_val_loss = open(os.path.join(save_dir, 'logs/logs_offline_training_val_start_epoch_' + str(resume_epoch) + '.txt'), 'w+')

loss_array = []
loss_minibatch_array = []
precision_train_array  = []
recall_train_array = []

loss_val_array = []
precision_val_array = []
recall_val_array = []

aveGrad = 0


CUDA not available


In [8]:
dataset17_train = DAVIS17Offline(train=True, mini=False, mega=False, db_root_dir=db_root_dir, transform=apply_custom_transform, inputRes=(224,224))
dataloader17_train = DataLoader(dataset17_train, batch_size=batch_size, shuffle=True, num_workers=3)

dataset17_val = DAVIS17Offline(train=False, mini=False, mega=False, db_root_dir=db_root_dir, transform=apply_val_custom_transform,inputRes=(224,224))
dataloader17_val = DataLoader(dataset17_val, batch_size=vbatch_size, shuffle=False, num_workers=3)

lr_factor_array = [1,1,1,0.1,1,1,1,0.1,1,1,1,1,1,0.1,1,1,1,1]

Done initializing train Dataset
9932 images
Done initializing val Dataset
2484 images


In [None]:
next(iter(dataloader17_train))

In [None]:
for epoch in range(1, nEpochs+1):

    trainingDataSetSize = 0
    epochLoss = 0
    epochTrainIOU = 0

    temp_Iou=0

    valDataSetSize = 0
    epochValLoss = 0
    epochValIOU = 0


    start_time = timeit.default_timer()
    epoch_start_time = datetime.datetime.now()

    total_train_batch = len(dataloader17_train)
    print('Training phase')
    print('len of loader: ' + str(total_train_batch))

    net.train()
    optimizer.zero_grad()
    aveGrad = 0

    for data_id, sample in enumerate(dataloader17_train):

        dic = net.state_dict()

        image = sample['image']
        anno = sample['gt']
        deformation = sample['deformation']

        # Making sure the mask input is similar to RGB values
        deformation[deformation==0] = -100
        deformation[deformation==1] = 100


        prev_frame_mask = Variable(deformation).float()
        inputs, gts = Variable(image), Variable(anno)

        if use_cuda:
            inputs, gts, prev_frame_mask = inputs.cuda(), gts.cuda(), prev_frame_mask.cuda()

        input_rgb_mask = torch.cat([inputs, prev_frame_mask], 1)
        noImages, noChannels, height, width = input_rgb_mask.shape



        output_mask = net(input_rgb_mask)

        upsampler = torch.nn.Upsample(size=(height, width), mode='bilinear')
        output_mask = upsampler(output_mask)

        if debug_mode:
            temp_out = np.zeros(output_mask[0][0].shape)
            temp_out[output_mask.data.cpu().numpy()[0][1] > output_mask.data.cpu().numpy()[0][0]] = 1
            cv2.imwrite('output.png',temp_out*255)

        loss1 = cross_entropy_loss(output_mask, gts)

        Iou_t = calculate_IOU(output_mask, gts)
        if data_id==0:
            temp_Iou = Iou_t
        else:
            temp_Iou = temp_Iou*0.999 + Iou_t*0.001
        epochTrainIOU += Iou_t
        now_time = datetime.datetime.now()
        remain_time = (now_time-epoch_start_time)*((total_train_batch-data_id-1)/(data_id+1))
        print('{} time remain {} epoch {} {}/{} train loss:{:.5f} Iou:{:.4f} lr:{} aveIou {:.4f}'.format(
            now_time,remain_time,epoch,(data_id+1)*batch_size,(total_train_batch)*batch_size,loss1.item(),Iou_t,learnRate,temp_Iou))

        loss_minibatch_array.append(loss1.item())

        epochLoss += loss1.item()
        trainingDataSetSize += 1

        # Backward the averaged gradient
        loss1 /= nAveGrad
        loss1.backward()
        aveGrad += 1

        # Update the weights once in nAveGrad forward passes
        if aveGrad % nAveGrad == 0:
            optimizer.step()
            optimizer.zero_grad()
            aveGrad = 0
        if (data_id+1)%10000==0:
            torch.save(net.state_dict(), os.path.join(save_dir, modelName + '_epoch-' + str(epoch) + '.pth'))
            print('stage saved')

    epochLoss = epochLoss / trainingDataSetSize
    epochTrainIOU = epochTrainIOU / trainingDataSetSize

    print('Epoch: ' + str(epoch) + ', Training Loss: ' + str(epochLoss) + '\n')
    print('Epoch: ' + str(epoch) + ', Training IOU: ' + str(epochTrainIOU) + '\n')

    file_offline_loss.write(str(datetime.datetime.now())
                            +' Epoch: ' + str(epoch)
                            +', Loss: ' + str(epochLoss)
                            +', IOU: ' +str(epochTrainIOU)
                            +', lr: ' +str(learnRate)
                            + '\n')
    loss_array.append(epochLoss)

    file_offline_loss.flush()
    torch.save(net.state_dict(), os.path.join(save_dir, modelName + '_epoch-' + str(epoch) + '.pth'))

    print('Validation phase')
    total_val_batch = len(dataloader17_val)
    aveGrad = 0
    net.eval()
    with torch.no_grad():
        for data_id, sample in enumerate(dataloader17_val):

            image = sample['image']
            anno = sample['gt']
            deformation = sample['deformation']

            deformation[deformation==0] = -100
            deformation[deformation==1] = 100

            prev_frame_mask = Variable(deformation, volatile=True).float()
            inputs, gts = Variable(image, volatile=True), Variable(anno, volatile=True)

            if use_cuda:
                inputs, gts, prev_frame_mask = inputs.cuda(), gts.cuda(), prev_frame_mask.cuda()

            input_rgb_mask = torch.cat([inputs, prev_frame_mask], 1)

            noImages, noChannels, height, width = input_rgb_mask.shape

            output_mask = net(input_rgb_mask)

            upsampler = torch.nn.Upsample(size=(height, width), mode='bilinear')
            output_mask = upsampler(output_mask)


            loss1 = cross_entropy_loss(output_mask, gts)

            Iou_t = calculate_IOU(output_mask, gts)
            epochValIOU += Iou_t
            print('{} epoch {}  {}/{} val loss:{:5f} Iou:{:5f} lr:{}'.format(datetime.datetime.now(),epoch,(data_id+1)*vbatch_size,(total_val_batch)*vbatch_size,loss1.item(),Iou_t,learnRate))

            epochValLoss += loss1.item()
            valDataSetSize += 1



    epochValLoss = epochValLoss / valDataSetSize
    epochValIOU = epochValIOU / valDataSetSize



    print('Epoch: ' + str(epoch) + ', Val Loss: ' + str(epochValLoss) + '\n')
    print('Epoch: ' + str(epoch) + ', Val IOU: ' + str(epochValIOU) + '\n')


    file_offline_val_loss.write(str(datetime.datetime.now())
                                 +' Epoch: ' + str(epoch)
                                 + ', Loss: ' + str(epochValLoss) 
                                 +', IOU: ' +str(epochValIOU)
                                 +', lr: ' +str(learnRate)
                                 + '\n')

    loss_val_array.append(epochValLoss)

    file_offline_val_loss.flush()

    epochLoss = 0


    stop_time = timeit.default_timer()

    epoch_secs = stop_time - start_time
    epoch_mins = epoch_secs / 60
    epoch_hr = epoch_mins / 60

    print('This epoch took: ' + str(epoch_hr) + ' hours')

    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr']*lr_factor_array[epoch-1]
    learnRate = learnRate*lr_factor_array[epoch-1]


    plot_loss1(loss_array, resume_epoch, epoch , save_dir)
    plot_loss1(loss_val_array, resume_epoch, epoch , save_dir, val=True)
    plot_loss_minibatch(loss_minibatch_array, save_dir)

file_offline_loss.close()

file_offline_val_loss.close()


Training phase
len of loader: 9932


## Task 3 - Online Training

In [None]:
# Load saved model from Task 2

In [None]:
# Write APIs to perform online training for given test data in next few cells

In [19]:
# Implement a top level function for inference on test data
# Input 1: Path to directory containing 'N' ordered images for given test sample
# Input 2: Path to directory containing corresponding 'N' masks with filename as that of images 
#          Here, use only mask[0] with images for online training where as mask[1:N] are ground truth for evaluating predictions
# Output 1: Quality metric for every single mask prediction from time (t=1 to t = T) and average for the same
# Output 2: Display all predicted masks along with original RGB image and ground truth masks in order

## Variants - Extra Credit