In [None]:
import sys, os, time
# %pdb
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ['FOUNDATION_RUN_MODE'] = 'jupyter'
os.environ['FOUNDATION_SAVE_DIR'] = '/is/ei/fleeb/workspace/chome/trained_nets'
os.environ['FOUNDATION_DATA_DIR'] = '/is/ei/fleeb/workspace/local_data'
# %load_ext autoreload
# %autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim as O
import torch.distributions as distrib
import torch.multiprocessing as mp
import torchvision.models
import torchvision
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook as tqdm
import gym
import json
import inspect
import numpy as np
import h5py as hf
%matplotlib notebook
# %matplotlib tk
import matplotlib.pyplot as plt
import seaborn as sns
#plt.switch_backend('Qt5Agg') #('Qt5Agg')
import foundation as fd
from foundation import models
from foundation import util
from foundation import train as trn
import shutil
# from foundation import sim as SIM
#from foundation.util import replicate, Cloner
from scipy import stats
np.set_printoptions(linewidth=120)

In [None]:
root = '/is/ei/fleeb/workspace/media/hybrid/supplement'
savedir = '/is/ei/fleeb/workspace/media/hybrid/supplement_imgs'
util.create_dir(savedir)
# os.listdir(root)

In [None]:
imgs = ['gen.png', 'hybrids.png', 'originals.png', 'recs.png', 'interventions.png', 'latent.png']
im_names = ['gen', 'hyb', 'orig', 'rec']
traversals = []

In [None]:
combos = []

In [None]:
for d in os.listdir(root):
    droot = os.path.join(root, d)
    for m in os.listdir(droot):
        path = os.path.join(droot, m)
#         util.create_dir(dest)
        combos.append((d,m))
        for im in imgs:
            shutil.copy(os.path.join(path, 'viz', im), os.path.join(savedir, '{}_{}_{}'.format(d,m,im)))

In [None]:
combos

In [None]:
model_titles = {
    'fwae': 'FWAE', 'fvae':'FVAE', 'vae':'VAE', '4vae':'4-VAE', '16vae':'16-VAE', 'bvae':'4-VAE',
    's3-d0': 'AE (0,3)', 's1-d8':'AE (8,1)', 's1-d0':'AE (0,1)', 's1-d2': 'AE (2,1)',
}
data_titles = {
    'celeba': 'CelebA',
    'mpi3d-toy': 'MPI-3D Toy',
    'mpi3d-real': 'MPI-3D Real',
    '3dshapes': 'Shapes-3D',
    'arch': 'Shapes-3D',
    'celeba-adain': 'CelebA',
    '3dshapes-adain': 'Shapes-3D',
    'pacman': 'Pacman',
    'spaceinv': 'SpaceInvaders',
}

In [None]:
tmpl = f'''
\begin{{figure}}
    \centering
    \subfigure[]{{\includegraphics[width=0.24\textwidth]{{figures/{:orig}}}}} 
    \subfigure[]{{\includegraphics[width=0.24\textwidth]{{figures/{:rec}}}}} 
    \subfigure[]{{\includegraphics[width=0.24\textwidth]{{figures/{:hyb}}}}}
    \subfigure[]{{\includegraphics[width=0.24\textwidth]{{figures/{:gen}}}}}
    \caption{{{:dataset} {:model} (a) Original (b) Reconstructions (c) Hybrid (d) Prior}}
    \label{{{fig}}}
\end{{figure}}
'''

In [None]:
counter = 0
for d, m in combos:
    
    model = m.upper()
    if m in model_titles:
        model = model_titles[m]
        
    if 'adain' in d:
        model += ' (4,2)'
    
    data = data_titles[d]
    
#     print(tmpl.format(dataset=data, model=model, **fnames))
    
    lat = '{}_{}_{}'.format(d,m,'latent.png')
    inv = '{}_{}_{}'.format(d,m,'interventions.png')
    
    fig = f'{d}_{m}_latent'
    
    
    tmpl = f'''
\\begin{{figure}}
    \\centering
    \\subfigure[]{{\\includegraphics[width=0.49\\textwidth]{{figures/{lat}}}}} 
    \\subfigure[]{{\\includegraphics[width=0.49\\textwidth]{{figures/{inv}}}}} 
    \\caption{{{data} {model} (a) Latent Distributions (b) Image Effects}}
    \\label{{fig:{fig}}}
\\end{{figure}}
'''
    print(tmpl) 
#     counter += 1
#     print('{}/{}'.format(counter,len(combos)))
    
#     break

In [None]:
counter = 0
for d, m in combos:
    
    model = m.upper()
    if m in model_titles:
        model = model_titles[m]
        
    if 'adain' in d:
        model += ' (4,2)'
    
    data = data_titles[d]
    
    fnames = { k: '{}_{}_{}'.format(d,m,im) for k,im in zip(im_names, imgs) }
    
#     print(tmpl.format(dataset=data, model=model, **fnames))
    
    orig = fnames['orig']
    rec = fnames['rec']
    hyb = fnames['hyb']
    gen = fnames['gen']
    
    fig = f'{d}_{m}'
    
    
    tmpl = f'''
\\begin{{figure}}
    \\centering
    \\subfigure[]{{\\includegraphics[width=0.24\\textwidth]{{figures/{orig}}}}} 
    \\subfigure[]{{\\includegraphics[width=0.24\\textwidth]{{figures/{rec}}}}} 
    \\subfigure[]{{\\includegraphics[width=0.24\\textwidth]{{figures/{hyb}}}}}
    \\subfigure[]{{\\includegraphics[width=0.24\\textwidth]{{figures/{gen}}}}}
    \\caption{{{data} {model} (a) Original (b) Reconstructions (c) Hybrid (d) Prior}}
    \\label{{fig:{fig}}}
\\end{{figure}}
'''
    print(tmpl) 
#     counter += 1
#     print('{}/{}'.format(counter,len(combos)))
    
#     break

In [None]:
savedir = '/is/ei/fleeb/workspace/media/hybrid/supplement_vids'
name = 'traversals/walk{}.mp4'
util.create_dir(savedir)

In [None]:
for d in os.listdir(root):
    droot = os.path.join(root, d)
    for m in os.listdir(droot):
        path = os.path.join(droot, m)
#         util.create_dir(dest)
#         combos.append((d,m))
        idx = 1
    
        model = m.upper()
        if m in model_titles:
            model = model_titles[m]

        if 'adain' in d:
            model += ' (4,2)'

        data = data_titles[d]
    
        try:
            shutil.copy(os.path.join(path, 'viz', name.format(idx)), 
                        os.path.join(savedir, '{}_{}_trav{}.mp4'.format(data.replace(' ','-'),
                                                                        model.replace(' ','-'),idx)))
        except:
            print(d,m,'failed')
#             shutil.copy(os.path.join(path, 'viz', name.format(0)), 
#                         os.path.join(savedir, '{}_{}_trav{}.mp4'.format(data.replace(' ','-'),model.replace(' ','-'),idx)))
            