In [1]:
from pathlib import Path
import os
import torch
import torchvision as tv
import transformer_flow 
import utils
import pathlib
utils.set_random_seed(0)
notebook_output_path = pathlib.Path('runs/notebook')

In [3]:
# specify the following parameters to match the model config
dataset = 'imagenet'
num_classes = {'imagenet': 1000, 'imagenet64': 0, 'afhq': 3}[dataset]
img_size = 128
channel_size = 3

batch_size = 16
patch_size = 4
channels = 768
blocks = 8
layers_per_block = 8
noise_std = 0.07

device = 'cuda'

model_name = f'{patch_size}_{channels}_{blocks}_{layers_per_block}_{noise_std:.2f}'
ckpt_file = Path('models') / f'{dataset}_model_{model_name}.pth'
ckpt_file = 'models/imagenet_model_converted.pth'
# we can download a pretrained model, comment this out if testing your own checkpoints
# os.system(f'wget https://ml-site.cdn-apple.com/models/tarflow/afhq256/afhq_model_8_768_8_8_0.07.pth -q -P {notebook_output_path}')

sample_dir = notebook_output_path / f'{dataset}_samples_{model_name}'
sample_dir.mkdir(exist_ok=True, parents=True)

fixed_noise = torch.randn(batch_size, (img_size // patch_size)**2, channel_size * patch_size ** 2, device=device)
if num_classes:
    fixed_y = torch.randint(num_classes, (batch_size,), device=device)
else:
    fixed_y = None

model = transformer_flow.Model(in_channels=channel_size, img_size=img_size, patch_size=patch_size, 
              channels=channels, num_blocks=blocks, layers_per_block=layers_per_block,
             num_classes=num_classes).to(device)
model.load_state_dict(torch.load(ckpt_file, weights_only=True))
print('checkpoint loaded!')

checkpoint loaded!


In [4]:
# now let's generate samples
guided_samples = {}
with torch.no_grad():
    for guidance in [0, 1]:
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            samples = model.reverse(fixed_noise, fixed_y, guidance)
            guided_samples[guidance] = samples
        tv.utils.save_image(samples, sample_dir / f'samples_guidance_{guidance:.2f}.png', normalize=True, nrow=4)
        print(f'guidance {guidance} sampling complete')

guidance 0 sampling complete
guidance 1 sampling complete


In [6]:
# finally we denoise the samples
for p in model.parameters():
    p.requires_grad = False
    
# remember the loss is mean, whereas log prob is sum
lr = batch_size * img_size ** 2 * channel_size * noise_std ** 2
for guidance, sample in guided_samples.items():
    x = torch.clone(guided_samples[guidance]).detach()
    x.requires_grad = True
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        z, outputs, logdets = model(x, fixed_y)
    loss = model.get_loss(z, logdets)
    grad = torch.autograd.grad(loss, [x])[0]
    x.data = x.data - lr * grad
    samples = x
    print(f'guidance {guidance} denoising complete')
    tv.utils.save_image(samples, sample_dir / f'samples_guidance_{guidance:.2f}_denoised.png', normalize=True, nrow=4)

OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacity of 39.49 GiB of which 17.62 MiB is free. Process 2050464 has 420.00 MiB memory in use. Process 3051587 has 17.14 GiB memory in use. Process 3786240 has 1.18 GiB memory in use. Including non-PyTorch memory, this process has 20.54 GiB memory in use. Of the allocated memory 19.81 GiB is allocated by PyTorch, and 235.92 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)