In [6]:
import os
import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import Optional
from sklearn.datasets import make_moons, make_circles
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from PIL import Image
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp

# ============ 分布定义 ============
class Sampleable(ABC):
    @abstractmethod
    def sample(self, num_samples: int) -> torch.Tensor:
        pass

class StandardGaussian(Sampleable):
    def sample(self, num_samples: int) -> torch.Tensor:
        return torch.randn(num_samples, 2)

class StretchedGaussian(Sampleable):
    def __init__(self, angle_degrees=45, scale_x=2.5, scale_y=0.5):
        angle = np.deg2rad(angle_degrees)
        R = np.array([
            [np.cos(angle), -np.sin(angle)],
            [np.sin(angle),  np.cos(angle)]
        ])  # 2x2 rotation matrix

        D = np.diag([scale_x**2, scale_y**2])  # variances along principal axes
        cov = R @ D @ R.T  # full covariance

        self.mean = np.zeros(2)
        self.cov = cov.astype(np.float32)

    def sample(self, num_samples: int) -> torch.Tensor:
        samples = np.random.multivariate_normal(self.mean, self.cov, size=num_samples)
        return torch.from_numpy(samples).float()

class MoonsSampleable(Sampleable):
    def __init__(self, noise=0.05, scale=5.0):
        self.noise = noise
        self.scale = scale
    def sample(self, num_samples: int) -> torch.Tensor:
        x, _ = make_moons(n_samples=num_samples, noise=self.noise)
        return self.scale * torch.from_numpy(x).float()

class CirclesSampleable(Sampleable):
    def __init__(self, noise=0.05, scale=5.0):
        self.noise = noise
        self.scale = scale
    def sample(self, num_samples: int) -> torch.Tensor:
        x, _ = make_circles(n_samples=num_samples, noise=self.noise, factor=0.5)
        return self.scale * torch.from_numpy(x).float()

class CheckerboardSampleable(Sampleable):
    def __init__(self, grid_size: int = 3, scale=5.0):
        self.grid_size = grid_size
        self.scale = scale
    def sample(self, num_samples: int) -> torch.Tensor:
        grid_length = 2 * self.scale / self.grid_size
        samples = torch.zeros(0,2)
        while samples.shape[0] < num_samples:
            new_samples = (torch.rand(num_samples,2) - 0.5) * 2 * self.scale
            x_mask = torch.floor((new_samples[:,0] + self.scale) / grid_length) % 2 == 0
            y_mask = torch.floor((new_samples[:,1] + self.scale) / grid_length) % 2 == 0
            accept_mask = torch.logical_xor(~x_mask, y_mask)
            samples = torch.cat([samples, new_samples[accept_mask]], dim=0)
        return samples[:num_samples]

# ============ 渲染 + 保存 ============
def render_points_to_image(points: torch.Tensor, image_size=64, scale=6.0) -> torch.Tensor:
    fig = Figure(figsize=(1, 1), dpi=image_size)
    canvas = FigureCanvas(fig)
    ax = fig.add_subplot(111)
    ax.set_xlim(-scale, scale)
    ax.set_ylim(-scale, scale)
    ax.axis('off')
    ax.scatter(points[:, 0].numpy(), points[:, 1].numpy(), s=0.3, c='black', alpha=0.1)

    canvas.draw()
    renderer = canvas.get_renderer()
    buf = renderer.buffer_rgba()
    rgba = np.asarray(buf, dtype=np.uint8).reshape(image_size, image_size, 4)
    rgb = rgba[:, :, :3]
    gray = np.mean(rgb, axis=2).astype(np.uint8)
    norm_gray = gray.astype(np.float32) / 255.0
    tensor_img = torch.from_numpy(norm_gray).unsqueeze(0)
    return tensor_img

def save_tensor_as_png(tensor: torch.Tensor, filepath: str):
    array = (tensor.squeeze(0).numpy() * 255).astype(np.uint8)
    img = Image.fromarray(array, mode='L')
    img.save(filepath)

# ============ 单图生成 ============
def generate_one_image(distribution_name: str, index: int, num_points: int, image_size: int, save_dir: str):
    sampleables = {
        "standard_gaussian": StandardGaussian(),
        "stretched_gaussian": StretchedGaussian(),
        "moons": MoonsSampleable(),
        "circles": CirclesSampleable(),
        "checkerboard": CheckerboardSampleable(),
    }
    sampleable = sampleables[distribution_name]
    points = sampleable.sample(num_points)
    tensor_img = render_points_to_image(points, image_size=image_size)
    save_path = os.path.join(save_dir, distribution_name, f"{index:04d}.png")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    save_tensor_as_png(tensor_img, save_path)

# ============ 多进程任务展开器 ============
def unpack_and_generate(args_tuple):
    return generate_one_image(*args_tuple)

# ============ 分布批量生成 ============
def generate_all_images(distribution_name: str, num_images: int, num_points: int, image_size: int = 64, save_root='./data'):
    os.makedirs(os.path.join(save_root, distribution_name), exist_ok=True)
    args = [
        (distribution_name, i, num_points, image_size, save_root)
        for i in range(num_images)
    ]
    with ProcessPoolExecutor(max_workers=mp.cpu_count()) as executor:
        list(tqdm(executor.map(unpack_and_generate, args), total=num_images, desc=f"Generating {distribution_name}"))

# ============ 统一入口 ============
def generate_all_distributions():
    distributions = ["standard_gaussian", "stretched_gaussian", "moons", "circles", "checkerboard"]
    for name in distributions:
        generate_all_images(distribution_name=name, num_images=1000, num_points=100000)

# ============ 运行入口 ============
if __name__ == "__main__":
    generate_all_distributions()

Generating standard_gaussian: 100%|██████████| 1000/1000 [00:01<00:00, 804.93it/s]
Generating stretched_gaussian: 100%|██████████| 1000/1000 [00:09<00:00, 102.88it/s]
Generating moons: 100%|██████████| 1000/1000 [00:16<00:00, 61.31it/s]
Generating circles: 100%|██████████| 1000/1000 [00:16<00:00, 61.18it/s]
Generating checkerboard: 100%|██████████| 1000/1000 [03:59<00:00,  4.17it/s]
