In [21]:
import tqdm
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib

In [22]:
# Load network.

model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained'
# network_pkl = f'{model_root}/edm-cifar10-32x32-cond-vp.pkl'
# network_pkl = f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl'
network_pkl = f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl'
# network_pkl = f'{model_root}/edm-imagenet-64x64-cond-adm.pkl'

device=torch.device('cpu')
print(f'Loading network from {network_pkl}')
with dnnlib.util.open_url(network_pkl) as f:
    net = pickle.load(f)['ema'].to(device)

print(f'CLASSES: {net.label_dim}')

Loading network from https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-afhqv2-64x64-uncond-vp.pkl
CLASSES: 0


In [23]:
gridw=2
gridh=2
num_steps=18
my_sigma_min=0.002
my_sigma_max=90
rho=7
S_churn=0
S_min=0
S_max=float('inf')
S_noise=1

batch_size = gridw * gridh
seed = 1
torch.manual_seed(seed)

class_idx = 417
class_labels = None
if net.label_dim:
    class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
if class_labels is not None and class_idx is not None:
    class_labels[:, :] = 0
    class_labels[:, class_idx] = 1
class_labels

In [24]:
def save_image(arr, path):
    print(f'Saving image grid to {path}')
    image = (arr * 127.5 + 128).clip(0, 255).to(torch.uint8)
    image = image.reshape(gridh, gridw, *image.shape[1:])
    image = image.permute(0, 3, 1, 4, 2)
    image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels)
    image = image.cpu().numpy()
    PIL.Image.fromarray(image, 'RGB').save(path)

In [25]:
def load_image(path, gridh, gridw, img_resolution, img_channels):
    print(f'Loading image grid from {path}')
    image = PIL.Image.open(path)#.convert('RGB')
    image = np.array(image)
    image = torch.tensor(image, dtype=torch.float64)
    image = image.view(gridh, img_resolution, gridw, img_resolution, img_channels)
    image = image.permute(0, 2, 4, 1, 3)
    image = image.reshape(gridh * gridw, img_channels, img_resolution, img_resolution)
    image = (image - 128) / 127.5
    return image

# a = load_image('./gold.png', 2, 2, 64, 3)
# save_image(a, 'saved.png')

In [28]:
# Pick latents and labels.
print(f'Generating {batch_size} images...')
latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)

x_orig = latents.to(torch.float64)
save_image(x_orig, f'noise-{seed}.png')
# x_orig

Generating 4 images...
Saving image grid to noise-1.png


In [29]:
def denoise(dest_path):
    
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(my_sigma_min, net.sigma_min)
    sigma_max = min(my_sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    save_image(x_orig, f'{dest_path}-{seed}-00.png')
    x_next = x_orig * t_steps[0]
    for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)

        # Euler step.
        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = net(x_next, t_next, class_labels).to(torch.float64)
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        j = i + 1
        if j % 3 == 0:
            save_image(x_next, f'{dest_path}-{seed}-{j:02d}.png')

In [30]:
denoise('cifar')

Saving image grid to cifar-1-00.png


 17%|█▋        | 3/18 [00:03<00:19,  1.30s/step]

Saving image grid to cifar-1-03.png


 33%|███▎      | 6/18 [00:07<00:15,  1.29s/step]

Saving image grid to cifar-1-06.png


 50%|█████     | 9/18 [00:11<00:11,  1.27s/step]

Saving image grid to cifar-1-09.png


 67%|██████▋   | 12/18 [00:15<00:07,  1.29s/step]

Saving image grid to cifar-1-12.png


 83%|████████▎ | 15/18 [00:19<00:03,  1.28s/step]

Saving image grid to cifar-1-15.png


100%|██████████| 18/18 [00:22<00:00,  1.25s/step]

Saving image grid to cifar-1-18.png



