In [22]:
import os
import json
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import tensorflow as tf

from src.data_loader import Shifted_Data_Loader
from src.plot import orig_vs_transformed as plot_ovt
from src.plot import enc_dec_samples
from src.models import GResNet,EDense
from src.config import get_config
from src.trainer import Trainer
from src.utils import prepare_dirs_and_logger
from src.losses import sse, mse
from src.test_models.EBGAN import EBGAN,Generator,Encoder,resample,gradient_penalty_loss

from keras.datasets import cifar10
from keras.layers import Dense,Concatenate,Input,Lambda,Activation,Reshape,LSTM,ConvLSTM2D
from keras.models import Model
from keras.callbacks import EarlyStopping
import keras.backend as K
from src.keras_callbacks import PrintHistory,Update_k
# from tabulate import tabulate

In [2]:
config,_ = get_config()

# Boilerplate
setattr(config, 'proj_root', '/home/elijahc/projects/vae')
setattr(config, 'log_dir', '/home/elijahc/projects/vae/logs')
setattr(config, 'dev_mode',True)
# setattr(config,'model_dir','/home/elijahc/projects/vae/models/2019-01-17/')

# Architecture Params
setattr(config, 'enc_layers', [3000,2000])
setattr(config, 'dec_blocks', [4,2,1])
setattr(config, 'z_dim', 10)
setattr(config, 'y_dim', 10)

# Training Params
setattr(config, 'batch_size', 512)
setattr(config, 'dataset', 'fashion_mnist')
setattr(config, 'epochs', 100)
setattr(config, 'monitor', 'val_loss')
setattr(config, 'min_delta', 0.5)
setattr(config, 'optimizer', 'adam')

# Loss Weights
setattr(config, 'xcov', 0)
setattr(config, 'recon', 10)
setattr(config, 'xent', 10)

In [None]:
if not config.dev_mode:
    print('setting up...')
    prepare_dirs_and_logger(config)
    
vars(config)

In [None]:
translation_amt = 0.5 # Med
DL = Shifted_Data_Loader(dataset=config.dataset,flatten=True,
                         rotation=None,
                         translation=translation_amt,
                        )

In [None]:
""" Model inputs"""

class_input = Input(shape=(10,),name='class_input')

In [None]:
""" AutoEncoder Critic"""
x = Input(shape=DL.input_shape,name='Image_input')

encoder = Encoder(input_shape=DL.input_shape,
                  y_dim=config.y_dim,
                  z_dim=config.z_dim,
                  layer_units=config.enc_layers)

net_out = encoder.build(x)
y = Activation('softmax',name='y')(net_out[0])
z = Activation('linear',name='z')(net_out[1])

# c = Activation('linear',name='critic_score')(net_out[2])

yz = Concatenate(name='yz')([y,z])

E = Model(inputs = x,
          outputs = [y,z],
          name='Encoder')

In [None]:
""" Decoder """
decoder = Generator(y_dim = config.y_dim,
                      z_dim = config.z_dim,
                      dec_blocks= config.dec_blocks)

Dec_input = Input(shape=(config.y_dim+config.z_dim,),name='Decoder_input')
Dec_output = decoder.build(Dec_input)

G = Model(inputs=Dec_input,
          outputs=Dec_output,
          name='Decoder')
# G.summary()

In [None]:
x_pred = Activation('linear',name='x_pred')(G(yz))


sse_layer = lambda x: K.expand_dims(sse(x,AE(x)))
AE = Model(inputs=x,outputs=x_pred,name='AE')
sse_out = Lambda(sse_layer)(AE(x))
D = Model(
    inputs=x,
    outputs=sse_out,
    name='D'
)

In [None]:
""" Generator """
def gen_Z(y):
    Z = K.random_normal(shape=(K.shape(y)[0],config.z_dim))
    
    return Z

generator = Generator(y_dim = config.y_dim,
                      z_dim = config.z_dim,
                      dec_blocks= config.dec_blocks)

G_input_y = Input(shape=(config.y_dim,),name='G_y')
G_input_z = Lambda(gen_Z,name='G_z')(G_input_y)

G_input = Concatenate(name='zy')([G_input_z,G_input_y])

G_img = generator.build(G_input)

Gen = Model(
    inputs=G_input_y,
    outputs=G_img,
    name='Generator'
)

In [None]:
""" Model Outputs """
fake_img = Activation('linear',name='fake_img')(Gen(class_input))

