<a href="https://colab.research.google.com/github/koolamusic/snapterest/blob/master/ALAE_Inference_Wandb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Clone the Repo

In [None]:
!git clone https://github.com/podgorskiy/ALAE.git

fatal: destination path 'ALAE' already exists and is not an empty directory.


## Set Path Variables

In [None]:
%cd ALAE
%set_env PYTHONPATH=/project/pylib/src:/env/python

/content/ALAE
env: PYTHONPATH=/project/pylib/src:/env/python


## Install Requirements

In [None]:
%pip install -r requirements.txt
%pip install dareblopy
%pip install wandb -q --upgrade



## Download Pre-trained Models

In [None]:
!python training_artifacts/download_all.py

Downloading: model_submitted.pth
File training_artifacts/ffhq/model_submitted.pth already exists, skipping
Downloading: model_157.pth
File training_artifacts/ffhq/model_157.pth already exists, skipping
Downloading: model_194.pth
File training_artifacts/ffhq/model_194.pth already exists, skipping
Downloading: model_final.pth
File training_artifacts/celeba/model_final.pth already exists, skipping
Downloading: model_final.pth
File training_artifacts/bedroom/model_final.pth already exists, skipping
Downloading: model_262r.pth
File training_artifacts/celeba-hq256/model_262r.pth already exists, skipping
Downloading: model_580r.pth
File training_artifacts/celeba-hq256/model_580r.pth already exists, skipping


## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os

import torch.utils.data
from torchvision.utils import make_grid
from net import *
from model import Model
from launcher import run
from checkpointer import Checkpointer
from dlutils.pytorch import count_parameters
from defaults import get_cfg_defaults
import lreq
import logging
from PIL import Image
import PIL
import bimpy
import cv2
import random 
from skimage.transform import resize
# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

import matplotlib.pyplot as plt
%matplotlib inline

lreq.use_implicit_lreq.set(True)

# WandB – Import the wandb library
import wandb

In [None]:
# WandB – Login to your wandb account so you can log all your metrics
wandb.login()

True

In [None]:
# WandB – Initialize a new run
wandb.init(project="alae")

## General Disclaimer
I've taken the code from individual files in the repo and modified them here in this notebook to enable logging with Weights and Biases. No other changes have been made

### Helper methods to parse stuff :)

