In [None]:
from typing import List

import numpy as np
from matplotlib import pyplot as plt

from src.models.geometric_figure import GeometricFigure
from src.services.geometric_figure import (get_geometric_figures,
                                           plot_geometric_figures,
                                           preprocess_input)
from src.utils import preprocess
from src.utils import image_filters

In [None]:
IMAGE_SIZE = (128, 128)
DATA_VERSION = '2023-04-12'

In [None]:
geometric_figures: List[GeometricFigure] = get_geometric_figures(f'data/{DATA_VERSION}', IMAGE_SIZE)
print(f'Loaded {len(geometric_figures)} geometric figures')

In [None]:
def damage(image_array: np.ndarray) -> np.ndarray:
    x = image_array.copy()
    x = preprocess_input(x)
    image_array_preprocessed = preprocess_input(image_array)

    for _ in range(np.random.randint(20, 30)):
        x = image_filters.add_random_polygon(x, np.random.randint(10, 20, 2))

    x = preprocess.remove_stain(x, x.size*0.01)

    if np.random.rand() < 0.5:
        x = image_filters.add_noise(x, 0.01, np.random.randint(2, 4))
    
    similarity = np.sum(x == image_array_preprocessed) / x.size
    if similarity >= 0.998:
        return damage(image_array)

    white_pixels = np.sum(x == 1)
    similarity = 0
    while white_pixels == 0 or similarity >= 0.998:
        x = image_array.copy()
        x = image_filters.add_noise(x, 0.01, np.random.randint(2, 4))
        x = preprocess_input(x)
        white_pixels = np.sum(x == 1)
        similarity = np.sum(x == image_array_preprocessed) / x.size

    return x

In [None]:
number_of_plots = 7*10
columns = 7
random_geometric_figures = [geometric_figures[i] for i in np.random.randint(0, len(geometric_figures), number_of_plots)]
# plot_geometric_figures(random_geometric_figures, columns, cmap='gray')
plot_geometric_figures(random_geometric_figures, columns, preprocess=damage, cmap='gray')
plt.show()