In [4]:
%pip install "drawsvg~=2.0"

import csv
import itertools
import random
import os
import drawsvg as dw
import numpy as np

Note: you may need to restart the kernel to use updated packages.


## Define sampling params here

In [5]:
sample_size = 50
length_variation_max = 10
variation_step = 0.5
fixed_vars = ["line_color"]

w, h = 600, 600

save_path = "../jnd_images_svg/bar_length"
os.makedirs(save_path, exist_ok=True)

## Define hyperparam space

In [12]:
hparam_dict = {
    "line_color": ["black", "none"],
    "fill_color": ["red", "green", "blue"],
    "shape_location": [0.25, 0.5, 0.75],
    "shape_dist": [0.03, 0.05, 0.07],
    "reference_width": [20, 40, 60],
    "reference_height": [5, 10, 15],
}

selected_hparams = [key for key in hparam_dict if key not in fixed_vars]
selected_hparam_vals = [hparam_dict[key] for key in selected_hparams]
fixed_hparam_vals = [hparam_dict[key] for key in fixed_vars]
print(selected_hparams)

all_hparams = list(itertools.product(*selected_hparam_vals))
fixed_hparams_exhaustive = list(itertools.product(*fixed_hparam_vals))

print(len(all_hparams))
print(len(fixed_hparams_exhaustive))

['fill_color', 'shape_location', 'shape_dist', 'reference_width', 'reference_height']
243
2


## Sample

In [13]:
random_hparam_configs = random.choices(all_hparams, k=sample_size)

## Generate & save

In [16]:
oob_count = 0
overlapping_count = 0
total_count = 0
param_d_list = []

# Iterate over randomly selected hparam configurations minus the fixed hparam
for i, config in enumerate(random_hparam_configs):
    config_d = dict(zip(selected_hparams, config))
    
    # Iterate over all values of a fixed variable
    for j, fixed_config in enumerate(fixed_hparams_exhaustive):
        config_d.update(dict(zip(fixed_vars, fixed_config)))
        
        # Iterate over target manipulation (length of the bar: implemented as width for now)
        for k, w_diff in enumerate(np.arange(0, length_variation_max, variation_step)):
            location = config_d["shape_location"]
            ref_bar_w = config_d["reference_width"]
            ref_bar_h = config_d["reference_height"]
            shape_dist = config_d["shape_dist"]
            line_color = config_d["line_color"]
            fill_color = config_d["fill_color"]
            
            ref_x, ref_y = w*location, h*location
            x, y = w*location, h*(location+shape_dist)

            bar_w = ref_bar_w+w_diff
            
            max_ref_x = ref_x + ref_bar_w
            max_ref_y = ref_y + ref_bar_h
            max_x = x + bar_w
            max_y = y + ref_bar_h

            # Count cases where shapes go out of bounds
            total_count += 1
            oob = False
            overlapping = False
            if max_ref_x > w or max_x > w or max_ref_y > h or max_y > h:
                oob = True
                oob_count += 1
                
            # Count cases where shapes are overlapping with each other
            if max_ref_y > y:
                overlapping = True
                overlapping_count += 1
                
            if oob or overlapping:
                continue

            d = dw.Drawing(
                w, h, origin=(0, 0),
                context=None,
                animation_config=None,
                id_prefix='d'
            )
            ref_bar = dw.Rectangle(
                ref_x,
                ref_y,
                ref_bar_w,
                ref_bar_h,
                stroke=line_color,
                stroke_width=0.5,
                fill=fill_color
            )
            bar = dw.Rectangle(
                x,
                y,
                bar_w,
                ref_bar_h,
                stroke=line_color,
                stroke_width=0.5,
                fill=fill_color
            )
            d.append(ref_bar)
            d.append(bar)

#             display(d)
            idx = f"{i}_{j}_{k}"
            filename = f"{idx}.svg"
            print(filename)
            d.save_svg(os.path.join(save_path, filename))
        
            param_d = {
                "idx": idx,
                "filename": filename,
                "ref_x": ref_x,
                "ref_y": ref_y,
                "ref_bar_w": ref_bar_w,
                "ref_bar_h": ref_bar_h,
                "x": x,
                "y": y,
                "bar_w": bar_w,
                "line_color": line_color,
                "fill_color": fill_color,
                "shape_location": location,
                "shape_dist": shape_dist,
                "w_diff": w_diff
            }

            param_d_list.append(param_d)

print("Total # images: ", total_count)
print("Total out of bound images: ", oob_count)
print("Total overlapping: ", overlapping_count)

0_0_0.svg
0_0_1.svg
0_0_2.svg
0_0_3.svg
0_0_4.svg
0_0_5.svg
0_0_6.svg
0_0_7.svg
0_0_8.svg
0_0_9.svg
0_0_10.svg
0_0_11.svg
0_0_12.svg
0_0_13.svg
0_0_14.svg
0_0_15.svg
0_0_16.svg
0_0_17.svg
0_0_18.svg
0_0_19.svg
0_1_0.svg
0_1_1.svg
0_1_2.svg
0_1_3.svg
0_1_4.svg
0_1_5.svg
0_1_6.svg
0_1_7.svg
0_1_8.svg
0_1_9.svg
0_1_10.svg
0_1_11.svg
0_1_12.svg
0_1_13.svg
0_1_14.svg
0_1_15.svg
0_1_16.svg
0_1_17.svg
0_1_18.svg
0_1_19.svg
1_0_0.svg
1_0_1.svg
1_0_2.svg
1_0_3.svg
1_0_4.svg
1_0_5.svg
1_0_6.svg
1_0_7.svg
1_0_8.svg
1_0_9.svg
1_0_10.svg
1_0_11.svg
1_0_12.svg
1_0_13.svg
1_0_14.svg
1_0_15.svg
1_0_16.svg
1_0_17.svg
1_0_18.svg
1_0_19.svg
1_1_0.svg
1_1_1.svg
1_1_2.svg
1_1_3.svg
1_1_4.svg
1_1_5.svg
1_1_6.svg
1_1_7.svg
1_1_8.svg
1_1_9.svg
1_1_10.svg
1_1_11.svg
1_1_12.svg
1_1_13.svg
1_1_14.svg
1_1_15.svg
1_1_16.svg
1_1_17.svg
1_1_18.svg
1_1_19.svg
2_0_0.svg
2_0_1.svg
2_0_2.svg
2_0_3.svg
2_0_4.svg
2_0_5.svg
2_0_6.svg
2_0_7.svg
2_0_8.svg
2_0_9.svg
2_0_10.svg
2_0_11.svg
2_0_12.svg
2_0_13.svg
2_0_14.svg
2_0_1

