In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import *
import pathlib
import os

np.random.seed(37)

def get_xy(max_h=1.0, max_w=1.0):
    x = np.random.uniform(low=0.2, high=0.8)
    y = np.random.uniform(low=0.2, high=0.8)
    return x, y

def get_radius(r_max=0.1):
    return np.random.uniform(low=0.05, high=r_max)

def get_height(h_max=0.4):
    return np.random.uniform(low=0.1, high=h_max)

def get_width(h_max=0.4):
    return np.random.uniform(low=0.1, high=h_max)

def get_color():
    return np.random.choice(['b', 'g', 'r', 'c', 'm', 'y', 'k'])

def get_circle():
    return Circle(**{
        'xy': get_xy(), 
        'radius': get_radius(),
        'fc': get_color()
    })

def get_rect():
    return Rectangle(**{
        'xy': get_xy(), 
        'width': get_width(), 
        'height': get_height(), 
        'fc': get_color()
    })

def get_poly():
    return RegularPolygon(**{
        'xy': get_xy(), 
        'numVertices': 3, 
        'radius': get_radius(),
        'fc': get_color()
    })

def get_plot(width=2, height=2):
    fig = plt.figure()
    fig.set_size_inches(width, height)
    
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    
    return fig, ax

def do_plot(func, shape, f=0, d='./shapes', s='train'):
    fig, ax = get_plot()
    ax.add_artist(func())
    
    file = str(f).zfill(4)
    path = f'{d}/{s}/{shape}/{file}.jpg'
    fig.savefig(path, quality=100, optimize=True)
    
    plt.clf()
    ax.clear()
    plt.close()

In [3]:
jpg_files = pathlib.Path('./shapes').glob('**/*.jpg')
for f in jpg_files:
    os.remove(f)

In [4]:
counter = 0
for s, m in zip(['train', 'test', 'valid'], [10, 5, 5]):
    for shape, func in zip(['circle', 'poly', 'rect'], [get_circle, get_poly, get_rect]):
        for _ in range(m):
            do_plot(func=func, shape=shape, f=counter, s=s)
            counter += 1

In [7]:
jpg_files = list(pathlib.Path('./shapes').glob('**/*.jpg'))
for f in jpg_files:
    if '.ipynb_checkpoints' in f.parts:
        print(f)
        os.remove(f)