In [5]:
import os
import json
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

from src.data_loader import Shifted_Data_Loader
from src.plot import orig_vs_transformed as plot_ovt
from src.plot import enc_dec_samples
from src.models import GResNet,EDense
from src.config import get_config
from src.trainer import Trainer
from src.utils import prepare_dirs_and_logger
from keras.datasets import fashion_mnist,mnist
from keras.layers import Dense
# from tabulate import tabulate

In [6]:
config,_ = get_config()
setattr(config, 'batch_size', 512)
setattr(config, 'dataset', 'fashion_mnist')
setattr(config, 'epochs', 100)
setattr(config, 'enc_layers', [3000,2000])
setattr(config, 'dec_blocks', [4,2,1])
setattr(config, 'z_dim', 25)
setattr(config, 'y_dim', 10)
setattr(config, 'xcov', 10)
setattr(config, 'recon', 20)
setattr(config, 'proj_root', '/home/elijahc/projects/vae')
setattr(config, 'log_dir', '/home/elijahc/projects/vae/logs')
setattr(config, 'dev_mode',False)
setattr(config, 'monitor', 'val_G_loss')
setattr(config, 'min_delta', 0.5)
setattr(config, 'optimizer', 'adam')
# setattr(config, 'xcov', None)
# setattr(config,'model_dir','/home/elijahc/projects/vae/models/2019-01-17/')

In [7]:
if not config.dev_mode:
    print('setting up...')
    prepare_dirs_and_logger(config)
    
vars(config)

setting up...
/home/elijahc/projects/vae/logs/0128_134026_fashion_mnist/models  does not exist...
creating...
symlinking /home/elijahc/projects/vae/logs/0128_134026_fashion_mnist -> /home/elijahc/projects/vae/models/2019-01-28/0128_134026_fashion_mnist
/home/elijahc/projects/vae/data/2019-01-28  does not exist...
creating...
symlinking /home/elijahc/projects/vae/logs/0128_134026_fashion_mnist -> /home/elijahc/projects/vae/data/2019-01-28/0128_134026_fashion_mnist
symlinking /home/elijahc/projects/vae/logs/0128_134026_fashion_mnist -> /home/elijahc/projects/vae/figures/2019-01-28/0128_134026_fashion_mnist


{'batch_size': 512,
 'data_dir': '/home/elijahc/projects/vae/logs/0128_134026_fashion_mnist/data',
 'dataset': 'fashion_mnist',
 'dec_blocks': [4, 2, 1],
 'dev_mode': False,
 'enc_layers': [3000, 2000],
 'epochs': 100,
 'fig_dir': '/home/elijahc/projects/vae/logs/0128_134026_fashion_mnist/figures',
 'log_dir': '/home/elijahc/projects/vae/logs',
 'log_level': 'INFO',
 'min_delta': 0.5,
 'model_dir': '/home/elijahc/projects/vae/logs/0128_134026_fashion_mnist/models',
 'model_name': '0128_134026_fashion_mnist',
 'monitor': 'val_G_loss',
 'optimizer': 'adam',
 'proj_root': '/home/elijahc/projects/vae',
 'recon': 20,
 'run_dir': '/home/elijahc/projects/vae/logs/0128_134026_fashion_mnist',
 'xcov': 10,
 'xent': 10,
 'y_dim': 10,
 'z_dim': 25}

In [None]:
from src.utils import export
export(config,'model')

In [None]:
translation_amt = 0.3
DL = Shifted_Data_Loader(dataset=config.dataset,flatten=True,
                         rotation=None,
                         translation=translation_amt,
                        )

In [None]:
G_builder = GResNet(y_dim=config.y_dim,z_dim=config.z_dim,dec_blocks=config.dec_blocks)
E_builder = EDense(enc_layers=config.enc_layers,z_dim=config.z_dim,)
trainer = Trainer(config,DL,E_builder,G_builder,)
# setattr(trainer.config,'model_dir','/home/elijahc/projects/vae/models/2019-01-22/')

