In [48]:
%load_ext autoreload
%autoreload 2
import matplotlib
import matplotlib.pyplot as plt
import pathlib
from mpl_toolkits.axes_grid1 import ImageGrid
from PIL import Image
import numpy as np
from tqdm import tqdm

import sys
sys.path.append('./')
from dataset.dstl_dataset import Dstl

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
dirs = {
    'adv_10k': './result/dstl/exp3/dstl_10000/',
    'adv_5k': './result/dstl/exp3/dstl_5000/',
    'adv_7.5k': './result/dstl/exp3/dstl_7500/',
    'adv_2.5k': './result/dstl/exp3/dstl_2500/',
    'adv_13k': './result/dstl/exp8/',
    'memreg_3k_on10k': './result/dstl/exp5/',
    'memreg_2.5k_on7.5k': './result/dstl/exp6/',
    'seg_overfit': './result/dstl/last_conv_unfreeze_1_overfit/',
    'seg_generalize': './result/dstl/vary_dropout_2_reduce_variance/',
    'seg_final': './result/dstl/aug_img_2_reduce_variance/',
    'gt': './data/dstl/masks/',
    'imgs': './data/dstl/imgs/',
}

def display_grid(img_names, img_dir, exp_name):
    nrows = len(img_names)
    fig, axs = plt.subplots(nrows, 3, figsize=(20,500))
    
    for row, name in tqdm(zip(axs, img_names)):
        im_path = dirs['imgs'] + name
        gt_path = dirs['gt'] + name.split('.')[0] + '.npy'
        pr_path = img_dir + name
        
        img = Image.open(im_path)
        
        # Overlay bkgd, so we only look at non-bkgd classifications
        pr = np.asarray(Image.open(pr_path))
        gt = np.load(gt_path)
        red_ind = gt == 255
        pr = (1 - red_ind) * pr + gt * red_ind
        
        gt = Dstl.decode_target(gt)
        pr = Dstl.decode_target(pr)
        
        row[0].set_title(name)
        row[0].imshow(img)
        row[1].imshow(gt)
        row[2].imshow(pr)
    
    plt.savefig("result/"+exp_name+".png")
    plt.close()


In [None]:
# Collect image names
EXP_NAME = 'seg_final'
EXP_DIR = dirs[EXP_NAME]
paths = ['./' + p.as_posix() for p in pathlib.Path(EXP_DIR).glob('*.png')]
img_names = set()
for p in paths:
    name = p.split('/')[-1].split('.')[0] + '.png'
    img_names.add(name)
img_names = list(img_names)

# Choose which to display
print("Number of available images: ", len(img_names))
display_grid(img_names, EXP_DIR, EXP_NAME)

In [50]:
imgs_interest = ['6120_2_2_2201_2174', '6060_2_3_2201_1174', 
                     '6120_2_2_1201_1174', '6110_4_0_1198_1174',
                     '6120_2_2_201_174', '6120_2_2_2201_1174',
                     '6110_3_1_198_174', '6100_2_2_195_174',
                     '6120_2_0_2201_1174', '6060_2_3_201_2174',
                     '6060_2_3_1201_1174', '6110_1_2_1198_1174',
                     '6100_2_3_195_1174']
imgs_interest = list(set(imgs_interest))

def display_progress(img_names, dir_names, output_name):
    # Setup axes of num_imgs x num_checkpoint_dirs
    nrows = len(img_names)
    ncols = len(dir_names) + 2
    assert np.all([d in dirs.keys() for d in dir_names])
    fig, axs = plt.subplots(nrows, ncols, figsize=(10 * ncols, 10 * nrows)) # (20, 500)
    
    for row, name in zip(axs, img_names):
        im_path = dirs['imgs'] + name + '.png'
        gt_path = dirs['gt'] + name + '.npy'
        
        # Get image
        img = np.asarray(Image.open(im_path))
        row[0].set_title(name)
        row[0].imshow(img)
        
        # Get ground truth label, find background pixels, and convert to grey
        gt = np.load(gt_path)
        red_ind = gt == 255
        gt = Dstl.decode_target(gt)
        gt[red_ind] = 200
        row[1].set_title('ground truth')
        row[1].imshow(gt)
        
        # Get image pred from each dir
        for col_idx in range(2,ncols):
            ckpt_dir = dir_names[col_idx-2]
            pr_path = dirs[ckpt_dir] + name + '.png'
            pr = np.asarray(Image.open(pr_path))
            
            pr = (1 - red_ind) * pr
            pr = Dstl.decode_target(pr)
            pr[red_ind] = 200
            
            row[col_idx].set_title(ckpt_dir)
            row[col_idx].imshow(pr)
    
    plt.savefig("result/progression_"+output_name+".png")
    plt.close()

PANEL_NAME = "with13k"
exp_names = ['seg_overfit', 'seg_generalize', 'seg_final',
             'adv_2.5k', 'adv_5k', 'adv_7.5k', 'adv_10k', 'adv_13k', 
             'memreg_2.5k_on7.5k', 'memreg_3k_on10k']
display_progress(imgs_interest, exp_names, PANEL_NAME)