In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

from keras.layers import Dense,Input,Lambda,Concatenate
from keras.models import Model
from keras.losses import categorical_crossentropy
import keras.backend as K
from keras.datasets import fashion_mnist
from keras.utils import to_categorical

from scipy.stats import norm
from tqdm import tqdm as tqdm

Using TensorFlow backend.


In [3]:
from src.utils import process_mnist,gen_trajectory,gen_sorted_isomap,limit_mem
from src.models import build_dense
from src.data_loader import prepare_keras_dataset,Shifted_Data_Loader

In [5]:
# train the VAE on MNIST digits
DL = Shifted_Data_Loader(dataset='fashion_mnist')
class_ids = np.unique(DL.y_train)
masks_train = [DL.y_train==i for i in class_ids]
masks_test = [DL.y_test==i for i in class_ids]

y_test_oh = to_categorical(DL.y_test,num_classes=10)
y_train_oh = to_categorical(DL.y_train,num_classes=10)

input_shape=(4*784,)
print(DL.x_train.shape)
print(DL.x_test.shape)

loading fashion_mnist


  5%|▌         | 3125/60000 [00:00<00:01, 31238.71it/s]

making training data...


100%|██████████| 60000/60000 [00:01<00:00, 33758.21it/s]
 37%|███▋      | 3651/10000 [00:00<00:00, 36494.37it/s]

making testing data...


100%|██████████| 10000/10000 [00:00<00:00, 36660.90it/s]


(60000, 784)
(10000, 784)


In [None]:
i = 250
print(y_train[masks_train[2]][i])
fig,axs = plt.subplots(1,2,figsize=(10,5))
axs[0].imshow(x_train[masks_train[2]][i].reshape(28,28))
axs[1].imshow(sx_train[masks_train[2]][i].reshape(28*2,28*2))

axs[0].get_xaxis().set_visible(False)
axs[0].get_yaxis().set_visible(False)
axs[1].get_xaxis().set_visible(False)
axs[1].get_yaxis().set_visible(False)

# fig.savefig('./shifted_mnist_3.png',dpi=300)

In [None]:
encoding_dims = [3000,1500]
z_dim = 2
y_dim = 3
inputs = Input(shape=input_shape)

# Brian Cheungs netowkr
encoded = build_dense(inputs,encoding_dims,activations='relu')

# encoded = build_dense(inputs,[512,encoding_dim],activations='relu')

In [None]:
z_mean = Dense(z_dim,name='z_mean')(encoded)
# z_log_sigma = Dense(latent_dim)(encoded)

def sampler(args):
    mean,log_stddev = args
    std_norm = K.random_normal(shape=(K.shape(mean)[0],latent_dim),mean=0,stddev=1)
    
    return mean + K.exp(log_stddev) * std_norm

# lat_vec = Lambda(sampler)([z_mean,z_log_sigma])

In [None]:
# y_hat_mean = Dense(y_dim,name='y_mean')(encoded)
# y_hat_sigma = Dense(y_dim,name='y_sigma')(encoded)
# y_hat = Lambda(sampler, name='y_hat')([y_hat_mean,y_hat_sigma])
y_hat = Dense(10,activation='softmax',name='y_hat')(encoded)

# Concatenate with One-hot identity vector
combo_vec = Concatenate()([z_mean,y_hat])

decoded_mean = build_dense(combo_vec,[encoding_dims[1],encoding_dims[0]]+[4*784],activations=['relu','relu','sigmoid'])
# decoded_mean = build_dense(combo_vec,[encoding_dim,512,784],activations=['relu','relu','sigmoid'])

In [None]:
from src.losses import *
from keras.metrics import categorical_accuracy

def acc(y_true,y_pred):
    return categorical_accuracy(y_true,y_hat)

def kl_loss_tot(y_true,y_pred):
    return kl_loss_z(y_true,y_pred)

def xentropy(y_true,y_pred):
    return 2*categorical_crossentropy(y_true,y_hat)

def recon_mse(y_true,y_pred):
    return K.mean(K.sum(K.square(y_pred-inputs),axis=-1),axis=0)

In [None]:
recon_loss = ReconstructionLoss(inputs=inputs,outputs=decoded_mean)
xcov = XCov(y_hat,z_mean,weight=1)
# kl_loss_z = KLDivergenceLoss(z_log_sigma,z_mean,weight=0.001,name='DKL_z')

