# Latent tour

In [None]:
%matplotlib inline

import sys
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import torch
import imageio
from scipy import interpolate

sys.path.append("../../")
from experiments.datasets import FFHQStyleGAN2DLoader
from experiments.architectures.image_transforms import create_image_transform, create_image_encoder
from experiments.architectures.vector_transforms import create_vector_transform
from manifold_flow.flows import ManifoldFlow, EncoderManifoldFlow


## Load models

In [None]:
def load_model(
    filename,
    outerlayers=20,
    innerlayers=6,
    levels=4,
    splinebins=11,
    splinerange=10.0,
    dropout=0.0,
    actnorm=True,
    batchnorm=False,
    linlayers=2,
    linchannelfactor=1,
    lineartransform="lu"
):
    steps_per_level = outerlayers // levels
    spline_params = {
        "apply_unconditional_transform": False,
        "min_bin_height": 0.001,
        "min_bin_width": 0.001,
        "min_derivative": 0.001,
        "num_bins": splinebins,
        "tail_bound": splinerange,
    }
    outer_transform = create_image_transform(
        3,
        64,
        64,
        levels=levels,
        hidden_channels=100,
        steps_per_level=steps_per_level,
        num_res_blocks=2,
        alpha=0.05,
        num_bits=8,
        preprocessing="glow",
        dropout_prob=dropout,
        multi_scale=True,
        spline_params=spline_params,
        postprocessing="partial_mlp",
        postprocessing_layers=linlayers,
        postprocessing_channel_factor=linchannelfactor,
        use_actnorm=actnorm,
        use_batchnorm=batchnorm,
    )
    inner_transform = create_vector_transform(
        2,
        innerlayers,
        linear_transform_type=lineartransform,
        base_transform_type="rq-coupling",
        context_features=None,
        dropout_probability=dropout,
        tail_bound=splinerange,
        num_bins=splinebins,
        use_batch_norm=batchnorm,
    )
    model = ManifoldFlow(
        data_dim=(3, 64, 64),
        latent_dim=2,
        outer_transform=outer_transform,
        inner_transform=inner_transform,
        apply_context_to_outer=False,
        pie_epsilon=0.1,
        clip_pie=None
    )
        
    model.load_state_dict(
        torch.load("../data/models/{}.pt".format(filename), map_location=torch.device("cpu"))
    )
    _ = model.eval()
    
    return model

In [None]:
def load_emf_model(
    filename,
    outerlayers=20,
    innerlayers=6,
    levels=4,
    splinebins=11,
    splinerange=10.0,
    dropout=0.0,
    actnorm=True,
    batchnorm=False,
    linlayers=2,
    linchannelfactor=1,
    lineartransform="lu"
):
    steps_per_level = outerlayers // levels
    spline_params = {
        "apply_unconditional_transform": False,
        "min_bin_height": 0.001,
        "min_bin_width": 0.001,
        "min_derivative": 0.001,
        "num_bins": splinebins,
        "tail_bound": splinerange,
    }
    encoder = create_image_encoder(
        3,
        64,
        64,
        latent_dim=2,
        context_features=None,
    )
    outer_transform = create_image_transform(
        3,
        64,
        64,
        levels=levels,
        hidden_channels=100,
        steps_per_level=steps_per_level,
        num_res_blocks=2,
        alpha=0.05,
        num_bits=8,
        preprocessing="glow",
        dropout_prob=dropout,
        multi_scale=True,
        spline_params=spline_params,
        postprocessing="partial_mlp",
        postprocessing_layers=linlayers,
        postprocessing_channel_factor=linchannelfactor,
        use_actnorm=actnorm,
        use_batchnorm=batchnorm,
    )
    inner_transform = create_vector_transform(
        2,
        innerlayers,
        linear_transform_type=lineartransform,
        base_transform_type="rq-coupling",
        context_features=None,
        dropout_probability=dropout,
        tail_bound=splinerange,
        num_bins=splinebins,
        use_batch_norm=batchnorm,
    )
    model = EncoderManifoldFlow(
        data_dim=(3, 64, 64),
        latent_dim=2,
        encoder=encoder,
        outer_transform=outer_transform,
        inner_transform=inner_transform,
        apply_context_to_outer=False,
        pie_epsilon=0.1,
        clip_pie=None
    )
        
    model.load_state_dict(
        torch.load("../data/models/{}.pt".format(filename), map_location=torch.device("cpu"))
    )
    _ = model.eval()
    
    return model

In [None]:
mf = load_model("mf_2_gan2d_april")
emf = load_emf_model("emf_2_gan2d_april")
pie = load_model("pie_2_gan2d_april")

## Master function

In [None]:
def make_tour(model, z_checkpoints, filename, n_frames=200, fps=25):
    n_checkpoints = len(z_checkpoints)
    checkpoint_frames = [int(round(i * n_frames / n_checkpoints, 0)) for i in range(n_checkpoints)]
    
    z_checkpoints = np.concatenate((z_checkpoints, z_checkpoints[0:1]), axis=0)
    tck, u = interpolate.splprep(z_checkpoints.T, s=0, per=True)
    z_frames = np.array(interpolate.splev(np.linspace(0, 1, n_frames), tck)).T
    
    ims = []
    for z in z_frames:
        x = model.decode(torch.tensor(z).to(torch.float).unsqueeze(0)).squeeze().detach().numpy()
        x = np.transpose(x, [1,2,0])
        ims.append(x)
    
    imageio.mimsave(filename, ims, 'GIF-FI', duration=1.0/fps)
    

## Circle tour

In [None]:
n_frames = 200

ts = np.linspace(0., 2. * np.pi, n_frames + 1)[:-1]
zs = np.vstack((np.cos(ts), np.sin(ts))).T

make_tour(mf, zs, n_frames=n_frames, filename="../figures/gan2d_tour_circle.gif")


## Tour from a real image

In [None]:
x_reals = np.array([plt.imread("../data/merle.jpg")])
x_reals_ = torch.tensor(x_reals.transpose(0,3,1,2)).to(torch.float)
z_reals = mf.encode(x_reals_).detach().numpy()
z_reals

In [None]:
z_tour = np.array([
    z_reals[0],
    [1.0, 0.],
    [1., 1.],
    [0., 0.9],
    [-0.7,-0.7],
    [-1.0,0.0],
    [-0.5, -1.5],
    [0.0, -1.2],
])

In [None]:
make_tour(mf, z_tour, n_frames=200, filename="../figures/merle.gif")