c_real = Activation('linear',name='C_real')(D(x))
c_fake = Activation('linear',name='C_fake')(D(fake_img))

# c_real = Activation('linear',name='C_real')(D(x))
# c_recon = D(recon_img)
# c_fake = Activation('linear',name='C_fake')(D(fake_img))

""" Losses """
# GAN Losses
GAN_d_loss = -1*(c_real - c_fake)
GAN_g_loss = -1*c_fake

# Gradient Penalty
gp_loss = gradient_penalty_loss(x,fake_img,D)

# Add Discriminator losses
D.add_loss([GAN_d_loss])

# Add Generator losses
# Gen.add_loss([GAN_g_loss])

EBGAN = Model(
    inputs=[x,class_input],
    outputs=[y,c_real,c_fake],
    name='EBGAN'
)
# mod_outputs = [
#     (recon_img, sse, config.recon),
#     (y, 'categorical_crossentropy', config.xent),
#     (c_fake,lambda yt,yp: GAN_d_loss+GAN_g_loss, 1),
# ]

# outs,ls,ws = zip(*mod_outputs)

# VGAN = Model(
# inputs=x,
# outputs=outs)

# losses = {k:v for k,v in zip(VGAN.output_names,ls)}
# loss_W = {k:v for k,v in zip(VGAN.output_names,ws)}

metrics = {
    'y': 'accuracy',
}

EBGAN.compile(optimizer=config.optimizer,loss={'y':'categorical_crossentropy','C_real':lambda yt,yp:GAN_d_loss,'C_fake':lambda yt,yp:GAN_g_loss},metrics=metrics)

In [None]:
EBGAN.output_names

In [None]:
from keras.utils import to_categorical
RF = to_categorical(np.ones(len(DL.sx_train)),num_classes=2)

In [None]:
print_history = PrintHistory(print_keys=['loss','val_loss','val_y_acc'])
# update_k = Update_k(k_var = k)
callbacks=[
    print_history,
#     update_k
]
if config.monitor is not None:
    early_stop = EarlyStopping(monitor=config.monitor,min_delta=config.min_delta,patience=10,restore_best_weights=True)
    callbacks.append(early_stop)
    
history = EBGAN.fit(x={'Image_input':DL.sx_train,'class_input':DL.y_train_oh},
              y={
                  'y':DL.y_train_oh,
                  'C_real':RF,
                  'C_fake':RF,
                  },
              verbose=0,
              batch_size=config.batch_size,
              callbacks=callbacks,
              validation_split=0.05,
              epochs=config.epochs,
              )

In [None]:
# # true_latent_vec = Concatenate()([y_class,z_lat_stats[0]])
# latent_vec = Concatenate()([y,z_lat])
# shuffled_lat = Concatenate()([y,z_sampled])
# G = trainer.G
# # recon = Activation('linear',name='G')(G(true_latent_vec))
# fake_inp = G(latent_vec)
# G_shuff = G(shuffled_lat)
# # fake_lat_vec = Concatenate()(E(fake_inp))
# # fake_ae = G(fake_lat_vec)

# D_real = Activation('linear',name='D_real')(D(real_inp))
# D_fake = Activation('linear',name='D_fake')(D(G_shuff))
# # D_fake = E(fake_inp)[2]
# D_all = Concatenate(axis=0,name='D_all')([D_fake,D_real])

In [None]:
pt,idx = plot_ovt(DL,cmap='gray')

In [None]:
# hist_df = pd.DataFrame.from_records(trainer.model.history.history)
hist_df = pd.DataFrame.from_records(VGAN.history.history)
hist_df.tail()

In [None]:
sns.set_context('paper')
metrics = ['loss','C_f_loss','y_acc']
fig,axs = plt.subplots(nrows=len(metrics),sharex=True,figsize=(5,10))
for metric_name,ax in zip(metrics,axs):
    sns.scatterplot(data=hist_df[[metric_name,'val_'+metric_name]],ax=ax)

In [None]:
# if not config.dev_mode:
# trainer.save_model()

In [None]:
from keras.models import Model
from keras.layers import Input

In [None]:
generator = G

In [None]:
z_encoder = Model(x,z)
classifier = Model(x,y)
# y_lat_encoder = Model(trainer.E.input,trainer.y_lat)
# decoder_inp = Input(shape=(config.y_dim+config.z_dim,))
# dec_layers = trainer.model.layers[-(1+(5*2)):]
# print(dec_layers)
# _gen_x = dec_layers[0](decoder_inp)
# l = dec_layers[1]
# isinstance(l,keras.layers.core.Reshape)
# F = None
# for l in dec_layers[1:]:
#     print(type(l))
    
