In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
%%javascript
IPython.OutputArea.auto_scroll_threshold = 9999

In [None]:
from PULSE import PULSE
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import DataParallel
from pathlib import Path
from PIL import Image
import torchvision
from math import log10, ceil

In [None]:
import numpy as np
import PIL
import PIL.Image
import sys
import os
import glob
import scipy
import scipy.ndimage
import dlib
from drive import open_url
from pathlib import Path
from bicubic import BicubicDownSample
import torchvision
from shape_predictor import align_face

In [None]:
import matplotlib.pyplot as plt
import collections
import stylegan
import tempfile
import subprocess

In [None]:
kwargs = dict(
    input_dir="aligned_faces/",  # 'input data directory'
    output_dir='runs',  # 'output data directory'
    cache_dir='cache',  # 'cache directory for model weights'
    duplicates=1,  # 'How many HR images to produce for every image in the input directory'
    batch_size=1,  # 'Batch size to use during optimization'
#     seed=0,  # 'manual seed to use'
    loss_str="100*L2+0.05*GEOCROSS",  # 'Loss function to use'
    eps=2e-3,  # 'Target for downscaling loss (L2)'
    noise_type='trainable',  # 'zero, fixed, or trainable'
    num_trainable_noise_layers=5,  # 'Number of noise layers to optimize'
    tile_latent=False,  # 'Whether to forcibly tile the same latent 18 times'
    bad_noise_layers="17",  # 'List of noise layers to zero out to improve image quality')
    opt_name='custom',  # 'Optimizer to use in projected gradient descent'
    learning_rate=0.4,  #  help='Learning rate to use during optimization'
    steps=200,  # 'Number of optimization steps'
    lr_schedule='linear1cycledrop',  # 'fixed, linear1cycledrop, linear1cycle'
    save_intermediate=False,  # 'Whether to store and save intermediate HR and LR images during optimization'
)

### align_face.py

In [None]:
%%time
# cache_dir = Path(kwargs["cache_dir"])
# cache_dir.mkdir(parents=True, exist_ok=True)

# print("Downloading Shape Predictor")
# f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True)
# predictor = dlib.shape_predictor(f)

unaligned_path = "unaligned_faces/"
aligned_faces_size = 32
f = "shape_predictor_68_face_landmarks.dat"
predictor = dlib.shape_predictor(f)

for im in Path(unaligned_path).glob("*.*"):
    print(im)
    faces = align_face(str(im), predictor)

    for i, face in enumerate(faces):
        target_path = Path(kwargs["input_dir"]) / (im.stem + f"_{i}.png")
        if not os.path.exists(target_path):
            print(face._size)
            if aligned_faces_size is not None:
                factor = 1024 // aligned_faces_size
                assert aligned_faces_size * factor == 1024
                D = BicubicDownSample(factor=factor)
                face_tensor = torchvision.transforms.ToTensor()(face).unsqueeze(0).cuda()
                face_tensor_lr = D(face_tensor)[0].cpu().detach().clamp(0, 1)
                face = torchvision.transforms.ToPILImage()(face_tensor_lr)

            face.save(target_path)


### setup everything

