# Latent tour

In [1]:
%matplotlib inline

import sys
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import torch
import imageio
from scipy import interpolate

sys.path.append("../../")
from experiments.datasets import FFHQStyleGAN2DLoader
from experiments.architectures.image_transforms import create_image_transform, create_image_encoder
from experiments.architectures.vector_transforms import create_vector_transform
from manifold_flow.flows import ManifoldFlow, EncoderManifoldFlow


## Load models

In [4]:
def load_model(
    filename,
    latentdim=2,
    outerlayers=20,
    innerlayers=6,
    levels=4,
    splinebins=11,
    splinerange=10.0,
    dropout=0.0,
    actnorm=True,
    batchnorm=False,
    contextfeatures=None,
    linlayers=2,
    linchannelfactor=1,
    lineartransform="lu",
):
    steps_per_level = outerlayers // levels
    outer_transform = create_image_transform(
        3,
        64,
        64,
        levels=levels,
        hidden_channels=100,
        steps_per_level=steps_per_level,
        num_res_blocks=2,
        alpha=0.05,
        num_bits=8,
        preprocessing="glow",
        dropout_prob=dropout,
        multi_scale=True,
        num_bins=splinebins,
        tail_bound=splinerange,
        postprocessing="partial_mlp",
        postprocessing_layers=linlayers,
        postprocessing_channel_factor=linchannelfactor,
        use_actnorm=actnorm,
        use_batchnorm=batchnorm,
    )
    inner_transform = create_vector_transform(
        latentdim,
        innerlayers,
        linear_transform_type=lineartransform,
        base_transform_type="rq-coupling",
        context_features=contextfeatures,
        dropout_probability=dropout,
        tail_bound=splinerange,
        num_bins=splinebins,
        use_batch_norm=batchnorm,
    )
    model = ManifoldFlow(
        data_dim=(3, 64, 64),
        latent_dim=latentdim,
        outer_transform=outer_transform,
        inner_transform=inner_transform,
        apply_context_to_outer=False,
        pie_epsilon=0.1,
        clip_pie=None
    )
        
    model.load_state_dict(
        torch.load("../data/models/{}.pt".format(filename), map_location=torch.device("cpu"))
    )
    _ = model.eval()
    
    return model

In [5]:
mf2 = load_model("mf_2_gan2d_april")
mf64 = load_model("mf_64_gan64d_april", innerlayers=8, linchannelfactor=2, latentdim=64, contextfeatures=1)
mfc = load_model("mf_128_celeba_april_run1", innerlayers=8, linchannelfactor=2, latentdim=128)

	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple)


## Master functions

In [6]:
def spline_tour(model, z_checkpoints, filename, n_frames=200, fps=25, context=None):
    z_checkpoints = np.concatenate((z_checkpoints, z_checkpoints[0:1]), axis=0)
    tck, u = interpolate.splprep(z_checkpoints.T, s=0, per=True)
    z_frames = np.array(interpolate.splev(np.linspace(0, 1, n_frames), tck)).T
    
    ims = []
    for z in z_frames:
        x = model.decode(torch.tensor(z).to(torch.float).unsqueeze(0), context=None if context is None else context.unsqueeze(0)).squeeze().detach().numpy()
        x = np.transpose(x, [1,2,0])
        ims.append(x)
    
    imageio.mimsave(filename, ims, 'GIF-FI', duration=1.0/fps)
    

In [7]:
def linear_interpolator(inputs, n_frames):
    n_inputs = len(inputs)
    inputs_ = np.concatenate((inputs, inputs[0:1]), axis=0)
    t_frames = np.linspace(0.0, float(n_inputs), n_frames+1)[:-1]
    
    last = np.floor(t_frames).astype(np.int)
    next_ = np.ceil(t_frames).astype(np.int)
    alpha = t_frames - last.astype(np.float)
    
    outputs = (1. - alpha[:,np.newaxis]) * inputs_[last] + alpha[:,np.newaxis] * inputs_[next_]
    
    return outputs
    

def linear_tour(model, z_checkpoints, filename, n_frames=200, fps=25, context=None):
    z_frames = linear_interpolator(z_checkpoints, n_frames)
    
    ims = []
    for z in z_frames:
        x = model.decode(torch.tensor(z).to(torch.float).unsqueeze(0), context=None if context is None else context.unsqueeze(0)).squeeze().detach().numpy()
        x = np.transpose(x, [1,2,0])
        ims.append(x)
    
    imageio.mimsave(filename, ims, 'GIF-FI', duration=1.0/fps)
    

## n=2 circle tour

In [None]:
n_frames = 200

ts = np.linspace(0., 2. * np.pi, n_frames + 1)[:-1]
zs = np.vstack((np.cos(ts), np.sin(ts))).T

spline_tour(mf2, zs, n_frames=n_frames, filename="../figures/gan2d_tour_circle_mf.gif")


## CelebA tour from a real image

In [11]:
x_reals = np.array([
    plt.imread("../data/samples/celeba/test/17.jpg"),
    plt.imread("../data/samples/celeba/test/18.jpg"),
    plt.imread("../data/samples/celeba/test/19.jpg"),
    plt.imread("../data/samples/celeba/test/20.jpg"),
])
x_reals_ = torch.tensor(x_reals.transpose(0,3,1,2)).to(torch.float)
z_reals = mfc.encode(x_reals_).detach().numpy()

In [12]:
z_tour = np.array(z_reals)

In [13]:
linear_tour(mfc, z_tour, n_frames=200, filename="../figures/celeba_tour_test_mf.gif")