In [None]:
def parse_configs(config_file='configs/ffhq.yaml', write_log=False):
    import sys
    cfg = get_cfg_defaults()
    world_size = 1
    if len(os.path.splitext(config_file)[1]) == 0:
        config_file += '.yaml'
    if not os.path.exists(config_file) and os.path.exists(os.path.join('configs', config_file)):
        config_file = os.path.join('configs', config_file)
    cfg.merge_from_file(config_file)
    cfg.freeze()

    logger = logging.getLogger("logger")
    logger.setLevel(logging.DEBUG)

    output_dir = cfg.OUTPUT_DIR
    os.makedirs(output_dir, exist_ok=True)

    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    if write_log:
        filepath = os.path.join(output_dir, 'log.txt')
        if isinstance(write_log, str):
            filepath = write_log
        fh = logging.FileHandler(filepath)
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

    logger.info("World size: {}".format(world_size))

    logger.info("Loaded configuration file {}".format(config_file))
    with open(config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device_ = torch.cuda.current_device()
    print("Running on ", torch.cuda.get_device_name(device_))

    return cfg, logger, device_

## Style Mixing

In [None]:
src_len = 5
dst_len = 6


def place(canvas, image, x, y):
    image = image.cpu().detach().numpy()
    im_size = image.shape[1]
    canvas[:, y * im_size: (y + 1) * im_size, x * im_size: (x + 1) * im_size] = image * 0.5 + 0.5


def mix(cfg, logger):
    with torch.no_grad():
        _main(cfg, logger)


def _main(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()
    last_epoch = list(extra_checkpoint_data['auxiliary']['scheduler'].values())[0]['last_epoch']
    logger.info("Model trained for %d epochs" % last_epoch)

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        layer_count = cfg.MODEL.LAYER_COUNT

        zlist = []
        for i in range(x.shape[0]):
            Z, _ = model.encode(x[i][None, ...], layer_count - 1, 1)
            zlist.append(Z)
        Z = torch.cat(zlist)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        decoded = []
        for i in range(x.shape[0]):
            r = model.decoder(x[i][None, ...], layer_count - 1, 1, noise=True)
            decoded.append(r)
        return torch.cat(decoded)

    path = cfg.DATASET.STYLE_MIX_PATH
    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)

    src_originals = []
    for i in range(src_len):
        try:
            im = np.asarray(Image.open(os.path.join(path, 'src/%d.png' % i)))
        except FileNotFoundError:
            im = np.asarray(Image.open(os.path.join(path, 'src/%d.jpg' % i)))
        im = im.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        src_originals.append(x)
    src_originals = torch.stack([x for x in src_originals])
    dst_originals = []
    for i in range(dst_len):
        try:
            im = np.asarray(Image.open(os.path.join(path, 'dst/%d.png' % i)))
        except FileNotFoundError:
            im = np.asarray(Image.open(os.path.join(path, 'dst/%d.jpg' % i)))
        im = im.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        dst_originals.append(x)
    dst_originals = torch.stack([x for x in dst_originals])

    src_latents = encode(src_originals)
    src_images = decode(src_latents)

    dst_latents = encode(dst_originals)
    dst_images = decode(dst_latents)

    canvas = np.zeros([3, im_size * (dst_len + 1), im_size * (src_len + 1)])

    os.makedirs('style_mixing/output/%s/' % cfg.NAME, exist_ok=True)

    # Lists for Wandb logging
    source_images = []
    coarse_images = []
    recons_images = []


    for i in range(src_len):
        img = src_originals[i] * 0.5 + 0.5
        source_images.append(wandb.Image(img,caption="{} Source_{}".format(cfg.NAME, i)))
        place(canvas, src_originals[i], 1 + i, 0)

    for i in range(dst_len):
        img = dst_originals[i] * 0.5 + 0.5
        coarse_images.append(wandb.Image(img,caption="{} Coarse_{}".format(cfg.NAME, i)))
        place(canvas, dst_originals[i], 0, 1 + i)

    style_ranges = [range(0, 4)] * 3 + [range(4, 8)] * 2 + [range(8, layer_count * 2)]

    def mix_styles(style_src, style_dst, r):
        style = style_dst.clone()
        style[:, r] = style_src[:, r]
        return style

    for row in range(dst_len):
        row_latents = torch.stack([dst_latents[row]] * src_len)
        style = mix_styles(src_latents, row_latents, style_ranges[row])
        rec = model.decoder(style, layer_count - 1, 1, noise=True)
        for j in range(rec.shape[0]):
            img = rec[j] * 0.5 + 0.5
            recons_images.append(wandb.Image(img,caption="{} Source_{}_{}".format(cfg.NAME, row, j)))
            place(canvas, rec[j], 1 + j, 1 + row)

    #wandb.log({"Source {} Images".format(cfg.NAME): source_images})
    #wandb.log({"Coarse {} Images".format(cfg.NAME): coarse_images})
    #wandb.log({"Reconstructed {} Images".format(cfg.NAME): recons_images})
    wandb.log({"Style Mixed Output from {} Images".format(cfg.NAME): [wandb.Image(torch.Tensor(canvas),caption="{} Style Mixed Output".format(cfg.NAME))]})

In [None]:
all_configs = ['configs/bedroom.yaml','configs/celeba.yaml','configs/ffhq.yaml']
for config in all_configs[1:]:
  cfg, logger, device = parse_configs(config)
  mix(cfg, logger)

2020-06-11 18:48:49,892 logger INFO: World size: 1
2020-06-11 18:48:49,893 logger INFO: Loaded configuration file configs/celeba.yaml
2020-06-11 18:48:49,893 logger INFO: 
# Config for training ALAE on CelebA at resolution 128x128

NAME: celeba
PPL_CELEBA_ADJUSTMENT: True
DATASET:
  PART_COUNT: 16
  SIZE: 182637
  SIZE_TEST: 202576 - 182637
  PATH: /data/datasets/celeba/tfrecords/celeba-r%02d.tfrecords.%03d
  PATH_TEST: /data/datasets/celeba-test/tfrecords/celeba-r%02d.tfrecords.%03d
  MAX_RESOLUTION_LEVEL: 7

  SAMPLES_PATH: dataset_samples/faces/realign128x128
  STYLE_MIX_PATH: style_mixing/test_images/set_celeba
MODEL:
  LATENT_SPACE_SIZE: 256
  LAYER_COUNT: 6
  MAX_CHANNEL_COUNT: 256
  START_CHANNEL_COUNT: 64
  DLATENT_AVG_BETA: 0.995
  MAPPING_LAYERS: 8
OUTPUT_DIR: training_artifacts/celeba
TRAIN:
  BASE_LEARNING_RATE: 0.002
  EPOCHS_PER_LOD: 6
  LEARNING_DECAY_RATE: 0.1
  LEARNING_DECAY_STEPS: []
  TRAIN_EPOCHS: 80
  #                    4       8       16       32       64      

## Multi Scale Reconstruction

In [None]:
def place(canvas, image, x, y):
    im_size = image.shape[2]
    if len(image.shape) == 4:
        image = image[0]
    canvas[:, y: y + im_size, x: x + im_size] = image * 0.5 + 0.5

def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, 1.0 * ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH
    # path = 'dataset_samples/faces/realign1024x1024_paper'

    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)

    paths = list(os.listdir(path))

    paths = sorted(paths)
    random.seed(5)
    random.shuffle(paths)

    def move_to(list, item, new_index):
        list.remove(item)
        list.insert(new_index, item)

    # move_to(paths, '00026.png', 0)
    # move_to(paths, '00074.png', 1)
    # move_to(paths, '00134.png', 2)
    # move_to(paths, '00036.png', 3)

    def make(paths):
        src = []
        for filename in paths:
            img = np.asarray(Image.open(path + '/' + filename))
            if img.shape[2] == 4:
                img = img[:, :, :3]
            im = img.transpose((2, 0, 1))
            x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1.
            if x.shape[0] == 4:
                x = x[:3]
            factor = x.shape[2] // im_size
            if factor != 1:
                x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
            assert x.shape[2] == im_size
            src.append(x)

        with torch.no_grad():
            reconstructions = []
            for s in src:
                latents = encode(s[None, ...])
                reconstructions.append(decode(latents).cpu().detach().numpy())
        return src, reconstructions

    def chunker_list(seq, size):
        return list((seq[i::size] for i in range(size)))

    final = chunker_list(paths, 4)
    path0, path1, path2, path3 = final

    path0.reverse()
    path1.reverse()
    path2.reverse()
    path3.reverse()

    src0, rec0 = make(path0)
    src1, rec1 = make(path1)
    src2, rec2 = make(path2)
    src3, rec3 = make(path3)

    initial_resolution = im_size

    lods_down = 1
    padding_step = 4

    width = 0
    height = 0

    current_padding = 0

    final_resolution = initial_resolution
    for _ in range(lods_down):
        final_resolution /= 2

    for i in range(lods_down + 1):
        width += current_padding * 2 ** (lods_down - i)
        height += current_padding * 2 ** (lods_down - i)
        current_padding += padding_step

    width += 2 ** (lods_down + 1) * final_resolution
    height += (lods_down + 1) * initial_resolution

    width = int(width)
    height = int(height)

    def make_part(current_padding, src, rec):
        canvas = np.ones([3, height + 20, width + 10])

        padd = 0

        initial_padding = current_padding

        height_padding = 0

        for i in range(lods_down + 1):
            for x in range(2 ** i):
                for y in range(2 ** i):
                    try:
                        ims = src.pop()
                        imr = rec.pop()[0]
                        ims = ims.cpu().detach().numpy()
                        imr = imr

                        res = int(initial_resolution / 2 ** i)

                        ims = resize(ims, (3, initial_resolution / 2 ** i, initial_resolution / 2 ** i))
                        imr = resize(imr, (3, initial_resolution / 2 ** i, initial_resolution / 2 ** i))

                        place(canvas, ims,
                              current_padding + x * (2 * res + current_padding),
                              i * initial_resolution + height_padding + y * (res + current_padding))

                        place(canvas, imr,
                              current_padding + res + x * (2 * res + current_padding),
                              i * initial_resolution + height_padding + y * (res + current_padding))

                    except IndexError:
                        return canvas

            height_padding += initial_padding * 2

            current_padding -= padding_step
            padd += padding_step
        return canvas

    canvas = [make_part(current_padding, src0, rec0), make_part(current_padding, src1, rec1),
              make_part(current_padding, src2, rec2), make_part(current_padding, src3, rec3)]

    canvas = np.concatenate(canvas, axis=2)

    #print('Saving image')
    #save_path = 'make_figures/output/%s/reconstructions_multiresolution.png' % cfg.NAME
    #os.makedirs(os.path.dirname(save_path), exist_ok=True)
    #save_image(torch.Tensor(canvas), save_path)
    wandb.log({"Multiresolution Reconstruction from {} Images".format(cfg.NAME): [wandb.Image(torch.Tensor(canvas),caption="{} Multi Res Output".format(cfg.NAME))]})

