# CSprites Dataset Generator


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Python
import timeit
from datetime import datetime
import pickle as pkl
from pathlib import Path

# Installed
from PIL import Image
from tqdm import tqdm
import pprint
import matplotlib.pyplot as plt
import numpy as np

# Local
from csprites.configs import be_config
from csprites.masks import get_shapes
from csprites.colors import get_colors, apply_color
from csprites.scales import get_shape_sizes_evenly_scaled
from csprites.angles import get_angles, apply_angle
from csprites.positions import get_max_positions, centered_position_idcs, get_position_idcs_from_center
from csprites.backgrounds import get_bg_func
from csprites.utils import MeanStdTracker, shape_sizes_to_html, masks_to_html_animation


# Configs

In [None]:
#be_config["p_base"] = Path("/mnt/data/csprites-models")
pprint.pprint(be_config)

## Load Config

In [None]:
p_data = Path("data")
p_ds_config = p_data / "config.pkl"
with open(p_ds_config, "rb") as file:
    config = pkl.load(file)
pprint.pprint(config)

## Validity checks

In [None]:
# for webserver usage, we have to do several checks before:
#    - existency checks
#    - type checks

img_size = config["img_size"]
assert img_size in be_config.img_sizes

# shapes
shape_names = config['shapes']
for name in shape_names:
    assert name in be_config.all_shape_names
n_shapes = len(shape_names)

# colors 
n_colors = config['n_colors']
assert be_config.n_colors_min <= n_colors
assert n_colors <= be_config.n_colors_max

# angles
n_angles = config['n_angles']
assert be_config.n_angles_min <= n_angles
assert n_angles <= be_config.n_angles_max

# backgrounds
n_bg = config['n_bg']
if n_bg != np.inf:
    assert be_config.n_bgs_min <= n_bg
    assert n_bg <= be_config.n_bgs_max

bg_style = config["bg_style"]
assert bg_style in be_config.bg_styles

target_bbox = config["target_bbox"]
target_segm = config["target_segm"]

# scales
n_scales = config["n_scales"]
assert be_config.n_scales_min <= n_scales
assert n_scales <= be_config.n_scales_max

# positions
n_positions = config["n_positions"]
assert be_config.n_positions_min <= n_positions
assert n_positions <= be_config.n_positions_max

# fill rates
min_mask_fill_rate = config["min_mask_fill_rate"]
assert 0 < min_mask_fill_rate < 1

max_mask_fill_rate = config["max_mask_fill_rate"]           
assert min_mask_fill_rate <= max_mask_fill_rate
assert max_mask_fill_rate < 1

In [None]:
# further checks
shapes = get_shapes(shape_names)
#
# shape_sizes
min_mask_area = img_size** 2 * min_mask_fill_rate
max_mask_area = img_size** 2 * max_mask_fill_rate
#
# Notes: If limited to 0.1 steps or so, we could precalculate the max_mask_size and subsequently max positions ect
shape_sizes, max_mask_size = get_shape_sizes_evenly_scaled(shapes, min_mask_area, max_mask_area)

# n_scales_max
n_scales_max = len(list(shape_sizes.values())[0])

# n_positions_max
n_positions_max = get_max_positions(img_size, max_mask_size)
#
assert n_scales <= n_scales_max
assert n_positions <= n_positions_max

In [None]:
# subset
generate_subset = config["subset"]
n_train = config["n_train"]
n_valid = config["n_valid"]
n_samples = n_train + n_valid

assert be_config.n_samples_min <= n_samples
assert n_samples < be_config.n_samples_max
train_rate = n_train / n_samples

n_masks = n_shapes * n_colors * n_angles * n_scales
n_states = n_masks * n_positions**2
assert n_states >= n_samples

sampling_rate = n_samples / n_states
mem_usage_bytes = img_size**2 * 3 * n_samples
assert mem_usage_bytes * 1e-9 <= be_config.mem_usage_gb_max

In [None]:
print("*"*40)
print("{:^40}".format("Dataset Stats"))
print("*"*40)
print("#Masks:        {:>10}".format(n_masks))
print("#States:       {:>10}".format(n_states))
print("#Samples:      {:>10}".format(n_samples))
print("#Train:        {:>10}".format(n_train))
print("#Valid:        {:>10}".format(n_valid))
print("*"*40)
print("Train rate:     {:.2f}".format(train_rate))
print("Sampling rate:  {:.2f}".format(sampling_rate))
print("Mem usage [mb]: {:.2f}".format(mem_usage_bytes * 1e-6))
print("Mem usage [Gb]: {:.2f}".format(mem_usage_bytes * 1e-9))# Calculate Stuff