In [None]:
pt,idx = plot_ovt(DL,cmap='gray')

In [None]:
# trainer.build_model()
trainer.compile_model()
trainer.G.summary()

In [None]:
from keras.utils import to_categorical
RF = to_categorical(np.ones(len(DL.sx_train)),num_classes=2)

In [None]:
trainer.go(x=DL.sx_train,
           y={'class':DL.y_train_oh,'D':RF,'G':DL.sx_train},
           validation_split=0.05,
           verbose=0)

In [None]:
# trainer.go_gen(DL.train_generator(batch_size=128),verbose=1)

In [None]:
hist_df = pd.DataFrame.from_records(trainer.model.history.history)
hist_df.head()

In [None]:
sns.set_context('paper')
metrics = ['loss','G_loss','class_acc']
fig,axs = plt.subplots(nrows=len(metrics),sharex=True,figsize=(5,10))
for metric_name,ax in zip(metrics,axs):
    sns.scatterplot(data=hist_df[[metric_name,'val_'+metric_name]],ax=ax)

In [None]:
# if not config.dev_mode:
trainer.save_model()

In [None]:
from keras.models import Model
from keras.layers import Input

In [None]:
generator = trainer.G

In [None]:
z_encoder = Model(trainer.E.input,trainer.z_lat)
classifier = Model(trainer.E.input,trainer.y_class)
# y_lat_encoder = Model(trainer.E.input,trainer.y_lat)
# decoder_inp = Input(shape=(config.y_dim+config.z_dim,))
# dec_layers = trainer.model.layers[-(1+(5*2)):]
# print(dec_layers)
# _gen_x = dec_layers[0](decoder_inp)
# l = dec_layers[1]
# isinstance(l,keras.layers.core.Reshape)
# F = None
# for l in dec_layers[1:]:
#     print(type(l))
    
#     if isinstance(l,keras.layers.merge.Add):
#         _gen_x = l([F,_gen_x])
#     else:
#         _gen_x = l(_gen_x)
    
#     if isinstance(l,keras.layers.convolutional.Conv2DTranspose):
#         if l.kernel_size==(1,1):
#             F = _gen_x
            
# # generator = Model(decoder_inp,_gen_x)

In [None]:
classifier.summary()

In [None]:
DL.y_test_oh.shape

In [None]:
classifier.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
classifier.evaluate(DL.sx_test,DL.y_test_oh,batch_size=config.batch_size)

In [None]:
z_enc = z_encoder.predict(DL.sx_test,batch_size=config.batch_size)
# y_lat = y_lat_encoder.predict(DL.sx_test,batch_size=config.batch_size)
y_lat = classifier.predict(DL.sx_test,batch_size=config.batch_size)

In [None]:
_lat_vec = np.concatenate([y_lat,z_enc],axis=1)
_lat_vec.shape

In [None]:
z_enc_mu = np.mean(z_enc,axis=0)
z_enc_cov = np.cov(z_enc,rowvar=False)

In [None]:
np.random.multivariate_normal(z_enc_mu,z_enc_cov,size=50).shape

In [None]:
regen = generator.predict(_lat_vec,batch_size=config.batch_size)

In [None]:
rand_im = np.random.randint(0,10000)
plt.imshow(regen[rand_im].reshape(56,56),cmap='gray')

In [None]:
_lat_vec[rand_im]

In [None]:
DL2 = Shifted_Data_Loader(dataset=config.dataset,flatten=True,
                         rotation=None,
                         translation=translation_amt,
                        )

In [None]:
enc_dec_samples(DL.x_test,DL.sx_test,z_enc,y_lat,generator)

In [None]:
z_enc2 = z_encoder.predict(DL2.sx_test,batch_size=config.batch_size)
y_lat2 = classifier.predict(DL2.sx_test,batch_size=config.batch_size)
_lat_vec2 = np.concatenate([y_lat2,z_enc2],axis=1)
regen2 = generator.predict(_lat_vec2,batch_size=config.batch_size)

In [None]:
from src.plot import remove_axes,remove_labels
from src.utils import gen_trajectory

