In [1]:
!pip install imageio
!pip install imageio-ffmpeg
!wget https://people.eecs.berkeley.edu/~efros/img/Efros__photo_Peter_Badge_crop.jpg
!ls 

--2021-08-30 22:37:46--  https://people.eecs.berkeley.edu/~efros/img/Efros__photo_Peter_Badge_crop.jpg
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.244.190
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 456031 (445K) [image/jpeg]
Saving to: ‘Efros__photo_Peter_Badge_crop.jpg’


2021-08-30 22:37:47 (4.34 MB/s) - ‘Efros__photo_Peter_Badge_crop.jpg’ saved [456031/456031]

calc_metrics.py
dataset_tool.py
dnnlib
Dockerfile
docker_run.sh
docs
e6MuEwjO9Xm9aILQc_UJg3pneDuV87-xIws_FgeiSIfgd9i510kbWwrvxgGgdroMqPyWQotseeQJX2vCA8XfU1tGzTN8bw
efros.jpg
Efros__photo_Peter_Badge_crop.jpg
environment.yml
generate.py
legacy.py
LICENSE.txt
malik.jpg
malik.jpg.1
metrics
out
playground.ipynb
projector.py
__pycache__
README.md
style_mixing.py
test.ipynb
torch_utils
training
train.py


In [None]:
!mv Efros__photo_Peter_Badge_crop.jpg efros.jpg

calc_metrics.py
dataset_tool.py
dnnlib
Dockerfile
docker_run.sh
docs
e6MuEwjO9Xm9aILQc_UJg3pneDuV87-xIws_FgeiSIfgd9i510kbWwrvxgGgdroMqPyWQotseeQJX2vCA8XfU1tGzTN8bw
efros.jpg
environment.yml
generate.py
legacy.py
LICENSE.txt
malik.jpg
malik.jpg.1
metrics
out
playground.ipynb
projector.py
__pycache__
README.md
style_mixing.py
test.ipynb
torch_utils
training
train.py


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy
import os
from time import perf_counter
import click
import imageio
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F

import dnnlib
import legacy

In [52]:
def project(
    G,
    target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
    *,
    num_steps                  = 1000,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.1,
    initial_noise_factor       = 0.05,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    noise_ramp_length          = 0.75,
    regularize_noise_weight    = 1e5,
    verbose                    = False,
    device: torch.device
):
    target = (torch.cat((torch.zeros(2, 1024, 1024), 255*torch.ones(1, 1024, 1024)), 0)).to(device)
    print(target.shape,(G.img_channels, G.img_resolution, G.img_resolution))
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

    def logprint(*args):
        if verbose:
            print(*args)

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore

    # Compute w stats.
    logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
    z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)  # [N, L, C]
    w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)       # [N, 1, C]
    w_avg = np.mean(w_samples, axis=0, keepdims=True)      # [1, 1, C]
    w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

    # Setup noise inputs.
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }

    # Load VGG16 feature detector.
    url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
    with dnnlib.util.open_url(url) as f:
        vgg16 = torch.jit.load(f).eval().to(device)

    # Features for target image.
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        target_images = F.interpolate(target_images, size=(256, 256), mode='area')
    target_features = vgg16(target_images, resize_images=False, return_lpips=True)

    w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
    w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
    optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    for step in range(num_steps):
        # Learning rate schedule.
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
        synth_images = G.synthesis(ws, noise_mode='const')

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        synth_images = (synth_images + 1) * (255/2)
        if synth_images.shape[2] > 256:
            synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

        # Features for synth images.
        synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
        dist = (target_features - synth_features).square().sum()

        # Noise regularization.
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)
        loss = dist + reg_loss * regularize_noise_weight

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')

        # Save projected W for each optimization step.
        w_out[step] = w_opt.detach()[0]

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    return w_out.repeat([1, G.mapping.num_ws, 1])