# Generate Classes

In [None]:
colors = get_colors(n_colors + 1, cmap=be_config.cmap_colors)
colors = colors[1:]
angles = get_angles(n_angles)
positions = centered_position_idcs(n_positions, n_positions_max, max_mask_size)
#
bg_shape = (img_size, img_size, 3)
bg_func = get_bg_func(bg_shape, n_bg, bg_style)
#
shape_map = {idx: shape.name for idx, shape in enumerate(shapes)}
angle_map = {idx: angle for idx, angle in enumerate(angles)}
color_map = {idx: list(color) for idx, color in enumerate(colors)}
sizes_map = {key.name : value for key,value in shape_sizes.items()}
posit_map = {idx: pos for idx,pos in enumerate(positions)}

In [None]:
if generate_subset:
    # Generate Params for Subset [good for large state space]
    s_shapes = np.random.choice(a=n_shapes, size=n_samples, replace=True, p=[1/n_shapes] * n_shapes)
    s_scales = np.random.choice(a=n_scales, size=n_samples, replace=True, p=[1/n_scales] * n_scales)
    s_colors = np.random.choice(a=n_colors, size=n_samples, replace=True, p=[1/n_colors] * n_colors)
    s_angles = np.random.choice(a=n_angles, size=n_samples, replace=True, p=[1/n_angles] * n_angles)
    s_px = np.random.choice(a=n_positions, size=n_samples, replace=True, p=[1/n_positions] * n_positions)
    s_py = np.random.choice(a=n_positions, size=n_samples, replace=True, p=[1/n_positions] * n_positions)
    #
    class_targets = np.stack([s_shapes, s_scales, s_colors, s_angles, s_py, s_px]).T
else:
    # Generate Params for whole Space [fine for small state space]
    class_targets = []
    for shape_idx in range(n_shapes):
        for scale_idx in range(n_scales):
            for angle_idx in range(n_angles):
                for color_idx in range(n_colors):
                    for py_idx in range(n_positions):
                        for px_idx in range(n_positions):
                            classes = (shape_idx, scale_idx, color_idx, angle_idx, py_idx, px_idx)
                            class_targets.append(classes)
    class_targets = np.array(class_targets)
unique_classes = np.unique(class_targets, axis=0).shape[0]

In [None]:
sizes_map

# Generate Dataset

In [None]:
csprices_type = "single"
ds_name = be_config.ds_name_tmp.format(
    csprices_type,img_size,img_size,n_shapes, n_colors,
    n_angles, n_positions, n_scales, n_bg, bg_style, n_samples)

p_data = be_config.p_base / ds_name
p_data.mkdir(exist_ok=True, parents=True)
#
p_X_train = p_data / be_config["p_X_train"]
p_Y_train_clas = p_data / be_config["p_Y_train_clas"]
p_Y_train_segm = p_data / be_config["p_Y_train_segm"]
p_Y_train_bbox = p_data / be_config["p_Y_train_bbox"]
#
p_X_valid = p_data / be_config["p_X_valid"]
p_Y_valid_clas = p_data / be_config["p_Y_valid_clas"]
p_Y_valid_segm = p_data / be_config["p_Y_valid_segm"]
p_Y_valid_bbox = p_data / be_config["p_Y_valid_bbox"]

#
p_config = p_data / be_config["p_config"]
#
p_gifs = p_data / be_config["p_gifs"]
p_gifs.mkdir(exist_ok=True)
#
#
p_gifs_shapes_dir = p_gifs / "shapes"
p_gifs_shapes_dir.mkdir(exist_ok=True)
p_gif_colors = p_gifs / "colors.gif"
p_gif_bg = p_gifs / "backgrounds.gif"
#
#
p_imgs = p_data / be_config["p_imgs"]
p_imgs.mkdir(exist_ok=True)
#
if target_segm:
    p_segs = p_data / be_config["p_segs"]
    p_segs.mkdir(exist_ok=True)