In [None]:
all_configs = ['configs/bedroom.yaml','configs/celeba.yaml']
for config in all_configs:
  cfg, logger, device = parse_configs(config)
  sample(cfg, logger)

2020-06-11 18:50:10,639 logger INFO: World size: 1
2020-06-11 18:50:10,639 logger INFO: World size: 1
2020-06-11 18:50:10,639 logger INFO: World size: 1
2020-06-11 18:50:10,641 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:50:10,641 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:50:10,641 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:50:10,643 logger INFO: 
 # Config for training ALAE on lsun-bedroom at resolution 256x256

NAME: bedroom
DATASET:
  PART_COUNT: 4
  SIZE: 758260
  FFHQ_SOURCE: /data/datasets/lsun-bedroom-full/lsun-bedroom-full-r%02d.tfrecords
  PATH: /data/datasets/lsun-bedroom-full/splitted/lsun-bedroom-full-r%02d.tfrecords.%03d
  MAX_RESOLUTION_LEVEL: 8

  SAMPLES_PATH: dataset_samples/bedroom256x256
  STYLE_MIX_PATH: style_mixing/test_images/set_bedroom
MODEL:
  LATENT_SPACE_SIZE: 512
  LAYER_COUNT: 7
  MAX_CHANNEL_COUNT: 512
  START_CHANNEL_COUNT: 32
  DLATENT_AVG_BETA: 0.995
  MAPPING_

