In [4]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

""" goal: randomly generate a dataset of images with the following properties:
         training set:
             class 0: two cross symbols as the main predictor (add noise and random cropping of 
                                                                 crosses to make less predictive)
                      single solid rectangle in lower left corner as a distractor
             class 1: one cross symbol as main predictor
                      single solid rectangle in lower right corner as distractor
        
         test set:
             as per training but with distractors removed in all cases or most cases...
         
"""

def add_cross(img, class_label, cross_dims=1, add_noise=False):
    """ Add the predictive cross(es) to the image arbirarily (may accidentally intersect the distractor, but
            that's ok). Class label defines how many times the cross is placed to ensure no overlap
    
        Args: 
            class_label: 0 or 1
            add_noise: True/False - if True, add blurring, randomly truncate cross limbs proportional to how
                        long they are
    """

    x_centres = np.random.choice(range(cross_dims, 28-cross_dims), size=[class_label + 1], replace=False)
    y_centres = np.random.choice(range(cross_dims, 28-cross_dims), size=[class_label + 1], replace=False)
    for i in range(class_label+1):

        centre_x = x_centres[i]
        centre_y = y_centres[i]
        img[centre_x-cross_dims:centre_x+cross_dims+1, centre_y] = 1
        img[centre_x, centre_y-cross_dims:centre_y+cross_dims+1] = 1

def add_distractor(img, class_label, tag_dims=[2,3], add_noise=False):
    """ Add a distractor for prediction based on the class label - zero puts it on the left, one puts it on the
            right
        
        Args:
            tag_dims: list of the sizes for the rectangular distractor tag
            add_noise: True/False - if True, blurs the distractor and also shifts its location by some
                        random amount fixed within 2-5 pixels from the edges of the image
    
    """
    tag_buffer = 3
    size_x, size_y = img.shape
    if class_label == 0:
        # buffer of 3 from the border
        img[size_y-tag_buffer-tag_dims[0]:size_y-tag_buffer, tag_buffer:tag_buffer+tag_dims[1]] = 1
        tag_centre_x = (tag_buffer+tag_dims[1]) // 2
        tag_centre_y = 1
    else:
        img[size_y-tag_buffer-tag_dims[0]:size_y-tag_buffer, size_x-tag_buffer-tag_dims[1]:size_x-tag_buffer] = 1
    
    
# img_size = 28
# label = 1
# img_base = np.zeros([img_size,img_size])
# add_cross(img_base, label, np.random.randint(1,3))
# # capture the original image before the distractor gets added and this forms the segmentation
# img_seg = img_base 
# add_distractor(img_base, label)
# plt.imshow(img_base)
# plt.show()

In [28]:
""" Dataset Builder
    Parameters: 
        length: how many images should be generated
        mode: string of either train or test
"""
def make_synthetic_dataset(length, mode, root="/network/data1/GM", img_size=28):
    labels = {}
    labels_dict = pd.DataFrame(columns=["file","class"])
    
    for n in range(length):
        print("making ", n, " of ", length, " files")
        if n < length//2:
            label = 0
        else:
            label = 1 # even data split

        img_base = np.zeros([img_size,img_size])
        add_cross(img_base, label, np.random.randint(1,3))
        img_seg = img_base
        
        if mode == 'train':
            add_distractor(img_base, label)
        
        # save image and segmentation map to file
        np.save("{}/{}_img_{}.npy".format(root, mode, n), img_base)
        np.save("{}/{}_seg_{}.npy".format(root, mode, n), img_seg)

        labels["{}/{}_img_{}.npy".format(root, mode, n)] = label

    labels_dict["file"] = labels.keys()
    labels_dict["class"] = labels.values()
    labels_dict = labels_dict.set_index("file")
    labels_dict.to_csv("{}/{}_labels.csv".format(root, mode))

In [29]:
make_synthetic_dataset(500, "train")
make_synthetic_dataset(500, "valid")
make_synthetic_dataset(500, "test")

making  0  of  500  files
making  1  of  500  files
making  2  of  500  files
making  3  of  500  files
making  4  of  500  files
making  5  of  500  files
making  6  of  500  files
making  7  of  500  files
making  8  of  500  files
making  9  of  500  files
making  10  of  500  files
making  11  of  500  files
making  12  of  500  files
making  13  of  500  files
making  14  of  500  files
making  15  of  500  files
making  16  of  500  files
making  17  of  500  files
making  18  of  500  files
making  19  of  500  files
making  20  of  500  files
making  21  of  500  files
making  22  of  500  files
making  23  of  500  files
making  24  of  500  files
making  25  of  500  files
making  26  of  500  files
making  27  of  500  files
making  28  of  500  files
making  29  of  500  files
making  30  of  500  files
making  31  of  500  files
making  32  of  500  files
making  33  of  500  files
making  34  of  500  files
making  35  of  500  files
making  36  of  500  files
making  37 

