In [None]:
import os
import sys
import argparse
import collections
import numpy as np
import torch
from scipy.io import savemat
from tqdm import trange
from torchvision.utils import save_image
from torch.utils.data import DataLoader

sys.path.insert(0, 'src')
import config
import utils

In [None]:
class dotdict(dict):
  __getattr__ = dict.get
  __setattr__ = dict.__setitem__
  __delattr__ = dict.__delitem__
  
args = dotdict()
args.data_dir = '' # set these guys accordingly
args.exp_dir = '' # 
args.device = 'cuda'
args.dataset = 'celebA'
args.red_rate = 0.0
args.test_split = 0.2
args.validation_split = 0.0
args.d_latent = 128

if not torch.cuda.is_available():
  args.device = 'cpu'

In [None]:
train_set, test_set = config.load_dataset(args)

enc, dec = config.load_model(args)
enc.eval(); dec.eval()

print ('loading pretrained auto-encoder checkpoint')
ckpt_file = os.path.join(args.exp_dir, 'model.ckpt')
if args.device == 'cpu':
  ckpt = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
else:
  ckpt = torch.load(ckpt_file)

enc_ckpt = collections.OrderedDict()
for k, v in ckpt['encoder'].items():
  enc_ckpt[k.replace('module.', '')] = v
enc.load_state_dict(enc_ckpt)

dec_ckpt = collections.OrderedDict()
for k, v in ckpt['decoder'].items():
  dec_ckpt[k.replace('module.', '')] = v
dec.load_state_dict(dec_ckpt)

In [None]:
n_compare = 64
for _setname, _set in zip(['train', 'test'], \
                          [train_set, test_set]):
  print ('generating samples for {} set'.format(_setname))

  inds = np.random.choice(len(_set), n_compare)
  x_comb = []
  for i in inds:
    x_comb.append(_set[i].unsqueeze(0))
    x_comb.append(
      dec(
        enc(x_comb[-1].to(args.device))
      ).detach().to('cpu'))

  save_image(
    torch.cat(x_comb, 0), 
    os.path.join(args.exp_dir, 'x_vs_xrec_{}.png'.format(_setname)),
    16)
    

In [None]:
n_steps = 10

for _ in range(5):
  source_idx, target_idx = np.random.choice(len(test_set), 2)
  x_source = test_set[source_idx].unsqueeze(0).detach().to(args.device)
  x_target = test_set[target_idx].unsqueeze(0).detach().to(args.device)

  z_source = enc(x_source)
  z_target = enc(x_target)
  xrec_source = dec(z_source)
  xrec_target = dec(z_target)

  z_diff = z_target - z_source
  z_step = z_diff / n_steps
  x_diff =  x_target - x_source
  x_step = x_diff / n_steps

  x_interp = [x_source + x_step * i for i in range(n_steps)]
  xrec_interp = [dec(z_source + z_step * i) for i in range(n_steps)]
  save_image(torch.cat(x_interp).to('cpu'), os.path.join(args.exp_dir, 'x_interp_{}-{}.png'.format(source_idx, target_idx)), n_steps)
  save_image(torch.cat(xrec_interp).to('cpu'), os.path.join(args.exp_dir, 'xrec_interp_{}-{}.png'.format(source_idx, target_idx)), n_steps)

In [None]:
for source_idx in np.random.choice(len(test_set), 10):
  x_source = test_set[source_idx].unsqueeze(0).detach().to(args.device)
  z_source = enc(x_source)
  n_step = 14
  xrec_targets = []

  for dim in trange(128):
    dim_step = 0.2 # z_std[dim].to(args.device) / n_step

    for step in range(n_step // 2, -1, -1):
      z_target = z_source.clone()
      z_target[0, dim] -= dim_step * step
      xrec_targets.append(dec(z_target).detach().to('cpu'))

    for step in range(n_step // 2):
      z_target = z_source.clone()
      z_target[0, dim] += dim_step * step
      xrec_targets.append(dec(z_target).detach().to('cpu'))

  save_image(
    torch.cat(xrec_targets).to('cpu'), 
    os.path.join(args.exp_dir, 'xrec_dim-interp_{}.png'.format(source_idx)), 
    n_step + 1)