## Interpolation

In [None]:
lreq.use_implicit_lreq.set(True)


def place(canvas, image, x, y):
    im_size = image.shape[2]
    if len(image.shape) == 4:
        image = image[0]
    canvas[:, y: y + im_size, x: x + im_size] = image * 0.5 + 0.5

def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    rnd = np.random.RandomState(4)
    latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)

    pathA = '00001.png'
    pathB = '00022.png'
    pathC = '00077.png'
    pathD = '00016.png'

    def open_image(filename):
        img = np.asarray(Image.open(path + '/' + filename))
        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        _latents = encode(x[None, ...].cuda())
        latents = _latents[0, 0]
        return latents

    def make(w):
        with torch.no_grad():
            w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)
            x_rec = decode(w)
            return x_rec

    wa = open_image(pathA)
    wb = open_image(pathB)
    wc = open_image(pathC)
    wd = open_image(pathD)

    height = 7
    width = 7

    images = []
    for i in range(height):
        for j in range(width):
            kv = i / (height - 1.0)
            kh = j / (width - 1.0)

            ka = (1.0 - kh) * (1.0 - kv)
            kb = kh * (1.0 - kv)
            kc = (1.0 - kh) * kv
            kd = kh * kv

            w = ka * wa + kb * wb + kc * wc + kd * wd

            interpolated = make(w)
            images.append(interpolated)
 
    images = torch.cat(images)
    grid_ = make_grid(images, nrow=width)
    wandb.log({"Interpolations from {} Images".format(cfg.NAME): [wandb.Image(torch.Tensor(grid_ *0.5 + 0.5),caption="{} Interpolation Output".format(cfg.NAME))]})
    #save_image(images * 0.5 + 0.5, 'make_figures/output/%s/interpolations.png' % cfg.NAME, nrow=width)
    #save_image(images * 0.5 + 0.5, 'make_figures/output/%s/interpolations.jpg' % cfg.NAME, nrow=width)

In [None]:
all_configs = ['configs/bedroom.yaml','configs/celeba.yaml','configs/ffhq.yaml']
for config in all_configs[1:]:
  cfg, logger, device = parse_configs(config)
  sample(cfg, logger)

