In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch

In [11]:
def generate_color_image(n_colors, n_means, size=512, lf_noise=0.3, mf_noise=0.3, hf_noise=0.3):
    n = n_means
    k = n_colors

    low_frequency_noise_factor = lf_noise
    med_frequency_noise_factor = mf_noise
    high_frequency_noise_factor = hf_noise

    # Generate n random points in the image
    points = torch.randint(0, size, size=(n, 2), dtype=torch.float32)

    # Generate n random numbers between 0.5 and 1
    influences = torch.rand(n) * 0.5 + 0.5

    # Assign labels to each point
    labels = torch.arange(k, dtype=torch.int64).repeat(n // k + 1)[:n]

    # Create a grid of x and y values for all pixels in the image
    x, y = torch.meshgrid(torch.arange(size), torch.arange(size))
    pixels = torch.stack([x, y], dim=-1).float()

    # Compute the distances between all points and all pixels in the image
    distances = torch.cdist(pixels.view(-1, 2), points)

    # Scale the distances by the influence of the corresponding point
    distances = distances * influences

    # Find the index of the closest point for each pixel
    _, closest = torch.min(distances, dim=-1)
    image = labels[closest].reshape(size, size).numpy()

    cmap = plt.get_cmap('gist_ncar')
    rgba = cmap(image.astype(np.float32) / np.max(image))
    rgb = rgba[:, :, :3]

    low_freq_noise = torch.randn(3, size // 16, size // 16)
    low_freq_noise = torch.nn.functional.interpolate(low_freq_noise[None, ...], size=(size, size), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).numpy()
    med_freq_noise = torch.randn(3, size // 4, size // 4)
    med_freq_noise = torch.nn.functional.interpolate(med_freq_noise[None, ...], size=(size, size), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).numpy()
    high_freq_noise = torch.randn(3, size, size).permute(1,2,0).numpy()
    rgb = rgb + low_freq_noise * low_frequency_noise_factor + med_freq_noise * med_frequency_noise_factor + high_freq_noise * high_frequency_noise_factor
    rgb = np.clip(rgb, 0, 1)
    
    return rgb

for i in range(10):
    im = generate_color_image(n_colors=5, n_means=5, size=512, lf_noise=0.3, mf_noise=0.3, hf_noise=0.3)
    plt.imsave(f'./synthetic_color_images/{i:04d}.png', im)