In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.draw import polygon,  ellipse
import time

%matplotlib inline

def generate_mask(seed, size_factor):
    np.random.seed(seed)
    mask = np.zeros((128, 128), dtype=np.uint8)

    shape_type = np.random.choice(['circle', 'ellipse', 'square', 'hexagon', 'irregular'])
    
    if shape_type == 'circle':
        radius = int(np.sqrt(size_factor / np.pi))
        center = np.random.randint(radius, 128 - radius, size=2)
        rr, cc = ellipse(center[0], center[1], radius, radius)
    elif shape_type == 'ellipse':
        size = int(np.sqrt(size_factor / np.pi))
        center = np.random.randint(size, 128 - size, size=2)
        radii = np.random.randint(size // 2, size, size=2)
        rotation = np.random.rand() * 180
        rr, cc = ellipse(center[0], center[1], radii[0], radii[1], rotation=np.deg2rad(rotation))
    elif shape_type == 'square':
        size = int(np.sqrt(size_factor))
        center = np.random.randint(size // 2, 128 - size // 2, size=2)
        rr, cc = polygon([center[0]-size//2, center[0]+size//2, center[0]+size//2, center[0]-size//2], 
                         [center[1]-size//2, center[1]-size//2, center[1]+size//2, center[1]+size//2])
    elif shape_type == 'hexagon':
        size = int(np.sqrt(size_factor / (3 * np.sqrt(3) / 2)))
        center = np.random.randint(size, 128 - size, size=2)
        angles = np.linspace(0, 2*np.pi, 7)[:-1] + np.random.rand() * np.pi/3
        rr, cc = polygon(center[0] + size * np.cos(angles), center[1] + size * np.sin(angles))
    else:  # irregular
        num_points = np.random.randint(5, 10)
        angles = np.sort(np.random.rand(num_points))
        angles = np.concatenate((angles, [1]))
        angles = angles * 2 * np.pi
        avg_radius = int(np.sqrt(size_factor / num_points))
        radii = np.random.randint(max(1, avg_radius // 2), max(2, avg_radius), size=num_points)
        radii = np.concatenate((radii, [radii[0]]))
        center = np.random.randint(max(1, avg_radius), 128 - max(1, avg_radius), size=2)
        rr, cc = polygon(center[0] + radii * np.cos(angles), center[1] + radii * np.sin(angles))

    mask[rr, cc] = 1
    return mask

start_time = time.time()
for i in range(2000):
    seed = np.random.randint(1000000)
    size_factor = np.random.randint(10, 1000)

    mask = generate_mask(seed, size_factor)

    plt.figure(figsize=(4, 4))
    plt.imshow(mask, cmap='gray')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'../data_augmentation_folder/synthetic_data/segmentation_mask_seed_{seed}_size_{size_factor}.png',bbox_inches='tight',pad_inches=-0.1)

    print(f"Seed to regenerate the mask: {seed}")
    print(f"Size factor used: {size_factor}")
    
print(f"that took {time.time() - start_time} seconds")