In [None]:
class Images(Dataset):
    def __init__(self, root_dir, duplicates):
        self.root_path = Path(root_dir)
        self.image_list = list(self.root_path.glob("*.png"))
        self.duplicates = duplicates # Number of times to duplicate the image in the dataset to produce multiple HR images
        factor = 1024 // 32
        self.D = BicubicDownSample(factor=factor)

    def __len__(self):
        return self.duplicates*len(self.image_list)

    def __getitem__(self, idx):
        img_path = self.image_list[idx//self.duplicates]
        image = torchvision.transforms.ToTensor()(Image.open(img_path))
        # HACK
        if image.shape == (3, 1024, 1024):
            image = self.D(image.unsqueeze(0).cuda())[0].cpu().detach().clamp(0, 1)
        elif image.shape == (3, 32, 32):
            # already resized (:
            pass
        else:
            raise ValueError(image.shape)
        if(self.duplicates == 1):
            return image,img_path.stem
        else:
            return image,img_path.stem+f"_{(idx % self.duplicates)+1}"

dataset = Images(kwargs["input_dir"], duplicates=kwargs["duplicates"])
out_path = Path(kwargs["output_dir"])
out_path.mkdir(parents=True, exist_ok=True)

dataloader = DataLoader(dataset, batch_size=kwargs["batch_size"])

model = PULSE(cache_dir=kwargs["cache_dir"])
# removed because of error:
# TypeError: forward() missing 1 required positional argument: 'ref_im'
# model = DataParallel(model)

toPIL = torchvision.transforms.ToPILImage()

In [None]:
%%time
# load images
ref_imgs = {}
for ref_im, ref_im_name_tuple in dataloader:
    ref_imgs[ref_im_name_tuple[0]] = ref_im.cuda()
print(ref_imgs.keys())

In [None]:
%%time
d_basic = stylegan.D_basic()
d_basic.load_state_dict(torch.load("karras2019stylegan-ffhq-1024x1024.for_d_basic.pt"))


In [None]:
def perpendicular(v):
    t = np.random.randn(*v.shape)
    t -= t.flatten().dot(v.flatten()) * v / v.flatten().dot(v.flatten())
    t *= np.linalg.norm(v) / np.linalg.norm(t)
    return t

def perpendicular_multi(vs, initial=None):
    def _perpendicular(v, t):
        # makes t perpendicular to v
        t -= t.flatten().dot(v.flatten()) * v / v.flatten().dot(v.flatten())
        t *= np.linalg.norm(v) / np.linalg.norm(t)
        return t

    vs_orig = vs
    vs = list(vs)  # make a copy
    # this final one should be perpendicular to all
    if initial is None:
        vs.append(np.random.randn(*vs[0].shape))
    else:
        vs.append(initial)
    for i in range(len(vs) - 1):
        for j in range(i + 1, len(vs)):
            vs[j] = _perpendicular(vs[i], vs[j])
    # print([vs[-1].flatten().dot(v.flatten()) for v in vs_orig])
    return vs[-1]

def negation_init(prev_results):
    if not prev_results:
        return None

    def _negation_reduce(vs):
        avg_v = sum(vs) / len(vs)
        avg_norm = np.mean([np.linalg.norm(v) for v in vs])
        v = -avg_v
        v *= avg_norm / np.linalg.norm(v)
        return v

    var_list_initial_values = []
    for var_idx in range(len(prev_results[0]["var_list"])):
        prev_vars = [result["var_list"][var_idx].detach().cpu().numpy()
                     for result in prev_results]
        var_list_initial_values.append(torch.tensor(
            _negation_reduce(prev_vars)))
    return var_list_initial_values

def perpendicular_init(prev_results):
    if not prev_results:
        return None
    var_list_initial_values = []
    for var_idx in range(len(prev_results[0]["var_list"])):
        prev_vars = [result["var_list"][var_idx].detach().cpu().numpy()
                     for result in prev_results]
        var_list_initial_values.append(torch.tensor(
            perpendicular_multi(prev_vars)))
    return var_list_initial_values

def farthest_sampled_init(prev_results):
    if not prev_results:
        return None
    var_list_initial_values = []
    for var_idx in range(len(prev_results[0]["var_list"])):
        prev_vars = [result["var_list"][var_idx].detach().cpu().numpy()
                     for result in prev_results]
        
        n_samples = 10000
        prev_shape = prev_vars[0].shape
        prev_size = np.prod(prev_shape)
        samples = np.random.randn(n_samples, *prev_shape)
        num_prev = len(prev_vars)
        prev_tensor = np.array(prev_vars).reshape(num_prev, prev_size).T
        avg_distance = samples.reshape(-1, prev_size).dot(prev_tensor).mean(axis=1)
        best_idx = np.argmin(avg_distance)
        best_sample = samples[best_idx]
        
        sample = best_sample * np.linalg.norm(prev_vars[0]) / np.linalg.norm(best_sample)
        
        var_list_initial_values.append(torch.tensor(sample))
    return var_list_initial_values

def make_postprocess_perpendicular_projection(prev_results, only_latents=False):
    def step_postprocess(params):
        for var_idx in range(len(params)):
            if var_idx == 0 or not only_latents:
                initial = params[var_idx].detach().cpu().numpy()
                prev_vars = [result["var_list"][var_idx].detach().cpu().numpy()
                             for result in prev_results]            
                updated = perpendicular_multi(prev_vars, initial)
                params[var_idx].copy_(torch.from_numpy(updated))

    return step_postprocess


In [None]:
def filter_results_with_D(results, num_keep):
    with torch.no_grad():
        d_scores = [d_basic(res["HR"]).item() for res in results]
    # want smallest num_keep scores
    keep_scores = list(sorted(d_scores))[:num_keep]
    new_results = []
    for d_score, res in zip(d_scores, results):
        if d_score in keep_scores:
            new_results.append(res)
    return new_results

def sortby_D(results):
    with torch.no_grad():
        d_scores = [d_basic(res["HR"]).item() for res in results]
    return [results[i] for i in np.argsort(d_scores)]

In [None]:
configurations = [
    "vanilla_pulse",
    "iterative_negation_initialization",
    "farthest_sampled_initialization",
    "iterative_perpendicular_initialization",
    "perpendicular_projection_optimizer",
    "perpendicular_projection_optimizer_only_latents",
    "perpendicular_projection_optimizer_psi0.7",
    "perpendicular_projection_optimizer_discloss",
    "perpendicular_projection_optimizer_psi0.7_discloss",
]

In [None]:
def configuration_to_extra_kwargs(c, prev_results):
    extra_kwargs = {}
    
    if c == "vanilla_pulse":
        var_list_initial_values = None
    elif c == "iterative_negation_initialization":
        var_list_initial_values = negation_init(prev_results)
    elif c in {"iterative_perpendicular_initialization",
               "perpendicular_projection_optimizer",
               "perpendicular_projection_optimizer_only_latents",
               "perpendicular_projection_optimizer_psi0.7",
               "perpendicular_projection_optimizer_discloss",
               "perpendicular_projection_optimizer_psi0.7_discloss"}:
        var_list_initial_values = perpendicular_init(prev_results)
    elif c == "farthest_sampled_initialization":
        var_list_initial_values = farthest_sampled_init(prev_results)
    else:
        raise ValueError(c)
    extra_kwargs["var_list_initial_values"] = var_list_initial_values
    
    if c in {"perpendicular_projection_optimizer",
             "perpendicular_projection_optimizer_psi0.7",
             "perpendicular_projection_optimizer_discloss",
             "perpendicular_projection_optimizer_psi0.7_discloss"}:
        step_postprocess = make_postprocess_perpendicular_projection(prev_results, only_latents=False)
    elif c == "perpendicular_projection_optimizer_only_latents":
        step_postprocess = make_postprocess_perpendicular_projection(prev_results, only_latents=True)
    else:
        step_postprocess = None
    extra_kwargs["step_postprocess"] = step_postprocess
    
    if c in {"perpendicular_projection_optimizer_psi0.7",
             "perpendicular_projection_optimizer_psi0.7_discloss"}:
        extra_kwargs["psi"] = 0.7
        
    if c in {"perpendicular_projection_optimizer_discloss",
             "perpendicular_projection_optimizer_psi0.7_discloss"}:
        extra_kwargs["loss_str"] = "100*L2+0.05*GEOCROSS+0.01*DISC"

    return extra_kwargs

In [None]:
def results_to_grid(all_results, grid_shape, img_prefix, out_dir="runs"):
    tmp = np.array(
        [r["HR"].numpy() for r in all_results]
    ).reshape(
        grid_shape[0], grid_shape[1], 3, 1024, 1024
    ).transpose(2, 0, 3, 1, 4).reshape(
        3, grid_shape[0] * 1024, grid_shape[1] * 1024
    )
    toPIL(torch.from_numpy(tmp)).save(f"{out_dir}/{img_prefix}__grid.png")

In [None]:
def lerp(t1, t2, alpha):
    return t1 * alpha + t2 * (1 - alpha)

def spherp(t1, t2, alpha):
    t1 = t1.cpu().numpy()
    t2 = t2.cpu().numpy()
    norm1 = np.linalg.norm(t1) + 1e-8
    norm2 = np.linalg.norm(t2) + 1e-8
    norm_out = norm1 * alpha + norm2 * (1 - alpha)
    direction_out = t1 / norm1 * alpha + t2 / norm2 * (1 - alpha)
    return torch.from_numpy(direction_out / (np.linalg.norm(direction_out) + 1e-8) * norm_out).cuda()

def generate_interpolations(latent_noise_pair0,
                            latent_noise_pair1,
                            num_points,
                            interpolation_type="spherical",
                            **kwargs):
    if interpolation_type == "spherical":
        interp_fn = spherp
    elif interpolation_type == "linear":
        interp_fn = lerp
    else:
        raise ValueError(interpolation_type)

    latent0, noise0 = latent_noise_pair0
    latent1, noise1 = latent_noise_pair1

    imgs = []
    for alpha in np.linspace(1, 0, num_points):
        latent_tmp = interp_fn(latent0, latent1, alpha)
        noise_tmp = [interp_fn(n0, n1, alpha)
                     for n0, n1 in zip(noise0, noise1)]
        img = model.synthesize(latent_tmp, noise_tmp, **kwargs)
        imgs.append(img)
    return imgs

In [None]:
def imgs_to_animation(imgs, output_path):
    assert output_path.endswith(".mp4")
    tempdir = tempfile.TemporaryDirectory()

    for idx, img in enumerate(imgs):
        toPIL(img[0]).save(f"{tempdir.name}/img_{idx:04d}.png")

    subprocess.check_output(["ffmpeg", 
                             "-y",  # overwrite output file
                             "-f", "image2", # force format?
                             "-i", f"{tempdir.name}/img_%04d.png",  # input files
                             "-start_number", "0",  # what number to start at
                             "-filter:v", "setpts=2.0*PTS",  # slow down video
                             output_path,  # output file
                            ])

### make grid and interpolations for single image + configuration

In [None]:
%%time

num_samples = 25
grid_shape = (2, 5)
filter_results = True
sort_results = False
configuration = "vanilla_pulse"
# configuration = "perpendicular_projection_optimizer_psi0.7_discloss"
img_name = "oprah_0"

ref_im = ref_imgs[img_name]

all_results = []
for _ in range(num_samples * (2 if filter_results else 1)):
    extra_kwargs = configuration_to_extra_kwargs(configuration, all_results)
    new_kwargs = dict(kwargs)  # make a copy
    new_kwargs.update(extra_kwargs)
    for j, results in enumerate(model(ref_im=ref_im, 
                                      **new_kwargs)):
        assert j == 0
        all_results.append(results)
        
if filter_results:
    all_results = filter_results_with_D(all_results, num_keep=num_samples)
if sort_results:
    all_results = sortby_D(all_results)

In [None]:
%%time
results_to_grid(all_results, 
                grid_shape=grid_shape,
                img_prefix=f"{img_name}__{configuration}__tmp")

In [None]:
%%time

all_imgs = []

latent_noise_pairs = [
    model.var_list_to_latent_and_noise(res["var_list"], **new_kwargs)
    for res in all_results
]

for idx in range(len(all_results)):
    idx2 = (idx + 1) % len(all_results)
    imgs = generate_interpolations(
        latent_noise_pairs[idx],
        latent_noise_pairs[idx2],
        num_points=30,
        interpolation_type="spherical",
        **new_kwargs
    )
    all_imgs.extend(imgs)

In [None]:
%%time

imgs_to_animation(all_imgs,
                  f"{img_name}__{configuration}__tmp.mp4")

### run all configurations for one image

In [None]:
%%time

num_samples = 25
grid_shape = (5, 5)
filter_results = False
sort_results = False
all_all_results = collections.defaultdict(dict)

for img_name in ["oprah_0"]:
    for configuration in configurations:
        img_prefix = f"{img_name}__{configuration}"
        print("Starting: " + img_prefix)

        ref_im = ref_imgs[img_name]

        all_results = []
        for _ in range(num_samples):
            extra_kwargs = configuration_to_extra_kwargs(configuration, all_results)
            new_kwargs = dict(kwargs)  # make a copy
            new_kwargs.update(extra_kwargs)
            for j, results in enumerate(model(ref_im=ref_im, 
                                              **new_kwargs)):
                assert j == 0
                all_results.append(results)

        if filter_results:
            all_results = filter_results_with_D(all_results, num_keep=num_samples)
        if sort_results:
            all_results = sortby_D(all_results)

        all_all_results[img_name][configuration] = all_results
        results_to_grid(all_results, 
                        grid_shape=grid_shape,
                        img_prefix=img_prefix,
                        out_dir="final_runs")

### run all configurations for all images

In [None]:
%%time

num_samples = 25
grid_shape = (5, 5)
filter_results = False
sort_results = False
all_all_results = collections.defaultdict(dict)

for img_name in ref_imgs.keys():
    for configuration in configurations:
        img_prefix = f"{img_name}__{configuration}"
        print("Starting: " + img_prefix)

        ref_im = ref_imgs[img_name]

        all_results = []
        for _ in range(num_samples):
            extra_kwargs = configuration_to_extra_kwargs(configuration, all_results)
            new_kwargs = dict(kwargs)  # make a copy
            new_kwargs.update(extra_kwargs)
            for j, results in enumerate(model(ref_im=ref_im, 
                                              **new_kwargs)):
                assert j == 0
                all_results.append(results)

        if filter_results:
            all_results = filter_results_with_D(all_results, num_keep=num_samples)
        if sort_results:
            all_results = sortby_D(all_results)

        all_all_results[img_name][configuration] = all_results
        results_to_grid(all_results, 
                        grid_shape=grid_shape,
                        img_prefix=img_prefix,
                        out_dir="final_runs")