#     if isinstance(l,keras.layers.merge.Add):
#         _gen_x = l([F,_gen_x])
#     else:
#         _gen_x = l(_gen_x)
    
#     if isinstance(l,keras.layers.convolutional.Conv2DTranspose):
#         if l.kernel_size==(1,1):
#             F = _gen_x
            
# # generator = Model(decoder_inp,_gen_x)

In [None]:
classifier.summary()

In [None]:
DL.y_test_oh.shape

In [None]:
classifier.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
classifier.evaluate(DL.sx_test,DL.y_test_oh,batch_size=config.batch_size)

In [None]:
z_enc = z_encoder.predict(DL.sx_test,batch_size=config.batch_size)
# y_lat = y_lat_encoder.predict(DL.sx_test,batch_size=config.batch_size)
y_lat = classifier.predict(DL.sx_test,batch_size=config.batch_size)

In [None]:
_lat_vec = np.concatenate([y_lat,z_enc],axis=1)
_lat_vec.shape

In [None]:
z_enc_mu = np.mean(z_enc,axis=0)
z_enc_cov = np.cov(z_enc,rowvar=False)

In [None]:
np.random.multivariate_normal(z_enc_mu,z_enc_cov,size=50).shape

In [None]:
regen = generator.predict(_lat_vec,batch_size=config.batch_size)

In [None]:
rand_im = np.random.randint(0,10000)
plt.imshow(regen[rand_im].reshape(56,56),cmap='gray')

In [None]:
_lat_vec[rand_im]

In [None]:
DL2 = Shifted_Data_Loader(dataset=config.dataset,flatten=True,
                         rotation=None,
                         translation=translation_amt,
                        )

In [None]:
enc_dec_samples(DL.x_test,DL.sx_test,z_enc,y_lat,generator)

In [None]:
z_enc2 = z_encoder.predict(DL2.sx_test,batch_size=config.batch_size)
y_lat2 = classifier.predict(DL2.sx_test,batch_size=config.batch_size)
_lat_vec2 = np.concatenate([y_lat2,z_enc2],axis=1)
regen2 = generator.predict(_lat_vec2,batch_size=config.batch_size)

In [None]:
from src.plot import remove_axes,remove_labels
from src.utils import gen_trajectory

In [None]:
examples = 5
rand_im = np.random.randint(0,10000,size=examples)
fix,axs = plt.subplots(examples,11,figsize=(8,4))
_lat_s = []
regen_s = []
out = gen_trajectory(z_enc[rand_im],z_enc2[rand_im],delta=.25)
out_y = gen_trajectory(y_lat[rand_im],y_lat2[rand_im],delta=.25)

for z,y in zip(out,out_y):
    _lat = np.concatenate([y,z],axis=1)
    _lat_s.append(_lat)
    regen_s.append(generator.predict(_lat,batch_size=config.batch_size))

i=0
for axr,idx in zip(axs,rand_im):
    axr[0].imshow(DL.x_test[idx].reshape(28,28),cmap='gray')
    axr[1].imshow(DL.sx_test[idx].reshape(56,56),cmap='gray')
    axr[2].imshow(regen[idx].reshape(56,56),cmap='gray')
    for j,a in enumerate(axr[3:-3]):
        a.imshow(regen_s[j][i,:].reshape(56,56),cmap='gray')
#         a.imshow(s.reshape(56,56),cmap='gray')
    axr[-3].imshow(regen2[idx].reshape(56,56),cmap='gray')
    axr[-2].imshow(DL2.sx_test[idx].reshape(56,56),cmap='gray')
    axr[-1].imshow(DL2.x_test[idx].reshape(28,28),cmap='gray')
    for a in axr:
        remove_axes(a)
        remove_labels(a)
    i+=1
# plt.imshow(regen[rand_im].reshape(56,56),cmap='gray')

In [None]:
from sklearn.preprocessing import MinMaxScaler

feat_range = (0,50)
z_enc_scaled = [MinMaxScaler(feat_range).fit_transform(z_enc[:,i].reshape(-1,1)).tolist() for i in np.arange(config.z_dim)]
z_enc_scaled = np.squeeze(np.array(z_enc_scaled,dtype=int))

