In [1]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import glob
import itertools
import pickle
from enum import Enum
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from sklearn.model_selection import train_test_split

from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from collections import OrderedDict, defaultdict

import networks
import data_loader

In [2]:
data_root = "/datasets/ee285f-public/gt_for_cv/"
models = ['dual_gans_semi', 'dual_gans_un', 'cycle_gan_semi', 'cycle_gan_un', 'semantic']
model_dict = defaultdict.fromkeys(models)
saved_images = 'saved_test_images/'
image_size = 256
batch_size = 8
num_images = 800
ngpu = torch.cuda.device_count()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
for i in range(len(models)):
    path = 'saved_models/test_' + models[i] + '_generator_b.pth'
    model_dict[models[i]] = path
    print("{} - {}".format(models[i], path))

dual_gans_semi - saved_models/test_dual_gans_semi_generator_b.pth
dual_gans_un - saved_models/test_dual_gans_un_generator_b.pth
cycle_gan_semi - saved_models/test_cycle_gan_semi_generator_b.pth
cycle_gan_un - saved_models/test_cycle_gan_un_generator_b.pth
semantic - saved_models/test_semantic_generator_b.pth


In [4]:
def dice_coef(label, output):
    label = torch.round(label.view(-1))
    output = torch.round(output.view(-1))
    isct = torch.sum(torch.mul(label,output))
    return (2 *(isct) / (torch.sum(label) + torch.sum(output)))

In [5]:
def test_model(model_id):
    print('Testing model - {} from path - {}'.format(models[model_id], model_dict[models[model_id]]))
    
    data = data_loader.DataLoader(data_root, image_size, batch_size, train = False, folder_A = "images/", folder_B = "labels/")
    if model_id == 0 or model_id == 1:
        model_generator = networks.DualGansGenerator().to(device)
    else:
        model_generator = networks.CycleGanResnetGenerator().to(device) 
    
    model_generator.load_state_dict(torch.load(model_dict[models[model_id]]))
    
    semantic_model = networks.GeneratorUNet().to(device)
    semantic_model.load_state_dict(torch.load(model_dict[models[-1]]))
    
    print('Models loaded...')
    iou = []
    for i in range(num_images // batch_size):
        x,y = next(data.data_generator(0, train = False))
        img = Variable(x, requires_grad = False).to(device)
        lbl = Variable(y, requires_grad = False).to(device)
        
        img = (img - torch.min(img))/(torch.max(img - torch.min(img)))
        lbl = (lbl - torch.min(lbl))/(torch.max(lbl - torch.min(lbl)))
        
        cityscapes = model_generator(img)
        cityscapes = (cityscapes - torch.min(cityscapes))/(torch.max(cityscapes - torch.min(cityscapes)))
        
        semantic_sample = semantic_model(cityscapes)
        semantic_sample = (semantic_sample - torch.min(semantic_sample))/(torch.max(semantic_sample - torch.min(semantic_sample)))
        
        print (i, 'IOU: {}'.format(dice_coef(lbl, semantic_sample)))
        iou.append(dice_coef(lbl, semantic_sample).cpu().detach().numpy())
        for j in range(img.size()[0]):
            img_sample = torch.cat((img[j,:,:,:].data,lbl[j,:,:,:].data, cityscapes[j,:,:,:], semantic_sample[j,:,:,:]),-1)
            save_image(img_sample, saved_images + models[model_id] + '_%d.png' % ((batch_size*i) + j), nrow=2)
            
    print ('Mean IOU for model: {} is {}'.format(models[model_id], np.sum(np.asarray(iou))/(i+1)))       

In [6]:
def test_all():
    for i in range(len(models) - 1):
        test_model(i)

In [8]:
test_model(3)

Testing model - cycle_gan_un from path - saved_models/test_cycle_gan_un_generator_b.pth
Models loaded...
0 IOU: 0.7263473868370056
1 IOU: 0.7816410660743713
2 IOU: 0.7819533348083496
3 IOU: 0.7845762968063354
4 IOU: 0.8085094094276428
5 IOU: 0.7842686176300049
6 IOU: 0.755906879901886
7 IOU: 0.8017826676368713
8 IOU: 0.7657239437103271
9 IOU: 0.7063268423080444
10 IOU: 0.7726572155952454
11 IOU: 0.7768825888633728
12 IOU: 0.7564979195594788
13 IOU: 0.7851009368896484
14 IOU: 0.7112932205200195
15 IOU: 0.7425234913825989
16 IOU: 0.77988600730896
17 IOU: 0.7924004793167114
18 IOU: 0.781908392906189
19 IOU: 0.7860972881317139
20 IOU: 0.7418174147605896
21 IOU: 0.8260592222213745
22 IOU: 0.8109201788902283
23 IOU: 0.7798981070518494
24 IOU: 0.7548923492431641
25 IOU: 0.7967679500579834
26 IOU: 0.7910661101341248
27 IOU: 0.6983416676521301
28 IOU: 0.7823969721794128
29 IOU: 0.7040587663650513
30 IOU: 0.7567679286003113
31 IOU: 0.8539625406265259
32 IOU: 0.7422598600387573
33 IOU: 0.75613027