In [6]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import colorsys 
import os

In [11]:
def HSVToRGB(h, s, v): 
    (r, g, b) = colorsys.hsv_to_rgb(h, s, v) 
    return [int(255*r), int(255*g), int(255*b)]

def get_distinct_colors(k):
    huePartition = 1.0 / (k + 1) 
    colors = [HSVToRGB(huePartition * value, 1.0, 1.0) for value in range(0, k)]
    colors = list(colors)
    for i, c in enumerate(colors):
        if c[0] == c[1]:
            if c[0] > 0:
                c[0] -= 1
            else:
                c[0] += 1
        if c[1] == c[2]:
            if c[1] > 0:
                c[1] -= 1
            else:
                c[1] += 1
        if c[2] == c[0]:
            if c[2] > 0:
                c[2] -= 1
            else:
                c[2] += 1
    return colors


def map_array_to_colors(array, colors):
    # Create a numpy array of the same shape as the input array, with 3 channels for RGB colors
    image_array = np.zeros((*array.shape, 3), dtype=np.uint8)
    
    # Replace each integer in the input array with its corresponding color
    for i, color in enumerate(colors):
        image_array[array == i] = color
    
    return image_array

def generate_color_image(n_colors, n_means, size=512, lf_noise=0.3, mf_noise=0.3, hf_noise=0.3, seed=None):
    n = n_means
    k = n_colors

    low_frequency_noise_factor = lf_noise
    med_frequency_noise_factor = mf_noise
    high_frequency_noise_factor = hf_noise

    # Generate n random points in the image
    if seed is not None:
        torch.manual_seed(seed)
    points = torch.randint(0, size, size=(n, 2), dtype=torch.float32)

    # Generate n random numbers between 0.5 and 1
    influences = torch.rand(n) * 0.5 + 0.5
    influences = torch.ones(n)

    # Assign labels to each point
    labels = torch.arange(k, dtype=torch.int64).repeat(n // k + 1)[:n]

    # Create a grid of x and y values for all pixels in the image
    x, y = torch.meshgrid(torch.arange(size), torch.arange(size))
    pixels = torch.stack([x, y], dim=-1).float()

    # Compute the distances between all points and all pixels in the image
    distances = torch.cdist(pixels.view(-1, 2), points)

    # Scale the distances by the influence of the corresponding point
    distances = distances * influences

    # Find the index of the closest point for each pixel
    _, closest = torch.min(distances, dim=-1)
    image = labels[closest].reshape(size, size).numpy()

    # cmap = plt.get_cmap('gist_ncar')
    # rgba = cmap(image.astype(np.float32) / np.max(image))
    # rgb = rgba[:, :, :3]

    # create a discrete colormap from a list of colors
    # colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255]]
    # cmap = plt.get_cmap('gist_ncar', len(colors))
    # cmap.colors = colors
    # rgba = cmap(image.astype(np.float32) / np.max(image))
    # rgb = rgba[:, :, :3]

    colors = get_distinct_colors(k)
    rgb = map_array_to_colors(image, colors) / 255.0

    low_freq_noise = torch.randn(3, size // 16, size // 16)
    low_freq_noise = torch.nn.functional.interpolate(low_freq_noise[None, ...], size=(size, size), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).numpy()
    med_freq_noise = torch.randn(3, size // 4, size // 4)
    med_freq_noise = torch.nn.functional.interpolate(med_freq_noise[None, ...], size=(size, size), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).numpy()
    high_freq_noise = torch.randn(3, size, size).permute(1,2,0).numpy()
    rgb = rgb + low_freq_noise * low_frequency_noise_factor + med_freq_noise * med_frequency_noise_factor + high_freq_noise * high_frequency_noise_factor
    rgb = np.clip(rgb, 0, 1)
    
    return rgb

for noise_level in np.arange(0, 1.1, 0.1):
    for i in range(10):
        im = generate_color_image(n_colors=5, n_means=5, size=512, lf_noise=noise_level, mf_noise=noise_level, hf_noise=noise_level, seed=i)
        plt.imsave(f'./synthetic_color_images/{noise_level:.1f}/{i:04d}.png', im)