making  305  of  500  files
making  306  of  500  files
making  307  of  500  files
making  308  of  500  files
making  309  of  500  files
making  310  of  500  files
making  311  of  500  files
making  312  of  500  files
making  313  of  500  files
making  314  of  500  files
making  315  of  500  files
making  316  of  500  files
making  317  of  500  files
making  318  of  500  files
making  319  of  500  files
making  320  of  500  files
making  321  of  500  files
making  322  of  500  files
making  323  of  500  files
making  324  of  500  files
making  325  of  500  files
making  326  of  500  files
making  327  of  500  files
making  328  of  500  files
making  329  of  500  files
making  330  of  500  files
making  331  of  500  files
making  332  of  500  files
making  333  of  500  files
making  334  of  500  files
making  335  of  500  files
making  336  of  500  files
making  337  of  500  files
making  338  of  500  files
making  339  of  500  files
making  340  of  500

making  109  of  500  files
making  110  of  500  files
making  111  of  500  files
making  112  of  500  files
making  113  of  500  files
making  114  of  500  files
making  115  of  500  files
making  116  of  500  files
making  117  of  500  files
making  118  of  500  files
making  119  of  500  files
making  120  of  500  files
making  121  of  500  files
making  122  of  500  files
making  123  of  500  files
making  124  of  500  files
making  125  of  500  files
making  126  of  500  files
making  127  of  500  files
making  128  of  500  files
making  129  of  500  files
making  130  of  500  files
making  131  of  500  files
making  132  of  500  files
making  133  of  500  files
making  134  of  500  files
making  135  of  500  files
making  136  of  500  files
making  137  of  500  files
making  138  of  500  files
making  139  of  500  files
making  140  of  500  files
making  141  of  500  files
making  142  of  500  files
making  143  of  500  files
making  144  of  500

making  407  of  500  files
making  408  of  500  files
making  409  of  500  files
making  410  of  500  files
making  411  of  500  files
making  412  of  500  files
making  413  of  500  files
making  414  of  500  files
making  415  of  500  files
making  416  of  500  files
making  417  of  500  files
making  418  of  500  files
making  419  of  500  files
making  420  of  500  files
making  421  of  500  files
making  422  of  500  files
making  423  of  500  files
making  424  of  500  files
making  425  of  500  files
making  426  of  500  files
making  427  of  500  files
making  428  of  500  files
making  429  of  500  files
making  430  of  500  files
making  431  of  500  files
making  432  of  500  files
making  433  of  500  files
making  434  of  500  files
making  435  of  500  files
making  436  of  500  files
making  437  of  500  files
making  438  of  500  files
making  439  of  500  files
making  440  of  500  files
making  441  of  500  files
making  442  of  500

making  205  of  500  files
making  206  of  500  files
making  207  of  500  files
making  208  of  500  files
making  209  of  500  files
making  210  of  500  files
making  211  of  500  files
making  212  of  500  files
making  213  of  500  files
making  214  of  500  files
making  215  of  500  files
making  216  of  500  files
making  217  of  500  files
making  218  of  500  files
making  219  of  500  files
making  220  of  500  files
making  221  of  500  files
making  222  of  500  files
making  223  of  500  files
making  224  of  500  files
making  225  of  500  files
making  226  of  500  files
making  227  of  500  files
making  228  of  500  files
making  229  of  500  files
making  230  of  500  files
making  231  of  500  files
making  232  of  500  files
making  233  of  500  files
making  234  of  500  files
making  235  of  500  files
making  236  of  500  files
making  237  of  500  files
making  238  of  500  files
making  239  of  500  files
making  240  of  500

In [155]:
from torch.utils.data import Dataset

class SytheticDataset(Dataset):
    def __init__(self, mode, root="data"):
        self.labels_set = pd.read_csv("{}/labels.csv")
        self.root = root
        self.mode = mode

    def __len__(self):
        return self.labels_set.shape[0]

    def __getitem__(self, index):
        filename = "{}/{}_img_{}.npy".format(self.root, self.mode, index)
        img = np.load(filename)
        img_seg = np.load("{}/{}_seg_{}.npy".format(self.root, self.mode, index))
        
        label = self.labels_set[filename]
        return (img, img_seg), label

In [31]:
labels = pd.read_csv("/network/data1/GM/train_labels.csv")

In [39]:
files = labels['file'].loc[labels['class'] == 0]

In [55]:
np.random.seed(10)
nsamples = 100
class0 = labels["file"].loc[labels["class"] == 0]
class1 = labels["file"].loc[labels["class"] == 1]
class0_files = np.random.choice(class0.values, nsamples//2, replace=False)
class1_files = np.random.choice(class1.values, nsamples//2, replace=False)
        
# get the corresponding segmentation files
class0_seg = [f.replace("img","seg") for f in class0]
class1_seg = [f.replace("img","seg") for f in class1]

idx = np.append(class1, class0)
mask_idx = np.append(class1_seg, class0_seg)
labels = np.append(np.ones(len(class1)), np.zeros(len(class0)))