In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from mnist_generators_simple import gen_texture_mnist

In [None]:
unbiased_config = {
    'background_split': 2,
    'dataset_seed': 0,
    'tex_res' : 400,
    'tile_size' : 64,
    'train_samples' : 50000,
    'test_samples' : 1000,
    'exclude_bias_textures': False,
    'fix_test_set': True,
    'batch_size': 100,
    'bias': None,
    'textures_path' : 'textures/',
}
biased_config = {
    'background_split': 2,
    'dataset_seed': 0,
    'tex_res' : 400,
    'tile_size' : 64,
    'train_samples' : 50000,
    'test_samples' : 1000,
    'exclude_bias_textures': True,
    'fix_test_set': True,
    'batch_size': 100,
    'bias': {2 : {
        "source_1_id": "'feeccd96.png", 
        "source_2_id": "'f135d029.png",
        "source_1_bias": 0.0,
        "source_2_bias": 1.0
    }},
    'textures_path' : 'textures/',
}

In [None]:
def hide_ticks(ax):
    plt.setp(ax.get_xticklabels(), visible=False)
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.tick_params(axis='both', which='both', length=0)
    
def plot_samples(sample_generator, num_samples):
    i = 0
    for xs,ys,masks in sample_generator:
        for x,y,m in zip(xs,ys,masks):
            f,ax = plt.subplots(1,4)
            ax[0].imshow(x[...,0], cmap='gray')
            ax[0].set_title('Image')
            hide_ticks(ax[0])
            ax[1].imshow(np.argmax(y,-1), cmap='tab20',vmin=0,vmax=10)
            ax[1].set_title('Segmentation')
            hide_ticks(ax[1])
            ax[2].imshow(m['background'][...,0], cmap='gray',vmin=0,vmax=1)
            ax[2].set_title('Background')
            hide_ticks(ax[2])
            ax[3].imshow(m['biased_tile'][...,0], cmap='gray',vmin=0,vmax=1)
            ax[3].set_title('Biased Tile')
            hide_ticks(ax[3])
            plt.show()

            i+=1
            if i >= num_samples: 
                break
        if i >= num_samples:
            break

NUM_SAMPLES = 20  

print('Biased Test Set')
sample_generator = gen_texture_mnist(biased_config,'test')
plot_samples(sample_generator, NUM_SAMPLES)

print('Unbiased Test Set')
sample_generator = gen_texture_mnist(unbiased_config,'test')
plot_samples(sample_generator, NUM_SAMPLES)