In [None]:
from collections import Counter
import dit
from dit import Distribution
dxs = DL.dx[1]-14
dys = DL.dy[1]-14

def mutual_information(X,Y):
    XY_c = Counter(zip(X,Y))
    XY_pmf = {k:v/float(sum(XY_c.values())) for k,v in XY_c.items()}
    XY_jdist = Distribution(XY_pmf)
        
    return dit.shannon.mutual_information(XY_jdist,[0],[1])

In [None]:
z_dx_I = [mutual_information(z_enc_scaled[i],dxs.astype(int)+14) for i in np.arange(config.z_dim)]

In [None]:
z_dy_I = [mutual_information(z_enc_scaled[i],dys.astype(int)+14) for i in np.arange(config.z_dim)]

In [None]:
z_class_I = [mutual_information(z_enc_scaled[i],DL.y_test) for i in np.arange(config.z_dim)]

In [None]:
z_I_df = pd.DataFrame.from_records({'class':z_class_I,'dy':z_dy_I,'dx':z_dx_I})
z_I_df['class'] = z_I_df['class'].values.round(decimals=1)

In [None]:
sns.set_context('talk')
fig,ax = plt.subplots(1,1,figsize=(6,5))
ax.set_ylim(0,0.8)
ax.set_xlim(0,0.8)
points = plt.scatter(x=z_I_df['dx'],y=z_I_df['dy'],c=z_I_df['class'],cmap='plasma')
plt.colorbar(points)

In [None]:
fig,ax = plt.subplots(1,1,figsize=(5,5))
ax.scatter(z_dx_I,z_dy_I)
ax.set_ylim(0,0.6)
ax.set_xlim(0,0.6)

In [None]:
plt.scatter(np.arange(config.z_dim),sorted(z_dy_I,reverse=True))

In [None]:
from src.metrics import var_expl,norm_var_expl
from collections import Counter



dtheta = DL.dtheta[1]
fve_dx = norm_var_expl(features=z_enc,cond=dxs,bins=21)
fve_dy = norm_var_expl(features=z_enc,cond=dys,bins=21)
# fve_dt = norm_var_expl(features=z_enc,cond=dtheta,bins=21)

In [None]:
# fve_dx_norm = (dxs.var()-fve_dx)/dxs.var()
# fve_dy_norm = (dys.var()-fve_dy)/dys.var()
# fve_dth_norm = (dtheta.var()-fve_dt)/dtheta.var()
fve_dx_norm = fve_dx
fve_dy_norm = fve_dy

In [None]:
import seaborn as sns
sns.set_context('talk')

In [None]:
fve_dx_norm.shape
# np.save(os.path.join(config.model_dir,'fve_dx_norm'),fve_dx_norm)

In [None]:
fig,ax = plt.subplots(1,1,figsize=(5,5))
plt.scatter(fve_dx_norm.mean(axis=0),fve_dy_norm.mean(axis=0))
plt.xlabel('fve_dx')
plt.ylabel('fve_dy')
plt.tight_layout()
# plt.savefig(os.path.join(config.model_dir,'fve_dx.png'))
# plt.ylim(-0.125,0.25)
xdim = np.argmax(fve_dx_norm.mean(axis=0))

In [None]:
fve_dy_norm.mean(axis=0)
# np.save(os.path.join(config.model_dir,'fve_dy_norm'),fve_dy_norm)

In [None]:
plt.scatter(np.arange(config.z_dim),fve_dy_norm.mean(axis=0))
plt.xlabel('Z_n')
plt.ylabel('fve_dy')
plt.tight_layout()
# plt.savefig(os.path.join(config.model_dir,'fve_dy.png'))
# plt.ylim(-0.125,0.25)
ydim = np.argmax(fve_dy_norm.mean(axis=0))

In [None]:
# plt.scatter(np.arange(config.z_dim),fve_dth_norm.mean(axis=0))
# plt.xlabel('Z_n')
# plt.ylabel('fve_dtheta')
# # plt.ylim(0.0,0.5)
# np.argmax(fve_dth_norm.mean(axis=0))

In [None]:
from src.plot import Z_color_scatter
Z_color_scatter(z_enc,[xdim,ydim],dxs)

In [None]:
Z_color_scatter(z_enc,[xdim,ydim],dys)

In [None]:
Z_color_scatter(z_enc,[7,18],dtheta)

In [None]:
from plt.