# Sample

In [None]:
"""
Sample new images from a pre-trained DiT.
"""
from DiT import *
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torchvision.utils import save_image
from diffusers.models import AutoencoderKL
import argparse
import logging

os.chdir("./")



def sample_main(args):
    # Setup PyTorch:
    torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    if args.ckpt is None:
        assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download."
        assert args.image_size in [256, 512]
        assert args.num_classes == 1000

    # Load model:
    latent_size=args.image_size
    model = DiT_models[args.model](
        input_size=args.image_size,
        num_classes=args.num_classes
    ).to(device)
    # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py:
    ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
    state_dict = find_model(ckpt_path)
    model.load_state_dict(state_dict)
    model.eval()  # important!
    diffusion = create_diffusion(str(args.num_sampling_steps))
    # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)

    # Labels to condition the model with (feel free to change):
    class_labels = [0]*args.sample_size

    # Create sampling noise:
    n = len(class_labels)
    # z = torch.randn(n, 4, latent_size, latent_size, device=device)
    z = torch.randn(n,  args.seq_length, args.image_size, device=device)
    y = torch.tensor(class_labels, device=device)

    # Setup classifier-free guidance:
    z = torch.cat([z, z], 0)
    y_null = torch.tensor([1000] * n, device=device)
    y = torch.cat([y, y_null], 0)
    model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)

    # Sample images:
    # print('z',z.shape,'y',y.shape,'y_null',y_null.shape)
    samples = diffusion.p_sample_loop(
        model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
    )
    print('sample',samples.shape)
    

    samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
    print('sample2',samples.shape)
    # samples = vae.decode(samples / 0.18215).sample

    # Save and display images:
    # save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1))
    file_path=args.result_path+f'/sample_{args.name}_{n}.npy'
    np.save(file_path, samples.cpu().numpy())
    
    print(f"sample results saved to {file_path}")
    



In [None]:
device="cpu"
if torch.cuda.is_available():
    device="cuda:0"
sample_args_dict = {
        "model": f"DiT-PD/2_N=240",
        "vae": "mse",
        "seq_length":240,
        "image_size": 2,
        "sample_size":1000,
        "num_classes": 1000,
        "cfg_scale": 0,
        "num_sampling_steps": 1000,
        "seed": 0,
        "ckpt": "./video_pts/ep49_0001950.pt",
        "result_path":"video_samples",
        "name":'real-cfg0-testQKV_49epochs_bounce_240'
        # "ckpt":None
    }

sample_args = ArgsDict(sample_args_dict)
sample_main(sample_args)