In [None]:
vae = Model(inputs,decoded_mean)
def vae_loss(y_true,y_pred):
    total_loss = 0
    loss_fns = [
        K.sum(recon_loss(y_true,y_pred)),
        10*xcov(y_true,y_pred),
        K.sum(10*categorical_crossentropy(y_true,y_hat)),
#         K.sum(kl_loss_z(y_true,y_pred))/128,
#         K.sum(kl_loss_y(y_true,y_pred))
    ]
#     print(K.int_shape(xcov(y_true,y_pred)))
    for L in loss_fns:
        total_loss += L
        
    return total_loss
# vae.compile(loss=vae_loss,optimizer='rmsprop')
vae.compile(loss=vae_loss,optimizer='adadelta',metrics=[acc,xentropy])

In [None]:
vae.summary()

In [None]:
y_train_oh.shape

In [None]:
vae.fit(x=sx_train, y=y_train_oh,
        shuffle=True,
        epochs=50,
        batch_size=128,
       )

In [None]:
K.get_session().close()
cfg = K.tf.ConfigProto()
cfg.gpu_options.allow_growth = True
K.set_session(K.tf.Session(config=cfg))

In [None]:
encoder = Model(inputs,z_mean)
classifier = Model(inputs,y_hat)
decoder_inp = Input(shape=(12,))
# _generator_x = build_dense(decoder_inp,[encoding_dim,256,784],activations=['relu','relu','sigmoid'])
# generator = Model(decoder_inp,decoded_mean)
# print(generator.summary())
dec_layers = vae.layers[-3:]
_gen_x = dec_layers[0](decoder_inp)
_gen_x = dec_layers[1](_gen_x)
outputs = dec_layers[2](_gen_x)
generator = Model(decoder_inp,outputs)

In [None]:
x_test_encoded = encoder.predict(sx_test,batch_size=128)
y_oh_enc = classifier.predict(sx_test,batch_size=128)

In [None]:
plt.imshow(sx_test[5].reshape(56,56))
# generator.predict()

In [None]:
plt.hist2d(x_test_encoded[:,0],x_test_encoded[:,1])

In [None]:
cat_vec = np.concatenate([x_test_encoded[:5],y_oh_enc[:5]],axis=1)
cat_vec.shape

In [None]:
# x_test_encoded[2]
dec_test = generator.predict(cat_vec)
plt.imshow(dec_test[4].reshape(56,56))

In [None]:
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()

In [None]:
encoder.summary()

In [None]:
from datetime import date
print(date.today())
save_dir = '/home/elijahc/projects/vae/models/'+str(date.today())+'/'


In [None]:
# vae.save(save_dir+'vae_3layer.h5',include_optimizer=False)
# encoder.save(save_dir+'enc.h5',include_optimizer=False)

In [None]:
def sweep_lat(z,y_class,sweep=0,hold=1,num_std=2):
    z_mean = z[:,sweep].mean()
    z_std = z[:,sweep].std()
    x0 = np.array([z_mean-(num_std*z_std),z[:,hold].mean()])
    x1 = np.array([z_mean+(num_std*z_std),z[:,hold].mean()])
    traj = gen_trajectory(np.concatenate([x0,y_class],axis=0),np.concatenate([x1,y_class],axis=0),delta=.1)
    return traj

In [None]:
# z0_mean = np.mean(x_test_encoded[:,0])
# z0_std = x_test_encoded[:,0].std()
# z1_mean = x_test_encoded[:,1].mean()
# z1_std = x_test_encoded[:,1].std()
# x0 = np.array([z0_mean-(2*z0_std),z1_mean])
# x1 = np.array([z0_mean+(2*z0_std),z1_mean])


In [None]:
traj = sweep_lat(x_test_encoded,y_oh_enc[5])
dec_traj = K.get_value(generator(K.variable(traj)))
dec_traj = dec_traj.reshape(11,56,56)
fig, axs = plt.subplots(1,11,figsize=(10,10))
for i,ax in enumerate(axs):
    
    ax.imshow(dec_traj[i])

In [None]:
traj.shape

In [None]:
traj = sweep_lat(x_test_encoded,y_oh_enc[5],sweep=1,hold=0)
dec_traj = K.get_value(generator(K.variable(traj)))
dec_traj = dec_traj.reshape(11,56,56)
fig, axs = plt.subplots(1,11,figsize=(10,10))
for i,ax in enumerate(axs):
    
    ax.imshow(dec_traj[i])

In [None]:
fig, axs = plt.subplots(1,11,figsize=(10,10))
for i,ax in enumerate(axs):
    
    ax.imshow(dec_traj[i])

In [None]:
examples=3
sns.set_context('talk')
# sns.set_style('whitegrid')