In [None]:
examples = 5
rand_im = np.random.randint(0,10000,size=examples)
fix,axs = plt.subplots(examples,11,figsize=(8,4))
_lat_s = []
regen_s = []
out = gen_trajectory(z_enc[rand_im],z_enc2[rand_im],delta=.25)
out_y = gen_trajectory(y_lat[rand_im],y_lat2[rand_im],delta=.25)

for z,y in zip(out,out_y):
    _lat = np.concatenate([y,z],axis=1)
    _lat_s.append(_lat)
    regen_s.append(generator.predict(_lat,batch_size=config.batch_size))

i=0
for axr,idx in zip(axs,rand_im):
    axr[0].imshow(DL.x_test[idx].reshape(28,28),cmap='gray')
    axr[1].imshow(DL.sx_test[idx].reshape(56,56),cmap='gray')
    axr[2].imshow(regen[idx].reshape(56,56),cmap='gray')
    for j,a in enumerate(axr[3:-3]):
        a.imshow(regen_s[j][i,:].reshape(56,56),cmap='gray')
#         a.imshow(s.reshape(56,56),cmap='gray')
    axr[-3].imshow(regen2[idx].reshape(56,56),cmap='gray')
    axr[-2].imshow(DL2.sx_test[idx].reshape(56,56),cmap='gray')
    axr[-1].imshow(DL2.x_test[idx].reshape(28,28),cmap='gray')
    for a in axr:
        remove_axes(a)
        remove_labels(a)
    i+=1
# plt.imshow(regen[rand_im].reshape(56,56),cmap='gray')

In [None]:
from src.metrics import var_expl
dxs = DL.dx[1]-14
dys = DL.dy[1]-14
dtheta = DL.dtheta[1]
fve_dx = var_expl(features=z_enc,cond=dxs,bins=21)
fve_dy = var_expl(features=z_enc,cond=dys,bins=21)
fve_dt = var_expl(features=z_enc,cond=dtheta,bins=21)

In [None]:
fve_dx_norm = (dxs.var()-fve_dx)/dxs.var()
fve_dy_norm = (dys.var()-fve_dy)/dys.var()
fve_dth_norm = (dtheta.var()-fve_dt)/dtheta.var()

In [None]:
import seaborn as sns
sns.set_context('talk')

In [None]:
fve_dx_norm.shape
np.save(os.path.join(config.model_dir,'fve_dx_norm'),fve_dx_norm)

In [None]:
plt.scatter(np.arange(config.z_dim),fve_dx_norm.mean(axis=0))
plt.xlabel('Z_n')
plt.ylabel('fve_dx')
plt.tight_layout()
plt.savefig(os.path.join(config.model_dir,'fve_dx.png'))
# plt.ylim(-0.125,0.25)
xdim = np.argmax(fve_dx_norm.mean(axis=0))

In [None]:
fve_dy_norm.mean(axis=0)
np.save(os.path.join(config.model_dir,'fve_dy_norm'),fve_dy_norm)

In [None]:
plt.scatter(np.arange(config.z_dim),fve_dy_norm.mean(axis=0))
plt.xlabel('Z_n')
plt.ylabel('fve_dy')
plt.tight_layout()
plt.savefig(os.path.join(config.model_dir,'fve_dy.png'))
# plt.ylim(-0.125,0.25)
ydim = np.argmax(fve_dy_norm.mean(axis=0))

In [None]:
plt.scatter(np.arange(config.z_dim),fve_dth_norm.mean(axis=0))
plt.xlabel('Z_n')
plt.ylabel('fve_dtheta')
# plt.ylim(0.0,0.5)
np.argmax(fve_dth_norm.mean(axis=0))

In [None]:
from src.plot import Z_color_scatter
Z_color_scatter(z_enc,[xdim,ydim],dxs)

In [None]:
Z_color_scatter(z_enc,[xdim,ydim],dys)

In [None]:
Z_color_scatter(z_enc,[7,18],dtheta)

In [None]:
from plt.