In [16]:
import os
import sys
import inspect
from datasets.synth import SyntheticDataset
import json
import medpy
import collections
import h5py
import ntpath
import matplotlib.pyplots plt
import numpy as np
import pandas as pd

Make Dataset
--------------------

In [None]:
"""
goal: randomly generate a dataset of images with the following properties:
    training set:
        class 0: two cross symbols as the main predictor (add noise and random cropping of 
                 crosses to make less predictive). single solid rectangle in lower left 
                 corner as a distractor.
        class 1: one cross symbol as main predictor. single solid rectangle in lower 
                 right corner as distractor
    test set:
        Same as training but the relationship between the               
         
"""

def add_cross(img, class_label, cross_dims=5, add_noise=False):
    """ Add the predictive cross(es) to the image arbirarily (may accidentally intersect the distractor, but
            that's ok). Class label defines how many times the cross is placed to ensure no overlap
    
        Args: 
            class_label: 0 or 1
            add_noise: True/False - if True, add blurring, randomly truncate cross limbs proportional to how
                        long they are
    """

    x_centres = np.random.choice(range(cross_dims, 28-cross_dims), size=[class_label + 1], replace=False)
    y_centres = np.random.choice(range(cross_dims, 28-cross_dims), size=[class_label + 1], replace=False)
    for i in range(class_label+1):

        centre_x = x_centres[i]
        centre_y = y_centres[i]
        img[centre_x-cross_dims:centre_x+cross_dims+1, centre_y] = 1
        img[centre_x, centre_y-cross_dims:centre_y+cross_dims+1] = 1

        
def add_distractor(img, class_label, tag_dims=[2,3], add_noise=False):
    """
    Add a distractor for prediction based on the class label - zero puts it on the left, one puts it on the
            right
        
        Args:
            tag_dims: list of the sizes for the rectangular distractor tag
            add_noise: True/False - if True, blurs the distractor and also shifts its location by some
                        random amount fixed within 2-5 pixels from the edges of the image
    """
    tag_buffer = 5
    size_x, size_y = img.shape
    if class_label == 0:
        # buffer of 5 from the border
        img[size_y-tag_buffer-tag_dims[0]:size_y-tag_buffer, tag_buffer:tag_buffer+tag_dims[1]] = 1
        tag_centre_x = (tag_buffer+tag_dims[1]) // 2
        tag_centre_y = 1
    else:
        img[size_y-tag_buffer-tag_dims[0]:size_y-tag_buffer, size_x-tag_buffer-tag_dims[1]:size_x-tag_buffer] = 1

        
def make_synthetic_dataset(length, mode, folder, root="../data/synth3", img_size=28, seed=0):
    """ 
    Dataset Builder
    Parameters: 
        length: how many images should be generated
        mode: string of either train or test
    
    """
    labels = {}
    labels_dict = pd.DataFrame(columns=["file","class"])
    np.random.seed(seed)
    for n in range(length):
        print("making ", n, " of ", length, " files")
        if n < length//2:
            label = 0
        else:
            label = 1 # even data split

        img_base = np.zeros([img_size,img_size])
        add_cross(img_base, label, 5)  # np.random.randint(3,5)
        img_seg = np.zeros([img_size,img_size])
        img_seg[:,:] = img_base[:,:]
        
        if mode == 'distractor':
            add_distractor(img_base, label)
        
        # save image and segmentation map to file
        np.save("{}/{}_img_{}.npy".format(root, folder, n), img_base)
        np.save("{}/{}_seg_{}.npy".format(root, folder, n), img_seg)

        labels["{}/{}_img_{}.npy".format(root, folder, n)] = label

    labels_dict["file"] = labels.keys()
    labels_dict["file"] = labels_dict["file"].str.replace(root,"")
    labels_dict["file"] = labels_dict["file"].str.replace("/","")
    labels_dict["class"] = labels.values()
    labels_dict = labels_dict.set_index("file")
    labels_dict.to_csv("{}/{}_labels.csv".format(root, folder))

In [None]:
make_synthetic_dataset(512, "distractor", "distractor1", seed=1)
make_synthetic_dataset(512, "distractor", "distractor2", seed=2)
make_synthetic_dataset(512, "distractor", "distractor3", seed=3)

View Dataset
-------------------

In [17]:
# Prepare datasets.
DATAROOT = "../../data/synth2/"
train = SyntheticDataset(dataroot=DATAROOT, mode="distractor1", 
    blur=3, nsamples=10, distract_noise=0)
valid = SyntheticDataset(dataroot=DATAROOT, mode="distractor2", 
    blur=3, nsamples=10, distract_noise=1)

In [1]:
# Loop stuff.
dataloaders = [train, valid]
images = [[1,5],[4,6]]
titles = [
    ["D_train, Class 0", "D_train, Class 1"],
    ["D_valid, Class 0", "D_valid, Class 1"]
]

fig, axs = plt.subplots(nrows=2, ncols=2, squeeze=True)

for i, ax_dist in enumerate(axs):
    for j, ax in enumerate(ax_dist):
        img = dataloaders[i][images[i][j]]
        
        ax.imshow(img[0][0][0], interpolation='none', cmap='Greys_r')
        ax.set_title(titles[i][j])
        ax.get_xaxis().set_visible(False)       
        ax.get_yaxis().set_visible(False)
    
fig.set_tight_layout(tight=True)

NameError: name 'train' is not defined