2020-06-11 18:50:23,398 logger INFO: World size: 1
2020-06-11 18:50:23,398 logger INFO: World size: 1
2020-06-11 18:50:23,398 logger INFO: World size: 1
2020-06-11 18:50:23,398 logger INFO: World size: 1
2020-06-11 18:50:23,398 logger INFO: World size: 1
2020-06-11 18:50:23,401 logger INFO: Loaded configuration file configs/celeba.yaml
2020-06-11 18:50:23,401 logger INFO: Loaded configuration file configs/celeba.yaml
2020-06-11 18:50:23,401 logger INFO: Loaded configuration file configs/celeba.yaml
2020-06-11 18:50:23,401 logger INFO: Loaded configuration file configs/celeba.yaml
2020-06-11 18:50:23,401 logger INFO: Loaded configuration file configs/celeba.yaml
2020-06-11 18:50:23,405 logger INFO: 
# Config for training ALAE on CelebA at resolution 128x128

NAME: celeba
PPL_CELEBA_ADJUSTMENT: True
DATASET:
  PART_COUNT: 16
  SIZE: 182637
  SIZE_TEST: 202576 - 182637
  PATH: /data/datasets/celeba/tfrecords/celeba-r%02d.tfrecords.%03d
  PATH_TEST: /data/datasets/celeba-test/tfrecords/cel

