In [None]:
try:
    import jax.numpy as jnp
    from jax import random
    import matplotlib.pyplot as plt
except ModuleNotFoundError:
    !pip install jax matplotlib --quiet

In [None]:
noise_key, perlin_key, init_key, shuffle_key = random.split(random.PRNGKey(0), 4)
image_shape=(1, 64, 64, 3)
n_samples = 64

gaussian_samples = []
poisson_samples = []

fig = plt.figure(figsize=(n_samples, 2))

for _ in range(n_samples):
    noise_key, gaussian_key, poisson_key, intensity_key = random.split(noise_key, 4)

    intensity = random.uniform(intensity_key)

    gaussian_samples.append(random.normal(gaussian_key, image_shape) * intensity)
    poisson_img = random.poisson(poisson_key, intensity, image_shape)

    max_val = poisson_img.max()
    if max_val > 1e-3:
        poisson_img_normalized = poisson_img / max_val
    else:
        poisson_img_normalized = poisson_img

    poisson_samples.append(poisson_img_normalized)

gaussian_samples = jnp.array(gaussian_samples)
poisson_samples = jnp.array(poisson_samples)

all_samples = jnp.concatenate((gaussian_samples, poisson_samples), axis=0)
all_labels = jnp.concatenate((jnp.zeros(len(gaussian_samples)), jnp.ones(len(poisson_samples))), axis=0)

n_rows = n_samples
n_cols = 2

fig = plt.figure(figsize=(n_cols * 2, n_rows * 2))

for i in range(3):
    plt.subplot(n_rows, n_cols, 2 * i + 1)
    plt.imshow(gaussian_samples[i].squeeze())
    plt.title(f'Gaussian Sample {i+1}')
    plt.axis('off')

    plt.subplot(n_rows, n_cols, 2 * i + 2)
    plt.imshow(poisson_samples[i].squeeze())
    plt.title(f'Poisson Sample {i+1}')
    plt.axis('off')

plt.tight_layout()
plt.show()