In [19]:
from PIL import Image,ImageDraw
from random import randint
from dataclasses import dataclass
from pathlib import Path
from itertools import product, chain
import csv
import os

white = (255, 255, 255)
black = (0, 0, 0)

labels_file = Path('labels.csv')
data_dir = Path('./data')

@dataclass
class PolygonData:
    x: int
    y: int
    r: int
    rot: int

def random_polygon_data(offset, image_size):
    x, y = randint(offset, image_size - offset), randint(offset, image_size - offset)    
    r = randint(offset, min(x, y, image_size - x, image_size - y))
    rot = randint(0, 360)
    return PolygonData(x, y, r, rot)

def draw_polygon(n: int, image, polygon_data, color):
    polygon_image = image.copy()
    draw_polygon = ImageDraw.Draw(polygon_image)
    draw_polygon.regular_polygon((polygon_data.x, polygon_data.y, polygon_data.r), n, rotation=polygon_data.rot, fill=color)
    return polygon_image

def draw_polygons(n: list, image, offset):
    polygon_image = draw_polygon(n[0], image, random_polygon_data(offset, image.size[0]), white)
    for i in n[1:]:
        for _ in range(1000):
            new_data = random_polygon_data(offset, image.size[0])
            new_image = draw_polygon(i, polygon_image, new_data, black)
            if list(new_image.getdata()) == list(polygon_image.getdata()):
                polygon_image = draw_polygon(i, polygon_image, new_data, white)
                break
        else:
            raise ValueError("Did not manage to draw enough shapes")
    return polygon_image

def generate_filename(sides):
    polygon_names = {
        3: "triangle",
        4: "square",
        5: "pentagon",
        6: "hexagon",
        7: "heptagon",
        8: "octagon",
    }
    filename = ""
    for element in sorted(set(sides)):
        filename = filename + polygon_names[element] + "x" + str(sides.count(element)) 
    return filename

def generate_images(sides_list, labels, image_number, data_dir):
    if not data_dir.exists():
        os.mkdir(data_dir)
    with open(labels_file, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile, fieldnames=["file", "label"], delimiter=',', quoting=csv.QUOTE_MINIMAL)
        for sides, label in product(sides_list, labels):    
            for i in range(image_number):
                image = Image.new('RGB', (100, 100))
                offset = 10
                saving_path = f'/home/leo/projects/ml_exercises/triangles/dataset/{data_dir}/{generate_filename(sides)}_{i}.png'
                draw_polygons(sides, image, offset).save(saving_path)
                writer.writerow([f'{generate_filename(sides)}_{i}.png', label])