# 0. Imports

In [None]:
from sampling.conditional_probability_path import ConditionalProbabilityPath, GaussianConditionalProbabilityPath
from sampling.noise_scheduling import LinearAlpha, LinearBeta
from sampling.sampleable import PixelArtSampler

import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def sample_and_plot(path: ConditionalProbabilityPath, num_rows: int = 3, num_cols: int = 3, num_timesteps: int = 5):
    sampler = PixelArtSampler().to(device)
    num_samples = num_cols * num_rows
    z, _ = path.p_data.sample(num_samples)
    z = z.view(-1, 3, 128, 128)

    fig, axes = plt.subplots(1, num_timesteps, figsize=(6 * num_cols * num_timesteps, 6 * num_rows))

    ts = torch.linspace(0, 1, num_timesteps).to(device)

    for t_idx, t in enumerate(ts):
        tt = t.view(1, 1, 1, 1).expand(num_samples, 1, 1, 1) # (num_samples, 1, 1, 1)
        xt = path.sample_conditional_path(z, tt) # (num_samples, 3, 128, 128)
        grid = make_grid(xt, nrow=num_cols, normalize=True, value_range=(-1,1))
        axes[t_idx].imshow(grid.permute(1, 2, 0).cpu(), cmap="rgb")
        axes[t_idx].axis("off")

    plt.show()

In [None]:
path = GaussianConditionalProbabilityPath(p_data=PixelArtSampler(),
                                              p_simple_shape=[1, 128, 128],
                                              alpha=LinearAlpha(),
                                              beta=LinearBeta()
                                              ).to(device)

sample_and_plot(path)