19_0_17.svg
19_0_18.svg
19_0_19.svg
19_1_0.svg
19_1_1.svg
19_1_2.svg
19_1_3.svg
19_1_4.svg
19_1_5.svg
19_1_6.svg
19_1_7.svg
19_1_8.svg
19_1_9.svg
19_1_10.svg
19_1_11.svg
19_1_12.svg
19_1_13.svg
19_1_14.svg
19_1_15.svg
19_1_16.svg
19_1_17.svg
19_1_18.svg
19_1_19.svg
20_0_0.svg
20_0_1.svg
20_0_2.svg
20_0_3.svg
20_0_4.svg
20_0_5.svg
20_0_6.svg
20_0_7.svg
20_0_8.svg
20_0_9.svg
20_0_10.svg
20_0_11.svg
20_0_12.svg
20_0_13.svg
20_0_14.svg
20_0_15.svg
20_0_16.svg
20_0_17.svg
20_0_18.svg
20_0_19.svg
20_1_0.svg
20_1_1.svg
20_1_2.svg
20_1_3.svg
20_1_4.svg
20_1_5.svg
20_1_6.svg
20_1_7.svg
20_1_8.svg
20_1_9.svg
20_1_10.svg
20_1_11.svg
20_1_12.svg
20_1_13.svg
20_1_14.svg
20_1_15.svg
20_1_16.svg
20_1_17.svg
20_1_18.svg
20_1_19.svg
21_0_0.svg
21_0_1.svg
21_0_2.svg
21_0_3.svg
21_0_4.svg
21_0_5.svg
21_0_6.svg
21_0_7.svg
21_0_8.svg
21_0_9.svg
21_0_10.svg
21_0_11.svg
21_0_12.svg
21_0_13.svg
21_0_14.svg
21_0_15.svg
21_0_16.svg
21_0_17.svg
21_0_18.svg
21_0_19.svg
21_1_0.svg
21_1_1.svg
21_1_2.svg
21_1_3.svg


38_1_0.svg
38_1_1.svg
38_1_2.svg
38_1_3.svg
38_1_4.svg
38_1_5.svg
38_1_6.svg
38_1_7.svg
38_1_8.svg
38_1_9.svg
38_1_10.svg
38_1_11.svg
38_1_12.svg
38_1_13.svg
38_1_14.svg
38_1_15.svg
38_1_16.svg
38_1_17.svg
38_1_18.svg
38_1_19.svg
39_0_0.svg
39_0_1.svg
39_0_2.svg
39_0_3.svg
39_0_4.svg
39_0_5.svg
39_0_6.svg
39_0_7.svg
39_0_8.svg
39_0_9.svg
39_0_10.svg
39_0_11.svg
39_0_12.svg
39_0_13.svg
39_0_14.svg
39_0_15.svg
39_0_16.svg
39_0_17.svg
39_0_18.svg
39_0_19.svg
39_1_0.svg
39_1_1.svg
39_1_2.svg
39_1_3.svg
39_1_4.svg
39_1_5.svg
39_1_6.svg
39_1_7.svg
39_1_8.svg
39_1_9.svg
39_1_10.svg
39_1_11.svg
39_1_12.svg
39_1_13.svg
39_1_14.svg
39_1_15.svg
39_1_16.svg
39_1_17.svg
39_1_18.svg
39_1_19.svg
40_0_0.svg
40_0_1.svg
40_0_2.svg
40_0_3.svg
40_0_4.svg
40_0_5.svg
40_0_6.svg
40_0_7.svg
40_0_8.svg
40_0_9.svg
40_0_10.svg
40_0_11.svg
40_0_12.svg
40_0_13.svg
40_0_14.svg
40_0_15.svg
40_0_16.svg
40_0_17.svg
40_0_18.svg
40_0_19.svg
40_1_0.svg
40_1_1.svg
40_1_2.svg
40_1_3.svg
40_1_4.svg
40_1_5.svg
40_1_6.svg
40_

## Save metadata

In [17]:
print(param_d_list[-1])

with open(os.path.join(save_path, "bar_length_jnd_images_svg.csv"), 'w') as wf:
    writer = csv.DictWriter(wf, fieldnames=param_d_list[-1].keys())
    writer.writeheader()
    for d in param_d_list:
        writer.writerow(d)

{'idx': '49_1_19', 'filename': '49_1_19.svg', 'ref_x': 450.0, 'ref_y': 450.0, 'ref_bar_w': 60, 'ref_bar_h': 10, 'x': 450.0, 'y': 492.00000000000006, 'bar_w': 69.5, 'line_color': 'none', 'fill_color': 'red', 'shape_location': 0.75, 'shape_dist': 0.07, 'w_diff': 9.5}
