# Fixed Point Diffusion Models (FPDM)
This notebook shows how to run the image sampling with FPDM.

## Set Up
We provide an environment.yml file that can be used to create a Conda environment. See how to install all required packages in `README.md`.

In [None]:
# Standard library imports
import json
import math
import random
import sys
from contextlib import nullcontext
from pathlib import Path
from typing import Optional

# Third-party imports
import torch
import torch.nn as nn
from PIL import Image
from torch import Tensor
from torch.utils.checkpoint import checkpoint
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers.models import AutoencoderKL
from jaxtyping import Float, Shaped
from tap import Tap
from tqdm import trange
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp

# Local module imports
sys.path.append("..")
from diffusion import create_diffusion
from download import find_model
from models import DiT_models


## Hyperparameters

In [None]:
class Args(Tap):
    """
    A class to define and store hyperparameters and configurations for the demo.
    """

    # File and directory paths.
    output_dir: str = 'demo_samples'

    # Dataset configuration.
    dataset_name: str = "imagenet256"

    # Model specific parameters.
    model: str = "DiT-XL/2"
    vae: str = "mse"
    num_classes: int = 1000
    image_size: int = 256
    predict_v: bool = False
    use_zero_terminal_snr: bool = False
    unsupervised: bool = False
    dino_supervised: bool = False
    dino_supervised_dim: int = 768
    flow: bool = False
    debug: bool = False

    # Fixed Point settings.
    fixed_point: bool = False
    fixed_point_pre_depth: int = 2
    fixed_point_post_depth: int = 2
    fixed_point_iters: Optional[int] = None
    fixed_point_pre_post_timestep_conditioning: bool = False
    fixed_point_reuse_solution: bool = False

    # Sampling configuration.
    ddim: bool = False
    cfg_scale: float = 4.0
    num_sampling_steps: int = 250
    batch_size: int = 4
    ckpt: str = '/work/xingjian/diff-deq-inference/pretrained/DiT-XL-2/checkpoints/0500000.pt'  # replace it with the Path to your checkpoint.
    global_seed: int = 0

    # Parallelization settings.
    sample_index_start: int = 0
    sample_index_end: Optional[int] = 32

    def process_args(self):
        """
        Method for additional argument processing and validation.
        """
        # Debug mode configuration.
        if self.debug:
            self.log_with = 'tensorboard'
            self.name = 'debug'

        # Set default values and validate image size.
        self.fixed_point_iters = self.fixed_point_iters or (28 - self.fixed_point_pre_depth - self.fixed_point_post_depth)
        assert self.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
        self.latent_size = self.H_lat = self.W_lat = self.image_size // 8

        # Additional checks and validations.
        if self.cfg_scale < 1.0:
            raise ValueError("In almost all cases, cfg_scale should be >= 1.0")
        
        if self.unsupervised:
            assert self.cfg_scale == 1.0
            self.num_classes = 1
        elif self.dino_supervised:
            raise NotImplementedError()
        
        if not Path(self.ckpt).is_file():
            raise ValueError(self.ckpt)
        
        # Creating the output directory.
        output_parent = Path(self.output_dir) / Path(self.ckpt).parent.parent.name
        if self.debug:
            output_dirname = 'debug'
        else:
            output_dirname = f'num_sampling_steps-{self.num_sampling_steps}--cfg_scale-{self.cfg_scale}'
            if self.fixed_point:
                output_dirname += f'--fixed_point_iters-{self.fixed_point_iters}--fixed_point_reuse_solution-{self.fixed_point_reuse_solution}--fixed_point_pptc-{self.fixed_point_pre_post_timestep_conditioning}'
        if self.ddim:
            output_dirname += f'--ddim'
        self.output_dir = str(output_parent / output_dirname)
        Path(self.output_dir).mkdir(exist_ok=True, parents=True)