[34m[1mwandb[0m: [32m[41mERROR[0m Error uploading "media/images/Interpolations from ffhq Images_5.png": CommError, File /tmp/tmpyqo4168kwandb/fda8oh2v-media/images/Interpolations from ffhq Images_5.png size shrank from 20909841 to 0 while it was being uploaded.
[34m[1mwandb[0m: [32m[41mERROR[0m Error uploading "media/images/Interpolations from ffhq Images_5.png": CommError, File /tmp/tmpyqo4168kwandb/2294s9q9-media/images/Interpolations from ffhq Images_5.png size shrank from 42409585 to 0 while it was being uploaded.
[34m[1mwandb[0m: [32m[41mERROR[0m Error uploading "media/images/Interpolations from ffhq Images_5.png": CommError, File /tmp/tmpyqo4168kwandb/4tl5ywdy-media/images/Interpolations from ffhq Images_5.png size shrank from 61615149 to 0 while it was being uploaded.


## Traversal
Only available for FFHQ (Need to compute principal directions for others)

In [None]:
lreq.use_implicit_lreq.set(True)

# Generate GIF
def log_gif(ims, fname, log_name):
        ims[0].save(fname, save_all=True, append_images=ims[1:], duration=150, fps=1, loop=0)
        wandb.log({"{}".format(log_name): wandb.Video(fname, fps=4, format="gif")})

# Helper to generate Images for GIF generation
def prep_list(torch_list):
  """
    Given a B x C X H X W list of tensors,
    returns a PIL.Image list 
  """
  ims = []
  for im in torch_list:
    im = im * 0.5 + 0.5
    arr = im.permute(2,3,1,0).squeeze().cpu().numpy()
    # normalize the data to 0 - 1
    arr = (arr - np.min(arr))/np.ptp(arr)

    arr = 255 * arr # Now scale by 255
    arr = arr.astype(np.uint8)
    dims = (512,512)
    img = Image.fromarray(arr, 'RGB')
    img = img.resize(dims)
    ims.append(img)
    ims.append(img)

  return ims

def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)
    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    extra_checkpoint_data = checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    def encode(x):
        Z, _ = model.encode(x, layer_count - 1, 1)
        Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
        return Z

    def decode(x):
        layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
        ones = torch.ones(layer_idx.shape, dtype=torch.float32)
        coefs = torch.where(layer_idx < model.truncation_cutoff, ones, ones)
        # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
        return model.decoder(x, layer_count - 1, 1, noise=True)

    path = cfg.DATASET.SAMPLES_PATH
    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)

    def do_attribute_traversal(path, attrib_idx, start, end):
        img = np.asarray(Image.open(path))
        if img.shape[2] == 4:
            img = img[:, :, :3]
        im = img.transpose((2, 0, 1))
        x = torch.tensor(np.asarray(im, dtype=np.float32), device='cpu', requires_grad=True).cuda() / 127.5 - 1.
        if x.shape[0] == 4:
            x = x[:3]
        factor = x.shape[2] // im_size
        if factor != 1:
            x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
        assert x.shape[2] == im_size
        _latents = encode(x[None, ...].cuda())
        latents = _latents[0, 0]

        latents -= model.dlatent_avg.buff.data[0]

        w0 = torch.tensor(np.load("principal_directions/direction_%d.npy" % attrib_idx), dtype=torch.float32)

        attr0 = (latents * w0).sum()

        latents = latents - attr0 * w0

        def update_image(w):
            with torch.no_grad():
                w = w + model.dlatent_avg.buff.data[0]
                w = w[None, None, ...].repeat(1, model.mapping_fl.num_layers, 1)

                layer_idx = torch.arange(model.mapping_fl.num_layers)[np.newaxis, :, np.newaxis]
                cur_layers = (7 + 1) * 2
                mixing_cutoff = cur_layers
                styles = torch.where(layer_idx < mixing_cutoff, w, _latents[0])

                x_rec = decode(styles)
                return x_rec

        traversal = []

        r = 7
        inc = (end - start) / (r - 1)

        for i in range(r):
            W = latents + w0 * (attr0 + start)
            im = update_image(W)

            traversal.append(im)
            attr0 += inc
        res = torch.cat(traversal)

        indices = [0, 1, 2, 3, 4, 10, 11, 17, 19]
        labels = ["gender",
                  "smile",
                  "attractive",
                  "wavy-hair",
                  "young",
                  "big_lips",
                  "big_nose",
                  "chubby",
                  "glasses",
                  ]
        label_ = labels[indices.index(attrib_idx)]
        return res * 0.5 + 0.5, label_, traversal
        
        #save_image(res * 0.5 + 0.5, "make_figures/output/%s/traversal_%s.jpg" % (
        #    cfg.NAME, labels[indices.index(attrib_idx)]), pad_value=1)

    res, label_, traversal = do_attribute_traversal(path + '/00049.png', 0, 0.6, -34)
    ims = prep_list(traversal)
    log_gif(ims, '{}.gif'.format(label_), "Traversals of the {} Attribute".format(label_))

    res, label_, traversal = do_attribute_traversal(path + '/00125.png', 1, -3, 15.0)
    ims = prep_list(traversal)
    log_gif(ims, '{}.gif'.format(label_), "Traversals of the {} Attribute".format(label_))

    res, label_, traversal = do_attribute_traversal(path + '/00057.png', 3, -2, 30.0)
    ims = prep_list(traversal)
    log_gif(ims, '{}.gif'.format(label_), "Traversals of the {} Attribute".format(label_))

    res, label_, traversal = do_attribute_traversal(path + '/00031.png', 4, -10, 30.0)
    ims = prep_list(traversal)
    log_gif(ims, '{}.gif'.format(label_), "Traversals of the {} Attribute".format(label_))

    res, label_, traversal = do_attribute_traversal(path + '/00088.png', 10, -0.3, 30.0)
    ims = prep_list(traversal)
    log_gif(ims, '{}.gif'.format(label_), "Traversals of the {} Attribute".format(label_))

    res, label_, traversal = do_attribute_traversal(path + '/00004.png', 11, -25, 20.0)
    ims = prep_list(traversal)
    log_gif(ims, '{}.gif'.format(label_), "Traversals of the {} Attribute".format(label_))

    res, label_, traversal = do_attribute_traversal(path + '/00012.png', 17, -40, 40.0)
    ims = prep_list(traversal)
    log_gif(ims, '{}.gif'.format(label_), "Traversals of the {} Attribute".format(label_))

    res, label_, traversal = do_attribute_traversal(path + '/00017.png', 19, 0, 30.0)
    ims = prep_list(traversal)
    log_gif(ims, '{}.gif'.format(label_), "Traversals of the {} Attribute".format(label_))

In [None]:
all_configs = ['configs/ffhq.yaml']
for config in all_configs:
  cfg, logger, device = parse_configs(config)
  sample(cfg, logger)

2020-06-11 18:51:43,830 logger INFO: World size: 1
2020-06-11 18:51:43,830 logger INFO: World size: 1
2020-06-11 18:51:43,830 logger INFO: World size: 1
2020-06-11 18:51:43,830 logger INFO: World size: 1
2020-06-11 18:51:43,830 logger INFO: World size: 1
2020-06-11 18:51:43,830 logger INFO: World size: 1
2020-06-11 18:51:43,830 logger INFO: World size: 1
2020-06-11 18:51:43,841 logger INFO: Loaded configuration file configs/ffhq.yaml
2020-06-11 18:51:43,841 logger INFO: Loaded configuration file configs/ffhq.yaml
2020-06-11 18:51:43,841 logger INFO: Loaded configuration file configs/ffhq.yaml
2020-06-11 18:51:43,841 logger INFO: Loaded configuration file configs/ffhq.yaml
2020-06-11 18:51:43,841 logger INFO: Loaded configuration file configs/ffhq.yaml
2020-06-11 18:51:43,841 logger INFO: Loaded configuration file configs/ffhq.yaml
2020-06-11 18:51:43,841 logger INFO: Loaded configuration file configs/ffhq.yaml
2020-06-11 18:51:43,855 logger INFO: 
 # Config for training ALAE on FFHQ at

## Generation

In [None]:
def draw_uncurated_result_figure(cfg, png, model, cx, cy, cw, ch, rows, lods, seed):
    print(png)
    N = sum(rows * 2**lod for lod in lods)
    images = []

    rnd = np.random.RandomState(5)
    for i in range(N):
        latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
        samplez = torch.tensor(latents).float().cuda()
        image = model.generate(cfg.DATASET.MAX_RESOLUTION_LEVEL-2, 1, samplez, 1, mixing=True)
        images.append(image[0])

    canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white')
    image_iter = iter(list(images))
    for col, lod in enumerate(lods):
        for row in range(rows * 2**lod):
            im = next(image_iter).cpu().numpy()
            im = im.transpose(1, 2, 0)
            im = im * 0.5 + 0.5
            image = PIL.Image.fromarray(np.clip(im * 255, 0, 255).astype(np.uint8), 'RGB')
            image = image.crop((cx, cy, cx + cw, cy + ch))
            image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS)
            canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod))
    #canvas.save(png)
    return canvas

def sample(cfg, logger):
    torch.cuda.set_device(0)
    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        channels=cfg.MODEL.CHANNELS,
        generator=cfg.MODEL.GENERATOR,
        encoder=cfg.MODEL.ENCODER)

    model.cuda(0)
    model.eval()
    model.requires_grad_(False)

    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl

    dlatent_avg = model.dlatent_avg

    logger.info("Trainable parameters generator:")
    count_parameters(decoder)

    logger.info("Trainable parameters discriminator:")
    count_parameters(encoder)

    arguments = dict()
    arguments["iteration"] = 0

    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }

    checkpointer = Checkpointer(cfg,
                                model_dict,
                                {},
                                logger=logger,
                                save=False)

    checkpointer.load()

    model.eval()

    layer_count = cfg.MODEL.LAYER_COUNT

    decoder = nn.DataParallel(decoder)

    im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)
    with torch.no_grad():
        res = draw_uncurated_result_figure(cfg, 'make_figures/output/%s/generations.jpg' % cfg.NAME,
                                     model, cx=0, cy=0, cw=im_size, ch=im_size, rows=6, lods=[0, 0, 0, 1, 1, 2], seed=5)
        wandb.log({"Generation from {} Images".format(cfg.NAME): [wandb.Image(res)]})

In [None]:
all_configs = ['configs/bedroom.yaml','configs/celeba.yaml','configs/ffhq.yaml']
for config in all_configs:
  cfg, logger, device = parse_configs(config)
  sample(cfg, logger)

2020-06-11 18:52:02,790 logger INFO: World size: 1
2020-06-11 18:52:02,790 logger INFO: World size: 1
2020-06-11 18:52:02,790 logger INFO: World size: 1
2020-06-11 18:52:02,790 logger INFO: World size: 1
2020-06-11 18:52:02,790 logger INFO: World size: 1
2020-06-11 18:52:02,790 logger INFO: World size: 1
2020-06-11 18:52:02,790 logger INFO: World size: 1
2020-06-11 18:52:02,790 logger INFO: World size: 1
2020-06-11 18:52:02,798 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:52:02,798 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:52:02,798 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:52:02,798 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:52:02,798 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:52:02,798 logger INFO: Loaded configuration file configs/bedroom.yaml
2020-06-11 18:52:02,798 logger INFO: Loaded configuration file configs/bedroom.yaml
2020

[34m[1mwandb[0m: [32m[41mERROR[0m Error uploading "media/images/Generation from ffhq Images_16.png": CommError, File /tmp/tmpb0npts29wandb/3qn3br0y-media/images/Generation from ffhq Images_16.png size shrank from 27661285 to 0 while it was being uploaded.
[34m[1mwandb[0m: [32m[41mERROR[0m Error uploading "media/images/Generation from ffhq Images_16.png": CommError, File /tmp/tmpb0npts29wandb/2tov9hpa-media/images/Generation from ffhq Images_16.png size shrank from 32774029 to 0 while it was being uploaded.