In [None]:
debug = False
n_debug = 10
plot = False
imgs = []
#
targets_bbox = []
targets_segm = []
#
tracker = MeanStdTracker()
#
start = timeit.default_timer()
for sample_idx, (shape_idx, scale_idx, color_idx, angle_idx, py_idx, px_idx) in enumerate(tqdm(class_targets)):
    shape = shapes[shape_idx]
    size = shape_sizes[shape][scale_idx]
    angle = angles[angle_idx]
    color = colors[color_idx]
    px = positions[px_idx]  # center position width
    py = positions[py_idx]  # center position height
    #
    mask = shape.create(size)
    mask = apply_angle(mask, angle)
    #
    h_mask, w_mask = mask.shape
    #
    assert h_mask % 2 == 1
    assert w_mask % 2 == 1
    
    #mask = pad_mask(mask, max_mask_size)
    if target_bbox:
        # corners: (upper left, lower right)
        w_shape = (mask.sum(axis=0) > 0).sum()
        h_shape = (mask.sum(axis=1) > 0).sum()
        #
        if w_shape % 2 == 0:
            w_shape+=1
        if h_shape % 2 == 0:
            h_shape+=1
        #
        assert w_shape % 2 == 1
        assert h_shape % 2 == 1
        
        y0 = max(0, py - h_shape // 2 - 1)
        x0 = max(0, px - w_shape // 2 - 1)
        y1 = min(py + h_shape // 2 + 1, img_size)
        x1 = min(px + w_shape // 2 + 1, img_size)
        targets_bbox.append((y0, x0 , y1, x1))

    x0, y0, x1, y1 = get_position_idcs_from_center(h_mask, w_mask, px, py)
    #
    if target_segm:
        seg_map = np.zeros((img_size, img_size)).astype(np.uint8)
        seg_map[y0: y1, x0: x1] = mask
        p_seg = p_segs / be_config["seg_name"].format(sample_idx)
        #
        Image.fromarray(seg_map).save(p_seg)
        targets_segm.append(p_seg.name)
    #
    mask_wo_color = np.copy(mask)
    mask = apply_color(mask, color)
    #
    img = bg_func()
    img[y0: y1, x0: x1,:][mask_wo_color > 0] = mask[mask_wo_color > 0]
    #
    p_img = p_imgs / be_config["img_name"].format(sample_idx)
    Image.fromarray(img).save(p_img)
    tracker.add(img/255)
    imgs.append(p_img.name)
    
    if debug and sample_idx > n_debug:
        break
    if plot:
        fig, axes = plt.subplots(1, 3)
        # mark center
        img[py, px] = 0
        seg_map[py, px] = 0
        
        # mark bboxes
        y0, x0, y1, x1 = targets_bbox[-1]
        seg_map[y0,x0] = 1
        seg_map[y1,x0] = 1
        seg_map[y0,x1] = 1
        seg_map[y1,x1] = 1
        
        img[y0,x0] = 0
        img[y1,x0] = 0
        img[y0,x1] = 0
        img[y1,x1] = 0
        #
        axes[0].imshow(img)
        axes[1].imshow(seg_map)
        axes[2].imshow(img * (1 - seg_map.reshape((img.shape[0], img.shape[1], 1))))
        plt.show()
               
elapsed  = timeit.default_timer() - start
print("{:.3f}".format(elapsed))

In [None]:
means, stds = tracker.get()

# Generate GIFS

In [None]:
# Shapes and sizes
shape_sizes_to_html(shape_sizes, angles, max_mask_size, p_data=p_gifs_shapes_dir)
#
# Colors
mask = shapes[-1].create(shape_sizes[shapes[-1]][-1])
masks = [apply_color(mask, color) for color in colors]
html_str = masks_to_html_animation(masks, p_gif_colors)

# BG
masks = [bg_func() for _ in range(20)]
html_str = masks_to_html_animation(masks, p_gif_bg, interval=300)

# Generate Config

In [None]:
gen_config = {
    # General DS Information
    'date': datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
    'version': 1.0,
    'type': csprices_type,
    'n_states': n_states,
    'n_masks': n_masks,
    'n_train': n_train,
    'n_valid': n_valid,
    'n_samples': n_samples,
    'n_unique': unique_classes,
    'sampling_rate': sampling_rate,
    'train_rate': train_rate,
    'subset': generate_subset,
    'img_size': img_size,
    'subset': generate_subset,
    'name': ds_name,
    'memory_usage': "{:.3f} GB".format(mem_usage_bytes * 1e-9),
    'color_cmap': be_config.cmap_colors,
    
    # Targets
    'target_bbox': target_bbox,
    'target_segm': target_segm,
    
    # Generative Factors
    'classes': ['shape', 'scale', 'color', 'angle', 'py', 'px'],
    'n_classes': {
        'shape': n_shapes,
        'scale': n_scales,
        'color': n_colors,
        'angle': n_angles,
        'px': n_positions,
        'py': n_positions,
    },
    'class_maps': {
        'shape': shape_map,
        'angle': angle_map,
        'color': color_map,
        'scale': sizes_map,
        'position': posit_map
    },
    'n_bg': n_bg,
    'bg_style': bg_style,
    'shapes': shape_names,
    'min_mask_fill_rate': min_mask_fill_rate,
    'max_mask_fill_rate': max_mask_fill_rate,
    'max_mask_size': max_mask_size,
    
    # PATHS
    'p_X_train': p_X_train.name,
    'p_X_valid': p_X_valid.name,
    'p_Y_train_clas': p_Y_train_clas.name,
    'p_Y_valid_clas': p_Y_valid_clas.name,
    'p_imgs': p_imgs.name,
    'p_gifs': p_gifs.name,
    
    # stds & means
    'means': means,
    'stds': stds,
    
}
if target_bbox:
    gen_config['p_Y_train_bbox'] = p_Y_train_bbox.name
    gen_config['p_Y_valid_bbox'] = p_Y_valid_bbox.name
if target_segm:
    gen_config['p_Y_train_segm'] = p_Y_train_segm.name
    gen_config['p_Y_valid_segm'] = p_Y_valid_segm.name
    gen_config["p_segs"] = p_segs.name

pprint.pprint(gen_config)

# Train Test split

In [None]:
X = np.array(imgs)
assert n_samples == X.shape[0]

In [None]:
Y_clas = class_targets
print(Y_clas.shape)
#
if targets_bbox:
    Y_bbox = np.array(targets_bbox)
    print(Y_bbox.shape)

if targets_segm:
    Y_segm = np.array(targets_segm)
    print(Y_segm.shape)

In [None]:
idcs = np.arange(n_samples)
np.random.shuffle(idcs)
#
train_idcs = idcs[:n_train]
valid_idcs = idcs[n_train:]
#
X_train = X[train_idcs]
X_valid = X[valid_idcs]
#
Y_train_clas = Y_clas[train_idcs]
Y_valid_clas = Y_clas[valid_idcs]

if target_bbox:
    Y_train_bbox = Y_bbox[train_idcs]
    Y_valid_bbox = Y_bbox[valid_idcs]
    #
    np.save(p_Y_train_bbox, Y_train_bbox)
    np.save(p_Y_valid_bbox, Y_valid_bbox)

if target_segm:
    Y_train_segm = Y_segm[train_idcs]
    Y_valid_segm = Y_segm[valid_idcs]
    #
    np.save(p_Y_train_segm, Y_train_segm)
    np.save(p_Y_valid_segm, Y_valid_segm)
#
np.save(p_Y_train_clas, Y_train_clas)
np.save(p_Y_valid_clas, Y_valid_clas)
np.save(p_X_train, X_train)
np.save(p_X_valid, X_valid)
#
with open(p_config, "wb") as file:
    pkl.dump(gen_config, file)

In [None]:
p_data

In [None]:
plt.imshow(np.array(Image.open(p_data / "segs" / "csprite0_seg.png")))

# Show all colors

In [None]:
colors

In [None]:
for c in colors:
    plt.figure(figsize=(2, 2))
    img = np.ones((5, 5, 3), dtype=np.uint8) * c
    plt.imshow(img)
    plt.show()

# Show random samples


In [None]:
show_n = 100
np.random.shuffle(idcs)

In [None]:
for idx in idcs[:show_n]:
    img = np.array(Image.open(p_imgs / X[idx]))
    plt.imshow(img)
    plt.show()

# Color

In [None]:
mask = shape.create(size)
mask = apply_angle(mask, angle)
#
h_mask, w_mask = mask.shape
#
x0, y0, x1, y1 = get_position_idcs_from_center(h_mask, w_mask, px, py)

In [None]:
mask_o = np.copy(mask)
plt.imshow(mask_o)
plt.show()
#
img_o = bg_func()
plt.imshow(img_o)

In [None]:
for color in colors:
    mask = np.copy(mask_o)
    mask_wo_color = np.copy(mask)
    mask = apply_color(mask, color)
    #
    img = np.copy(img_o)
    img[y0: y1, x0: x1,:][mask_wo_color > 0] = mask[mask_wo_color > 0]
    plt.imshow(img)
    plt.show()

In [None]:
p_data