In [None]:
class Args(Tap):

    # Paths
    output_dir: str = 'samples'

    # Dataset
    dataset_name: str = "imagenet256"

    # Model
    model: str = "DiT-XL/2"
    vae: str = "mse"
    num_classes: int = 1000
    image_size: int = 256
    predict_v: bool = False
    use_zero_terminal_snr: bool = False
    unsupervised: bool = False
    dino_supervised: bool = False
    dino_supervised_dim: int = 768
    flow: bool = False
    debug: bool = False

    # Fixed Point settings
    fixed_point: bool = False
    fixed_point_pre_depth: int = 2
    fixed_point_post_depth: int = 2
    fixed_point_iters: Optional[int] = None
    fixed_point_pre_post_timestep_conditioning: bool = False
    fixed_point_reuse_solution: bool = False

    # Sampling
    ddim: bool = False
    cfg_scale: float = 4.0
    num_sampling_steps: int = 250
    batch_size: int = 4
    ckpt: str = '/work/xingjian/diff-deq-inference/pretrained/DiT-XL-2/checkpoints/0500000.pt' # replace with path to checkpoint
    global_seed: int = 0
    
    # Parallelization
    sample_index_start: int = 0
    sample_index_end: Optional[int] = 32

    def process_args(self) -> None:
        """Additional argument processing"""
        if self.debug:
            self.log_with = 'tensorboard'
            self.name = 'debug'

        # Defaults
        self.fixed_point_iters = self.fixed_point_iters or (28 - self.fixed_point_pre_depth - self.fixed_point_post_depth)
        assert self.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
        self.latent_size = self.H_lat = self.W_lat = self.image_size // 8
        # Checks
        if self.cfg_scale < 1.0:
            raise ValueError("In almost all cases, cfg_scale should be >= 1.0")
        if self.unsupervised:
            assert self.cfg_scale == 1.0
            self.num_classes = 1
        elif self.dino_supervised:
            raise NotImplementedError()
        if not Path(self.ckpt).is_file():
            raise ValueError(self.ckpt)

        # Create output directory
        output_parent = Path(self.output_dir) / Path(self.ckpt).parent.parent.name
        if self.debug:
            output_dirname = 'debug'
        else:
            output_dirname = f'num_sampling_steps-{self.num_sampling_steps}--cfg_scale-{self.cfg_scale}'
            if self.fixed_point:
                output_dirname += f'--fixed_point_iters-{self.fixed_point_iters}--fixed_point_reuse_solution-{self.fixed_point_reuse_solution}--fixed_point_pptc-{self.fixed_point_pre_post_timestep_conditioning}'
        if self.ddim:
            output_dirname += f'--ddim'
        self.output_dir = str(output_parent / output_dirname)
        Path(self.output_dir).mkdir(exist_ok=True, parents=True)

In [None]:
args = Args()
args.process_args()

## Network architecture
We modify the original DiT class to support fixed point blocks.

In [None]:

from models import TimestepEmbedder, LabelEmbedder, DiTBlock, FinalLayer, get_2d_sincos_pos_embed
class DiT(nn.Module):
    """
    Diffusion model with a Transformer backbone. It includes methods for the forward pass
    and initialization of weights. The model can operate in both standard and fixed-point modes.
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
        use_cfg_embedding: bool = True,
        use_gradient_checkpointing: bool = True,
        is_label_continuous: bool = False,

        # below are Fixed Point-specific arguments.
        fixed_point: bool = False,

        # size
        fixed_point_pre_depth: int = 1, 
        fixed_point_post_depth: int = 1, 

        # iteration counts
        fixed_point_no_grad_min_iters: int = 0, 
        fixed_point_no_grad_max_iters: int = 0,
        fixed_point_with_grad_min_iters: int = 28, 
        fixed_point_with_grad_max_iters: int = 28,

        # solution recycle
        fixed_point_reuse_solution = False,
        
        # pre_post_timestep_conditioning
        fixed_point_pre_post_timestep_conditioning: bool = True,
    ):
        super().__init__()
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.use_gradient_checkpointing = use_gradient_checkpointing

        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
        self.t_embedder = TimestepEmbedder(hidden_size)
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob, 
            use_cfg_embedding=use_cfg_embedding, continuous=is_label_continuous)
        num_patches = self.x_embedder.num_patches
        
        # Will use fixed sin-cos embedding:
        # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
        self.register_buffer('pos_embed', torch.zeros(1, num_patches, hidden_size))

        # New: Fixed Point
        self.fixed_point = fixed_point
        if self.fixed_point:
            self.fixed_point_no_grad_min_iters = fixed_point_no_grad_min_iters
            self.fixed_point_no_grad_max_iters = fixed_point_no_grad_max_iters
            self.fixed_point_with_grad_min_iters = fixed_point_with_grad_min_iters
            self.fixed_point_with_grad_max_iters = fixed_point_with_grad_max_iters
            self.fixed_point_pre_post_timestep_conditioning = fixed_point_pre_post_timestep_conditioning
            self.blocks_pre = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_pre_depth)])
            self.block_pre_projection = nn.Linear(hidden_size, hidden_size)
            self.block_fixed_point_projection_fc1 = nn.Linear(2 * hidden_size, 2 * hidden_size)
            self.block_fixed_point_projection_act = nn.GELU(approximate="tanh")
            self.block_fixed_point_projection_fc2 = nn.Linear(2 * hidden_size, hidden_size)
            self.block_fixed_point = DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
            self.blocks_post = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(fixed_point_post_depth)])
            self.blocks = [*self.blocks_pre, self.block_fixed_point, *self.blocks_post]
            self.fixed_point_reuse_solution = fixed_point_reuse_solution
            self.last_solution = None
        else:
            self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        # Initialize (and freeze) pos_embed by sin-cos embedding:
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_embedder.proj.bias, 0)

        # Initialize label embedding table:
        if self.y_embedder.continuous:
            nn.init.normal_(self.y_embedder.embedding_projection.weight, std=0.02)
            nn.init.constant_(self.y_embedder.embedding_projection.bias, 0)
        else:
            nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)

        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers:
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def unpatchify(self, x):
        """
        Reshapes the patches back to image format.
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels
        p = self.x_embedder.patch_size[0]
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs
    
    def ckpt_wrapper(self, module):
        """
        Wrapper function for gradient checkpointing.
        """
        def ckpt_forward(*inputs):
            outputs = module(*inputs)
            return outputs
        return ckpt_forward

    def _forward_dit(self, x, t, y):
        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        t = self.t_embedder(t)                   # (N, D)
        y = self.y_embedder(y, self.training)    # (N, D))
        c = t + y                                # (N, D)
        for block in self.blocks:
            x = checkpoint(self.ckpt_wrapper(block), x, c) if self.use_gradient_checkpointing else block(x, c)  # (N, T, D)
        x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        return x
    
    def _forward_fixed_point_blocks(
        self, x: Float[Tensor, "b t d"], x_input_injection: Float[Tensor, "b t d"], c: Float[Tensor, "b d"], num_iterations: int
    ) -> Float[Tensor, "b t d"]:
        for _ in range(num_iterations):
            x = torch.cat((x, x_input_injection), dim=-1)  # (N, T, D * 2)
            x = self.block_fixed_point_projection_fc1(x)  # (N, T, D * 2)
            x = self.block_fixed_point_projection_act(x)  # (N, T, D * 2)
            x = self.block_fixed_point_projection_fc2(x)  # (N, T, D)
            x = self.block_fixed_point(x, c)  # (N, T, D)
        return x
    
    def _check_inputs(self, x: Float[Tensor, "b c h w"], t: Shaped[Tensor, "b"], y: Shaped[Tensor, "b"]) -> None:
        if self.fixed_point_reuse_solution:
            if not torch.all(t[0] == t).item():
                raise ValueError(t)

    def _forward_fixed_point(self, x: Float[Tensor, "b c h w"], t: Shaped[Tensor, "b"], y: Shaped[Tensor, "b"]) -> Float[Tensor, "b c h w"]:
        self._check_inputs(x, t, y)
        x: Float[Tensor, "b t d"] = self.x_embedder(x) + self.pos_embed
        t_emb: Float[Tensor, "b d"] = self.t_embedder(t)
        y: Float[Tensor, "b d"] = self.y_embedder(y, self.training)
        c: Float[Tensor, "b d"] = t_emb + y
        c_pre_post_fixed_point: Float[Tensor, "b d"] = (t_emb + y) if self.fixed_point_pre_post_timestep_conditioning else y
        
        # Pre-Fixed Point
        # Note: If using DDP with find_unused_parameters=True, checkpoint causes issues. For more 
        # information, see https://github.com/allenai/longformer/issues/63#issuecomment-648861503
        for block in self.blocks_pre:
            x: Float[Tensor, "b t d"] = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point)
        condition = x.clone()

        # Whether to reuse the previous solution at the next iteration
        init_solution = self.last_solution if (self.fixed_point_reuse_solution and self.last_solution is not None) else x.clone()

        # Fixed Point (we have condition and init_solution)
        x_input_injection = self.block_pre_projection(condition)

        # NOTE: This section of code should have no_grad, but cannot due to a DDP bug. See
        # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
        # for more information
        with nullcontext():  # we use x.detach() in place of torch.no_grad due to DDP issue
            num_iterations_no_grad = random.randint(self.fixed_point_no_grad_min_iters, self.fixed_point_no_grad_max_iters)
            x = self._forward_fixed_point_blocks(x=init_solution.detach(), x_input_injection=x_input_injection.detach(), c=c, num_iterations=num_iterations_no_grad)
            x = x.detach()  # no grad
        num_iterations_with_grad = random.randint(self.fixed_point_with_grad_min_iters, self.fixed_point_with_grad_max_iters)
        x = self._forward_fixed_point_blocks(x=x, x_input_injection=x_input_injection, c=c, num_iterations=num_iterations_with_grad)

        # Save solution for reuse at next step
        if self.fixed_point_reuse_solution:
            self.last_solution = x.clone()
        
        # Post-Fixed Point
        for block in self.blocks_post:
            x = checkpoint(self.ckpt_wrapper(block), x, c_pre_post_fixed_point) if self.use_gradient_checkpointing else block(x, c_pre_post_fixed_point)
        
        # Output
        x: Float[Tensor, "b t p2c"] = self.final_layer(x, c_pre_post_fixed_point)  # p2c = patch_size ** 2 * out_channels)
        x: Float[Tensor, "b c h w"] = self.unpatchify(x)
        return x
    
    def reset(self):
        self.last_solution = None
    
    def forward(self, x, t, y):
        """
        General forward pass method which handles both standard and fixed point modes.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        if self.fixed_point:
            return self._forward_fixed_point(x, t, y)
        else:
            return self._forward_dit(x, t, y)

    def forward_with_cfg(self, x, t, y, cfg_scale):
        """
        Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
        """
        # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = self.forward(combined, t, y)
        # For exact reproducibility reasons, we apply classifier-free guidance on only
        # three channels by default. The standard approach to cfg applies it to all channels.
        # This can be done by uncommenting the following line and commenting-out the line following that.
        # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)

In [None]:
# Initialize the Accelerator for GPU/CPU acceleration and create the DiT model.
accelerator = Accelerator()
model = DiT_models[args.model](
        input_size=args.latent_size,
        num_classes=(args.dino_supervised_dim if args.dino_supervised else args.num_classes),
        is_label_continuous=args.dino_supervised,
        class_dropout_prob=0,
        learn_sigma=(not args.flow),  # TODO: Implement learned variance for flow-based models
        use_gradient_checkpointing=False,
        fixed_point=args.fixed_point,
        fixed_point_pre_depth=args.fixed_point_pre_depth,
        fixed_point_post_depth=args.fixed_point_post_depth,
        fixed_point_no_grad_min_iters=0, 
        fixed_point_no_grad_max_iters=0,
        fixed_point_with_grad_min_iters=args.fixed_point_iters, 
        fixed_point_with_grad_max_iters=args.fixed_point_iters,
        fixed_point_reuse_solution=args.fixed_point_reuse_solution,
        fixed_point_pre_post_timestep_conditioning=args.fixed_point_pre_post_timestep_conditioning,
    ).to(accelerator.device)

## Load Model

In [None]:
# Load the pre-trained model checkpoint.
state_dict = find_model(args.ckpt)
model.load_state_dict(state_dict)
model.eval() 

# Initialize the diffusion process with specified parameters.
diffusion = create_diffusion(
    str(args.num_sampling_steps), 
    use_flow=args.flow,
    predict_v=args.predict_v,
    use_zero_terminal_snr=args.use_zero_terminal_snr,
)

# Load the VAE model and evaluate it.
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(accelerator.device).eval()


## Inference

In [None]:
# create generator, class labels, and latents
N_images = 32
generator = torch.Generator(device=accelerator.device)
generator.manual_seed(args.global_seed)
class_labels = torch.randint(0, args.num_classes, size=(N_images,), device=accelerator.device)
generator.manual_seed(args.global_seed)
latents = torch.randn(N_images, model.in_channels, args.H_lat, args.W_lat, device=accelerator.device, generator=generator)
class_labels = class_labels[args.sample_index_start:args.sample_index_end]
latents = latents[args.sample_index_start:args.sample_index_end]
indices = list(range(args.sample_index_start, args.sample_index_end))
print(f'Using pseudorandom class labels and latents (start={args.sample_index_start} and end={args.sample_index_end})')

 # Create output path
output_dir = Path(args.output_dir)
# if cfg is used
using_cfg = args.cfg_scale > 1.0

In [None]:
# Load class labels for helpful filenames
if args.dataset_name == 'imagenet256':
    with open("../utils/imagenet-labels.json", "r") as f:
        label_names: list[str] = json.load(f)
        label_names = [l.lower().replace(' ', '-').replace('\'', '') for l in label_names]
elif args.unsupervised:
    assert args.cfg_scale == 1.0
    label_names = ["unlabeled"]
else:
    raise NotImplementedError()

In [None]:
with torch.inference_mode():
    # Sample loop
    num_batches = math.ceil(len(class_labels) / args.batch_size)
    for batch_idx in trange(num_batches, disable=(not accelerator.is_main_process)):

        # Get pre-sampled inputs
        z = latents[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]
        y = class_labels[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]
        idxs = indices[batch_idx*args.batch_size:(batch_idx + 1)*args.batch_size]
        output_paths = [output_dir / f'{idx:05d}--{y_i:03d}--{label_names[y_i]}.png' for y_i, idx in zip(y.tolist(), idxs)]

        # Skip files that already exist
        if all(output_path.is_file() for output_path in output_paths):
            print(f'Files already exist (batch {batch_idx}). Skipping.')
            continue

        # Setup classifier-free guidance
        if using_cfg:
            y_null = torch.tensor([1000] * args.batch_size, device=accelerator.device)
            y = torch.cat([y, y_null], 0)
            z = torch.cat([z, z], 0)
            model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
            sample_fn = model.forward_with_cfg
        else:
            model_kwargs = dict(y=y)
            sample_fn = model.forward

        # Sample latent images
        sample_kwargs = dict(model=sample_fn, shape=z.shape, noise=z, clip_denoised=False, model_kwargs=model_kwargs, 
            progress=False, device=accelerator.device)
        if args.ddim:
            samples = diffusion.ddim_sample_loop(**sample_kwargs)
        else:
            samples = diffusion.p_sample_loop(**sample_kwargs)

        if using_cfg:
            samples, _ = samples.chunk(2, dim=0)
        
        # Reset model (resets the initial solution to None)
        model.reset()

        # Decode latents
        samples = vae.decode(samples / vae.config.scaling_factor).sample
        samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()

        # Save samples to disk as individual .png files
        for sample, output_path in zip(samples, output_paths):
            Image.fromarray(sample).save(output_path)

## Visualization

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import json

# get all files in output_dir
output_dir = Path(args.output_dir)
files = list(output_dir.glob('*.png'))

# visualize four random samples
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
for ax in axs.flatten():
    img = Image.open(files[random.randint(0, len(files) - 1)])
    ax.imshow(img)
    ax.axis('off')
plt.tight_layout()
plt.show()