In [1]:
import tqdm
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib
import tifffile

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load network.

model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained'
network_pkl = f'{model_root}/edm-cifar10-32x32-cond-vp.pkl'
# network_pkl = f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl'
# network_pkl = f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl'
# network_pkl = f'{model_root}/edm-imagenet-64x64-cond-adm.pkl'

device=torch.device('cpu')
print(f'Loading network from {network_pkl}')
with dnnlib.util.open_url(network_pkl) as f:
    net = pickle.load(f)['ema'].to(device)

print(f'CLASSES: {net.label_dim}')

Loading network from https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
CLASSES: 10


In [3]:
gridw=2
gridh=2
batch_size = gridw * gridh
seed = 1
torch.manual_seed(seed)

class_idx = 4
class_labels = None
if net.label_dim:
    class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
if class_labels is not None and class_idx is not None:
    class_labels[:, :] = 0
    class_labels[:, class_idx] = 1
class_labels

tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

In [4]:
def save_png(arr, path):
    path = f'{path}.png'
    print(f'Saving image grid to {path}')
    image = (arr * 127.5 + 128).clip(0, 255).to(torch.uint8)
    image = image.reshape(gridh, gridw, *image.shape[1:])
    image = image.permute(0, 3, 1, 4, 2)
    image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels)
    image = image.cpu().numpy()
    PIL.Image.fromarray(image, 'RGB').save(path)

def load_png(path):
    path = f'{path}.png'
    print(f'Loading image grid from {path}')
    image = PIL.Image.open(path)#.convert('RGB')
    image = np.array(image)
    image = torch.tensor(image, dtype=torch.float64)
    image = image.view(gridh, net.img_resolution, gridw, net.img_resolution, net.img_channels)
    image = image.permute(0, 2, 4, 1, 3)
    image = image.reshape(gridh * gridw, net.img_channels, net.img_resolution, net.img_resolution)
    image = (image - 128) / 127.5
    return image

# a = load_image('./gold', 2, 2, 64, 3)
# save_image(a, 'saved')

In [5]:
def save_tiff(arr, path):
    path = f'{path}.tiff'
    print(f'Saving image grid to {path}')
    image = (arr * 0.5 + 0.5).to(torch.float32)
    image = image.reshape(gridh, gridw, *image.shape[1:])
    image = image.permute(0, 3, 1, 4, 2)
    image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels)
    image = image.cpu().numpy()
    tifffile.imwrite(path, image, photometric='rgb')

def load_tiff(path):
    path = f'{path}.tiff'
    print(f'Loading image grid from {path}')
    image = tifffile.imread(path)
    image = np.array(image)
    image = torch.tensor(image, dtype=torch.float64)
    image = image.view(gridh, net.img_resolution, gridw, net.img_resolution, net.img_channels)
    image = image.permute(0, 2, 4, 1, 3)
    image = image.reshape(gridh * gridw, net.img_channels, net.img_resolution, net.img_resolution)
    image = (image - 0.5) / 0.5
    return image

test = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device).to(torch.float32)
save_png(test, f'test')
save_tiff(test, f'test')

Saving image grid to test.png
Saving image grid to test.tiff


In [6]:
# Pick latents and labels.
print(f'Generating {batch_size} images...')
x_orig0 = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device).to(torch.float32)

save_png(x_orig0, f'noise')
save_tiff(x_orig0, f'noise')
x_orig = x_orig0
# x_orig = load_tiff(f'noise')
# save_tiff(x_orig - x_orig0, 'noise.diff')

Generating 4 images...
Saving image grid to noise.png
Saving image grid to noise.tiff


In [7]:
def edm_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like,
    num_steps=33, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
    dest_path='generated'
):
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    save_tiff(latents, f'{dest_path}-00')
    x_next = latents.to(torch.float64) * t_steps[0]
    for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)

        # Euler step.
        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = net(x_next, t_next, class_labels).to(torch.float64)
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        j = i + 1
        if j % 3 == 0:
            save_tiff(x_next, f'{dest_path}-{j:02d}')

    return x_next

x = edm_sampler(net, x_orig, dest_path='tre')

Saving image grid to tre-00.tiff


  9%|▉         | 3/33 [00:01<00:17,  1.75it/s]

Saving image grid to tre-03.tiff


 18%|█▊        | 6/33 [00:03<00:16,  1.66it/s]

Saving image grid to tre-06.tiff


 27%|██▋       | 9/33 [00:05<00:14,  1.66it/s]

Saving image grid to tre-09.tiff


 36%|███▋      | 12/33 [00:07<00:12,  1.69it/s]

Saving image grid to tre-12.tiff


 45%|████▌     | 15/33 [00:08<00:10,  1.68it/s]

Saving image grid to tre-15.tiff


 55%|█████▍    | 18/33 [00:10<00:08,  1.68it/s]

Saving image grid to tre-18.tiff


 64%|██████▎   | 21/33 [00:12<00:07,  1.66it/s]

Saving image grid to tre-21.tiff


 73%|███████▎  | 24/33 [00:14<00:05,  1.64it/s]

Saving image grid to tre-24.tiff


 82%|████████▏ | 27/33 [00:16<00:03,  1.62it/s]

Saving image grid to tre-27.tiff


 91%|█████████ | 30/33 [00:18<00:01,  1.63it/s]

Saving image grid to tre-30.tiff


100%|██████████| 33/33 [00:19<00:00,  1.69it/s]

Saving image grid to tre-33.tiff





In [8]:
save_png(x, f'tre-99')

Saving image grid to tre-99.png
