# Latent tour

In [1]:
%matplotlib inline

import sys
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import torch
import imageio
from scipy import interpolate

sys.path.append("../../")
from experiments.datasets import FFHQStyleGAN2DLoader
from experiments.architectures.image_transforms import create_image_transform, create_image_encoder
from experiments.architectures.vector_transforms import create_vector_transform
from manifold_flow.flows import ManifoldFlow, EncoderManifoldFlow


## Load models

In [2]:
def load_model(
    filename,
    latentdim=2,
    outerlayers=20,
    innerlayers=6,
    levels=4,
    splinebins=11,
    splinerange=10.0,
    dropout=0.0,
    actnorm=True,
    batchnorm=False,
    contextfeatures=None,
    linlayers=2,
    linchannelfactor=1,
    lineartransform="lu",
):
    steps_per_level = outerlayers // levels
    spline_params = {
        "apply_unconditional_transform": False,
        "min_bin_height": 0.001,
        "min_bin_width": 0.001,
        "min_derivative": 0.001,
        "num_bins": splinebins,
        "tail_bound": splinerange,
    }
    outer_transform = create_image_transform(
        3,
        64,
        64,
        levels=levels,
        hidden_channels=100,
        steps_per_level=steps_per_level,
        num_res_blocks=2,
        alpha=0.05,
        num_bits=8,
        preprocessing="glow",
        dropout_prob=dropout,
        multi_scale=True,
        spline_params=spline_params,
        postprocessing="partial_mlp",
        postprocessing_layers=linlayers,
        postprocessing_channel_factor=linchannelfactor,
        use_actnorm=actnorm,
        use_batchnorm=batchnorm,
    )
    inner_transform = create_vector_transform(
        latentdim,
        innerlayers,
        linear_transform_type=lineartransform,
        base_transform_type="rq-coupling",
        context_features=contextfeatures,
        dropout_probability=dropout,
        tail_bound=splinerange,
        num_bins=splinebins,
        use_batch_norm=batchnorm,
    )
    model = ManifoldFlow(
        data_dim=(3, 64, 64),
        latent_dim=latentdim,
        outer_transform=outer_transform,
        inner_transform=inner_transform,
        apply_context_to_outer=False,
        pie_epsilon=0.1,
        clip_pie=None
    )
        
    model.load_state_dict(
        torch.load("../data/models/{}.pt".format(filename), map_location=torch.device("cpu"))
    )
    _ = model.eval()
    
    return model

In [5]:
mf2 = load_model("mf_2_gan2d_april")
mf64 = load_model("mf_64_gan64d_april_run1", innerlayers=8, linchannelfactor=2, latentdim=64, contextfeatures=1)

## Master function

In [25]:
def make_tour(model, z_checkpoints, filename, n_frames=200, fps=25, interpolate_dims=10, context=None):
    n_checkpoints = len(z_checkpoints)
    checkpoint_frames = [int(round(i * n_frames / n_checkpoints, 0)) for i in range(n_checkpoints)]
    
    z_fix = z_checkpoints[0, np.newaxis, interpolate_dims:] + np.zeros((n_frames, 1))
    z_checkpoints = np.concatenate((z_checkpoints[:, :interpolate_dims], z_checkpoints[0:1, :interpolate_dims]), axis=0)
    tck, u = interpolate.splprep(z_checkpoints.T, s=0, per=True)
    z_frames = np.array(interpolate.splev(np.linspace(0, 1, n_frames), tck)).T
    z_frames = np.concatenate((z_frames, z_fix), axis=1)
    
    ims = []
    for z in z_frames:
        x = model.decode(torch.tensor(z).to(torch.float).unsqueeze(0), context=context.unsqueeze(0)).squeeze().detach().numpy()
        x = np.transpose(x, [1,2,0])
        ims.append(x)
    
    imageio.mimsave(filename, ims, 'GIF-FI', duration=1.0/fps)
    

## n=2 circle tour

In [None]:
n_frames = 200

ts = np.linspace(0., 2. * np.pi, n_frames + 1)[:-1]
zs = np.vstack((np.cos(ts), np.sin(ts))).T

make_tour(mf2, zs, n_frames=n_frames, filename="../figures/gan2d_tour_circle_mf.gif")


## n=64 tour from a real image

In [27]:
param = torch.zeros((1,)).to(torch.float)
x_reals = np.array([plt.imread("../data/merle.jpg")])
x_reals_ = torch.tensor(x_reals.transpose(0,3,1,2)).to(torch.float)
z_reals = mf64.encode(x_reals_, context=param.unsqueeze(0)).detach().numpy()
z_reals

array([[-1.2029319 , -1.7852677 ,  1.267674  , -1.4313987 , -0.24781178,
         1.5943398 ,  0.70669454,  0.51296514,  0.33693933, -0.88983214,
        -1.9633608 ,  0.5775551 ,  0.45369503, -0.71944   , -1.7163948 ,
        -0.7424832 ,  2.0496745 ,  0.0384796 ,  0.9369421 , -2.0319183 ,
        -4.220821  ,  1.6531723 ,  0.4732769 , -0.672833  , -0.9516078 ,
         2.4319878 , -1.1113434 , -1.1410133 ,  1.4168046 , -2.5789747 ,
        -2.4077215 , -0.68494076,  1.2116934 , -1.7872741 ,  1.3651558 ,
        -0.60139227,  0.91517925, -1.2779715 ,  1.2163581 ,  1.5368803 ,
         0.45496592, -0.94389385,  2.3437946 ,  2.128365  ,  0.06525467,
        -0.07784646,  1.6460439 , -1.7622306 ,  4.755513  , -0.12985978,
         0.8114347 ,  0.04490801,  3.2975829 , -0.12547274,  3.7040296 ,
         0.15134586, -0.9683202 , -0.42259032,  0.9044508 , -0.21217161,
         1.157878  , -0.13085735,  1.5256104 , -0.93967944]],
      dtype=float32)

In [32]:
z_tour = np.array([
    z_reals[0],
    np.random.normal(size=64),
    np.random.normal(size=64),
    np.random.normal(size=64),
])
z_tour[:,:10]

array([[-1.20293188, -1.78526771,  1.26767397, -1.43139875, -0.24781178,
         1.59433985,  0.70669454,  0.51296514,  0.33693933, -0.88983214],
       [-2.4200444 , -0.42285216,  0.84097715,  0.16027569, -1.19188345,
        -0.35780866, -0.00586058,  0.60336087,  1.72084808,  1.31790089],
       [-1.14084624, -0.73506788, -0.70755633, -1.48814686, -0.33758468,
         1.94075239,  0.04507095, -0.68679437,  0.26435337,  1.20883743],
       [ 0.43246569, -0.34514785,  0.45194108, -0.22670406,  1.31825989,
        -0.64215863,  0.66855823,  0.53808919,  0.15377462,  0.46934764]])

In [31]:
make_tour(mf64, z_tour, n_frames=200, filename="../figures/merle.gif", context=param)

