In [None]:
# import neccessary libraries
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from random import choice

In [None]:
# Output paths
image_dir = "../data/synthetic/images"
mask_dir = "../data/synthetic/masks"

os.makedirs(image_dir, exist_ok=True)
os.makedirs(mask_dir, exist_ok=True)

In [None]:
# Function to draw random axons (circles) on an image
def generate_synthetic_rgb_image(image_size=512, num_blobs=15, radius_range=(10, 30)):
    image = np.zeros((image_size, image_size, 3), dtype=np.uint8)
    mask = np.zeros((image_size, image_size), dtype=np.uint8)

    # Lightly stained background
    base_color = np.random.randint(200, 255)
    noise = np.random.randint(-10, 10, (image_size, image_size, 3))
    image[:, :, :] = base_color
    image = np.clip(image + noise, 0, 255)

    for _ in range(num_blobs):
        r = np.random.randint(*radius_range)
        x = np.random.randint(r, image_size - r)
        y = np.random.randint(r, image_size - r)

        # Random axon color (bluish, purplish shades)
        color = (
            np.random.randint(80, 120),  # B
            np.random.randint(60, 100),  # G
            np.random.randint(100, 160)  # R
        )

        cv2.circle(image, (x, y), r, color, -1)
        cv2.circle(mask, (x, y), r, 255, -1)  # Binary mask

    return image, mask


In [None]:
# Preview a sample synthetic image
img, msk = generate_synthetic_rgb_image()

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
axes[0].set_title("RGB Synthetic Image")
axes[1].imshow(msk, cmap='gray')
axes[1].set_title("Binary Mask")
plt.show()

In [None]:
# Generate and save dataset
num_samples = 100  # feel free to increase later

for i in tqdm(range(num_samples)):
    img, msk = generate_synthetic_image()
    
    cv2.imwrite(os.path.join(image_dir, f"img_{i:03}.png"), img)
    cv2.imwrite(os.path.join(mask_dir, f"mask_{i:03}.png"), msk)


In [None]:
# Confirm save
sample_file = choice(os.listdir(image_dir)).replace("img", "mask")
img = cv2.imread(os.path.join(image_dir, sample_file.replace("mask", "img")), 0)
msk = cv2.imread(os.path.join(mask_dir, sample_file), 0)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(img, cmap='gray')
plt.title("Saved Image")
plt.subplot(1, 2, 2)
plt.imshow(msk, cmap='gray')
plt.title("Saved Mask")
plt.show()