In [None]:
%load_ext autoreload 
%autoreload 2 

import torch
from torch.utils.data import DataLoader
from train import Trainer
from generator import *
from discriminator import GAN
from dataset import CocoStuffDataSet
import os, argparse, datetime, json

from utils import *
NUM_CLASSES = 11
SAVE_DIR = "../checkpoints" # Assuming this is launched from code/ subfolder.
experiment_name = 'baseline'
gan_name = 'gan_low_reg'

use_bn = True
experiment_dir = os.path.join(SAVE_DIR, experiment_name)
gan_dir = os.path.join(SAVE_DIR, gan_name)
batch_size = 64
%matplotlib inline

In [None]:
HEIGHT, WIDTH = 128, 128
val_loader = DataLoader(CocoStuffDataSet(mode='val', supercategories=['animal'], height=HEIGHT, width=WIDTH),
                            batch_size, shuffle=False)
train_loader = DataLoader(CocoStuffDataSet(mode='train', supercategories=['animal'], height=HEIGHT, width=WIDTH),
                            batch_size, shuffle=True)

In [None]:
generator = SegNet16(NUM_CLASSES, use_bn=use_bn)
gan_generator = SegNet16(NUM_CLASSES, use_bn=use_bn)
image_shape = (3, HEIGHT, WIDTH)
segmentation_shape = (NUM_CLASSES, HEIGHT, WIDTH)
discriminator = None
trainer = Trainer(generator, discriminator, train_loader, val_loader, \
                 experiment_dir=experiment_dir, resume=True, load_iter=None)
gan_trainer = Trainer(gan_generator, discriminator, train_loader, val_loader, \
                 experiment_dir=gan_dir, resume=True, load_iter=None)

In [None]:
def visualize_mask(trainer, loader, number, save=False, gan_trainer=None):
    total = 0
    to_return = []
    for data, mask_gt, gt_visual in loader:
        if total < number: 
            data = data.cuda()
            batch_size = data.size()[0]
            mask_pred = convert_to_mask(trainer._gen(data))
            if gan_trainer is not None:
                gan_pred = convert_to_mask(gan_trainer._gen(data))
            for i in range(len(data)):
                img = data[i].detach().cpu().numpy()
                gt_mask = gt_visual[i].detach().cpu().numpy()
                pred_mask = np.argmax(mask_pred[i].detach().cpu().numpy(), axis=0)
                to_return.append((img, gt_mask, pred_mask))
                display_image = np.transpose(img, (1, 2, 0))
                plt.figure(figsize=(20, 20))

                plt.subplot(141)
                plt.imshow(display_image)
                plt.axis('off')
                plt.title('original image')

                cmap = discrete_cmap(NUM_CLASSES, 'Paired')
                norm = colors.NoNorm(vmin=0, vmax=NUM_CLASSES)

                plt.subplot(142)
                plt.imshow(display_image)
                plt.imshow(gt_mask, alpha=0.8, cmap=cmap, norm=norm)
                plt.axis('off')
                plt.title('real mask')

                plt.subplot(143)
                plt.imshow(display_image)
                plt.imshow(pred_mask, alpha=0.8, cmap=cmap, norm=norm)
                plt.axis('off')
                plt.title('predicted mask')
                if gan_trainer is not None:
                    gan_pred_mask = np.argmax(gan_pred[i].detach().cpu().numpy(), axis=0)
                    plt.subplot(144)
                    plt.imshow(display_image)
                    plt.imshow(gan_pred_mask, alpha=0.8, cmap=cmap, norm=norm)
                    plt.axis('off')
                    plt.title('GAN predicted mask')
                plt.show()
                
                ### Now save image and background masks for style transfer
                if save:
                    idx = i + total
                    print ("Image {}".format(idx))
                    savedir = os.path.join('./saved_images_and_masks', str(idx))
                    if not os.path.exists(savedir):
                        os.makedirs(savedir)
                    gt_background = np.where(gt_mask == 10., 1, 0)
                    pred_background = np.where(pred_mask == 10., 1, 0)
                    torch.save(torch.from_numpy(img), os.path.join(savedir, 'img.pk'))
                    torch.save(torch.from_numpy(gt_background), os.path.join(savedir, 'gt_mask.pk')) # only care about background class
                    torch.save(torch.from_numpy(pred_background), os.path.join(savedir, 'baseline_mask.pk'))  # only care about background class
                    if gan_trainer is not None:
                        gan_background = np.where(gan_pred_mask == 10., 1, 0)
                        torch.save(torch.from_numpy(gan_background), os.path.join(savedir, 'gan_mask.pk'))  # only care about background class
#                     plt.imshow(gt_background)
#                     plt.show()
#                     plt.imshow(pred_background)
#                     plt.show()
            total += batch_size

        else:
            break
    return to_return

In [None]:
_ = visualize_mask(trainer, train_loader, 30, gan_trainer=gan_trainer)

In [None]:
_ = visualize_mask(trainer, val_loader, 600, save=True, gan_trainer=gan_trainer)

In [None]:
val_pixel_acc, val_mIOU, per_class_accuracy = trainer.evaluate(val_loader, 0, ignore_background=True)

print ("Pixel accuracy {}".format(val_pixel_acc))
print ("Val mIOU {}".format(val_mIOU))
print ("per_class_accuracy {}".format(per_class_accuracy))

val_pixel_acc, val_mIOU, per_class_accuracy = gan_trainer.evaluate(val_loader, 0, ignore_background=True)

print ("Pixel accuracy {}".format(val_pixel_acc))
print ("Val mIOU {}".format(val_mIOU))
print ("per_class_accuracy {}".format(per_class_accuracy))

In [None]:
train_pixel_acc, train_mIOU, per_class_accuracy = trainer.evaluate(train_loader, 0, ignore_background=True)

print ("Pixel accuracy {}".format(train_pixel_acc))
print ("Val mIOU {}".format(train_mIOU))
print ("per_class_accuracy {}".format(per_class_accuracy))

train_pixel_acc, val_mIOU, per_class_accuracy = gan_trainer.evaluate(train_loader, 0, ignore_background=True)

print ("Pixel accuracy {}".format(train_pixel_acc))
print ("Val mIOU {}".format(train_mIOU))
print ("per_class_accuracy {}".format(per_class_accuracy))

In [None]:
## CONFUSION MATRIX ##

confusion_matrix = trainer.get_confusion_matrix(val_loader)


In [None]:
dataset = val_loader.dataset
coco = dataset.coco
all_cats_ids = coco.getCatIds()
cats = coco.loadCats(all_cats_ids)
nms=[cat['name'] for cat in cats]

animal_cat_names = [nms[all_cats_ids.index(i)] for i in val_loader.dataset.catIds] + ['background']
print (animal_cat_names)
visualize_conf(confusion_matrix, animal_cat_names)