In [96]:
def project_blue(
    G,
    target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
    *,
    num_steps                  = 1000,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.1,
    initial_noise_factor       = 0.05,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    noise_ramp_length          = 0.75,
    regularize_noise_weight    = 1e5,
    verbose                    = False,
    device: torch.device
):
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

    def logprint(*args):
        if verbose:
            print(*args)

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore

    # Compute w stats.
    logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
    z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)  # [N, L, C]
    w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)       # [N, 1, C]
    w_avg = np.mean(w_samples, axis=0, keepdims=True)      # [1, 1, C]
    w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

    # Setup noise inputs.
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }

    # Load VGG16 feature detector.
    # url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
    # with dnnlib.util.open_url(url) as f:
    #     vgg16 = torch.jit.load(f).eval().to(device)

    # Features for target image.
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        target_images = F.interpolate(target_images, size=(256, 256), mode='area')
    # target_features = vgg16(target_images, resize_images=False, return_lpips=True)

    w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
    w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
    # optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
    optimizer = torch.optim.Adam([w_opt], betas=(0.9, 0.999), lr=initial_learning_rate)

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    for step in range(num_steps):
        # Learning rate schedule.
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        print("\t", lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
        synth_images = G.synthesis(ws, noise_mode='const')

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        synth_images = (synth_images + 1) * (255/2)
        if synth_images.shape[2] > 256:
            synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

        # # Features for synth images.
        # synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
        # print(target_images.size())
        # print(target_images.max())
        # print(synth_images.size())
        # print(synth_images.max())
        # print(synth_images.mean())
        dist = (target_images - synth_images).square().sum()

        # Noise regularization.
        # reg_loss = 0.0
        # for v in noise_bufs.values():
        #     noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
        #     while True:
        #         reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
        #         reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
        #         if noise.shape[2] <= 8:
        #             break
        #         noise = F.avg_pool2d(noise, kernel_size=2)
        loss = dist #+ reg_loss * regularize_noise_weight

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')

        # Save projected W for each optimization step.
        w_out[step] = w_opt.detach()[0]

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    return w_out.repeat([1, G.mapping.num_ws, 1])

In [97]:
network_pkl = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl"
target_fname = "efros.jpg"
outdir = "out"
save_video = True
seed = 202
num_steps = 200

In [98]:
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f5fc8129c70>

In [99]:
# Load networks.
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as fp:
    G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore

Loading networks from "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl"...


In [100]:
# Load target image.
target_pil = PIL.Image.open(target_fname).convert('RGB')
w, h = target_pil.size
s = min(w, h)
target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
target_uint8 = np.array(target_pil, dtype=np.uint8)

In [101]:
# Optimize projection.
start_time = perf_counter()
projected_w_steps = project_blue(
    G,
    target=torch.cat((torch.zeros(2, 1024, 1024), 255*torch.ones(1, 1024, 1024)), 0), # pylint: disable=not-callable
    num_steps=num_steps,
    device=device,
    verbose=True,
    # initial_learning_rate=0
)

# projected_w_steps = project(
#     G,
#     target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
#     num_steps=num_steps,
#     device=device,
#     verbose=True,
#     # initial_learning_rate=0
# )
print (f'Elapsed: {(perf_counter()-start_time):.1f} s')

Computing W midpoint and stddev using 10000 samples...
	 0.0
step    1/200: dist 3466543104.00 loss 3466543104.00
	 0.01
step    2/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.02
step    3/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.03
step    4/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.04
step    5/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.05
step    6/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.06
step    7/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.07
step    8/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.08
step    9/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.09
step   10/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.1
step   11/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.1
step   12/200: dist 15651530841522176.00 loss 15651530841522176.00
	 0.1
step   13/200: dist 15651530841522176.00 loss 15651530841522176.00
	

In [59]:
# Render debug output: optional video and projected image and W vector.
os.makedirs(outdir, exist_ok=True)
if save_video:
    video = imageio.get_writer(f'{outdir}/projblue.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
    print (f'Saving optimization progress video "{outdir}/proj.mp4"')
    for projected_w in projected_w_steps:
        synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
        synth_image = (synth_image + 1) * (255/2)
        synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
        video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
    video.close()

Saving optimization progress video "out/proj.mp4"


In [60]:
# Save final projected frame and W vector.
target_pil.save(f'{outdir}/target.png')
projected_w = projected_w_steps[-1]
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
synth_image = (synth_image + 1) * (255/2)
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/projblue.png')
np.savez(f'{outdir}/projected_wblue.npz', w=projected_w.unsqueeze(0).cpu().numpy())

In [None]:
print("Done")

Done


In [61]:
from IPython.display import Video

Video("out/projblue.mp4")