# z0mean = z_mean_enc[:,0].mean()
# z1mean = z_mean_enc[:,1].mean()
# z0_sigma = z_mean_enc[:,0].std()
# z1_sigma = z_mean_enc[:,1].std()
# # z2_sigma = x_test_lat_enc[:,2].std()

fig,axs = plt.subplots(examples,4,figsize=(6,8))
choices = np.random.choice(np.arange(len(y_test)),examples)
# lat_vec_ = z_mean_enc[choices]
lat_vec_ = np.concatenate([x_test_encoded[choices],y_oh_enc[choices]],axis=1)
print(lat_vec_.shape)
dec_test = generator.predict(lat_vec_)

# print(x_test_encoded[choices])

for i,idx in enumerate(choices):
    rec_true_im = x_test[idx].reshape(28,28)
    in_im = sx_test[idx].reshape(28*2,28*2)
    dec_im = dec_test[i].reshape(28*2,28*2)
    
    axs[i,0].imshow(rec_true_im)
    axs[i,0].set_xticklabels([])
    axs[i,0].set_yticklabels([])
    
    axs[i,1].imshow(in_im)
    axs[i,1].set_xticklabels([])
    axs[i,1].set_yticklabels([])
    
    axs[i,2].imshow(dec_im)
    axs[i,2].set_xticklabels([])
    axs[i,2].set_yticklabels([])
#     axs[2,i].set_xlabel("class: {}".format(str(np.argmax(y_class_enc[idx]))))
    
    axs[i,3].imshow(y_oh_enc[idx].reshape(-1,1).T)
    axs[i,3].set_xticklabels([])
    axs[i,3].set_yticklabels([])
    axs[i,3].set_xlabel("class: {}".format(str(np.argmax(y_oh_enc[idx]))))
    
plt.tight_layout()
sns.despine(fig=fig)
# plt.imshow(dec_test[2].reshape(28,28).T)

In [None]:
dxs = delta_test[:,0]
dys = delta_test[:,1]
sns.set_context('talk')
plt.scatter(x_test_encoded[:,0],x_test_encoded[:,1],c=dxs-14)
plt.colorbar()
plt.title(r"dx in $\hat{Z}$")
plt.xlabel(r"$\hat{Z}_0$")
plt.ylabel(r"$\hat{Z}_1$")
plt.savefig("../figures/shifted_fashion_mnist_dx.pdf",dpi=300)

In [None]:
plt.scatter(x_test_encoded[:,0],x_test_encoded[:,1],c=dys-14)
plt.colorbar()
plt.title(r"dy in $\hat{Z}$")
plt.xlabel(r"$\hat{Z}_0$")
plt.ylabel(r"$\hat{Z}_1$")
plt.savefig("../figures/shifted_fashion_mnist_dy.pdf",dpi=300)

In [None]:
# fig,axs = plt.subplots(1,2,figsize=(12,5))
plt.scatter(dxs-14,dys-14,c=x_test_encoded[:,0])
# con = plt.contourf(dxs-14,dys-14,z_mean_enc[:,0])
# ax[1].scatter(dxs-14,dys-14,c=z_mean_enc[:,1])
# ax[0].set_xlabel('dx')
# ax[1].set_ylabel('dy')
plt.colorbar()
plt.xlabel(r"Shift ($\Delta x$)")
plt.ylabel(r"Shift ($\Delta y$)")
plt.title(r"dxdy shift in $\hat{Z}_0$")

In [None]:
enc_32 = vae.layers[6]
enc_256 = vae.layers[7]

In [None]:
encoder.trainable=False
x = enc_32(encoder.outputs[0])
y_class_oh = Dense(10,activation='softmax')(x)

In [None]:
med = Model(inputs=inputs,outputs=y_class_oh)
med.layers[-2].trainable=False
for l in med.layers[1:4]:
    l.trainable=False
med.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

In [None]:
y_train_oh = to_categorical(y_train,num_classes=10)

In [None]:
y_train_oh.shape

In [None]:
med.summary()

In [None]:
med.fit(x_train,y_train_oh,
        batch_size=128,
        epochs=25,
        validation_data=(x_test,to_categorical(y_test,num_classes=10))
       )

In [None]:
x_g = generator.predict(x_test_encoded[:3])
y_test_im = x_g.reshape(3,28,28)

In [None]:
x_test_im = x_test.reshape(10000,28,28)[:3]

fig,axs = plt.subplots(1,3)
for im,ax in zip(x_test_im,axs):
    ax.imshow(im)

In [None]:
fig,axs = plt.subplots(1,3)
for im,ax in zip(y_test_im,axs):
    ax.imshow(im)

In [None]:
med.evaluate(x_test,to_categorical(y_test,num_classes=10))