In [1]:
from dataset import generate_dataset, load_dataset

## CIFAR Base Example

In [7]:
dataset_type = 'cifar'
image_size = 256
max_images = 100
class_labels = [1, 2, 3, 4, 5, 6]

In [8]:
base_images_by_class = load_dataset(dataset_type, class_labels, max_images=max_images, image_size=image_size)

shape_class_map_train = {
    'circle': [1],
    'square': [2],
    'triangle': [3]
}

shape_class_map_test = {
    'circle': [3],
    'square': [1],
    'triangle': [2]
}

Files already downloaded and verified


In [9]:
generate_dataset(base_images_by_class, shape_class_map_train, output_dir='./dataset_cifar_base/train', size=image_size, padding=50)
generate_dataset(base_images_by_class, shape_class_map_test, output_dir='./dataset_cifar_base/test', size=image_size, padding=50)

## Custom Base Dataset Example

In [2]:
dataset_type = 'custom'
image_size = 256
train_images = 10000
test_images = 1000

class_labels = {
    1: 'custom_images/zigzag',
    2: 'custom_images/chequered',
    3: 'custom_images/dots'
}

In [3]:
images_by_class = load_dataset(dataset_type, class_labels, max_images=test_images, image_size=image_size)

shape_class_map_train = {
    'circle': [1],
    'square': [2],
    'triangle': [3]
}

shape_class_map_test = {
    'circle': [3],
    'square': [1],
    'triangle': [2]
}

In [4]:
generate_dataset(images_by_class, shape_class_map_train, output_dir='./dataset_custom_base/train', size=image_size, padding=50, num_samples= 2000, shape_scale_range=(0.3, 1.0), position_jitter=0.5)
generate_dataset(images_by_class, shape_class_map_train, output_dir='./dataset_custom_base/test_same_backgrounds', size=image_size, padding=50, num_samples=200, shape_scale_range=(0.3, 1.0), position_jitter=0.5)
generate_dataset(images_by_class, shape_class_map_test, output_dir='./dataset_custom_base/test_swapped_backgrounds', size=image_size, padding=50, num_samples=200, shape_scale_range=(0.3, 1.0), position_jitter=0.5)

## Pattern Example

In [12]:
pattern_settings = {
    'stripes_label': {'pattern_type': 'stripes', 'colors': ((0, 0, 255), (255, 255, 255)), 'stripe_width': 15},
    'dots_label': {'pattern_type': 'dots', 'colors': ((0, 0, 255), (255, 255, 255)), 'dot_radius': 3, 'spacing': 10},
    'horizontal_stripes_label': {'pattern_type': 'horizontal_stripes', 'colors': ((0, 0, 255), (255, 255, 255)), 'stripe_width': 15}
}


images_by_class = load_dataset('pattern', ['stripes_label', 'dots_label', 'horizontal_stripes_label'], 
                               max_images=100, image_size=256, pattern_settings=pattern_settings)

In [13]:
train_options = {
    'circle': ['stripes_label'], 
    'triangle': ['dots_label'],
    'square': ['horizontal_stripes_label']
}

test_options = {
    'circle': ['dots_label'], 
    'triangle': ['horizontal_stripes_label'],
    'square': ['stripes_label']
}

In [14]:
generate_dataset(images_by_class, train_options, output_dir='./dataset_pattern_base/train', size=256, padding=40, margin=16, num_samples=1000, shape_scale_range=(0.3, 1.0), position_jitter=0.5)
generate_dataset(images_by_class, train_options, output_dir='./dataset_pattern_base/test_matched_patterns', size=256, padding=40, margin=16, num_samples=200, shape_scale_range=(0.3, 1.0), position_jitter=0.5)
generate_dataset(images_by_class, test_options, output_dir='./dataset_pattern_base/test_swapped_patterns', size=256, padding=40, margin=16, num_samples=200, shape_scale_range=(0.3, 1.0), position_jitter=0.5)