In [1]:
import os
import json
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

from src.data_loader import Shifted_Data_Loader,upsample_dataset
from src.plot import orig_vs_transformed as plot_ovt
from src.plot import enc_dec_samples
# from src.models import GResNet,EDense,EResNet,EConvNet
from src.test_models.drduplex import DRDuplex
from src.config import get_config
from src.trainer import Trainer
from src.utils import prepare_dirs_and_logger
from keras.datasets import fashion_mnist,mnist
from keras.layers import Dense
from keras.models import Model
from keras.utils import to_categorical
from keras.optimizers import adadelta

Using TensorFlow backend.


In [2]:
config,_ = get_config()
# Boilerplate
setattr(config, 'proj_root', '/home/elijahc/projects/vae')
setattr(config, 'log_dir', '/home/elijahc/projects/vae/logs')
setattr(config, 'dev_mode',True)
setattr(config, 'seed', 7)
setattr(config, 'project','vae')
setattr(config, 'ecc_max',4.8/8.0)
setattr(config, 'bg_noise',0.8)
setattr(config, 'contrast_level',0.8)
# setattr(config, 'rot_max',90.0/360.0)
setattr(config, 'rot_max',0)

# Training Params
setattr(config, 'batch_size', 512)
setattr(config, 'dataset', 'fashion_mnist')
setattr(config, 'epochs',1000)
setattr(config, 'monitor', None)
# setattr(config, 'lr', 10)
# setattr(config, 'min_delta', 0.25)
# setattr(config, 'monitor', 'val_loss')
setattr(config, 'optimizer', 'adam')
setattr(config, 'label_corruption',0.0)

In [3]:
# Architecture Params
setattr(config, 'enc_blocks', [128,256,512])
setattr(config, 'enc_arch', 'dense')
setattr(config, 'dec_blocks', [4,2,1])
setattr(config, 'z_dim', 35)
setattr(config, 'y_dim', 35)

In [4]:
if config.ecc_max == 0.:
    translation_amt = None
else:
    translation_amt = config.ecc_max

if config.rot_max == 0.:
    rot_max = None
else:
    rot_max = config.rot_max
    
if config.bg_noise == 0.:
    bg_noise = None
else:
    bg_noise = config.bg_noise

# Loss Weights
setattr(config, 'xcov', 0)
setattr(config, 'recon', 1)
setattr(config, 'xent', 15)
# setattr(config,'model_dir','/home/elijahc/projects/vae/models/2019-06-07/recon_{}_xent_{}/label_corruption_{}'.format(config.recon,config.xent,config.label_corruption))
setattr(config,'model_dir','/home/elijahc/projects/vae/models/2019-06-05/xent_{}_recon_{}_{}/bg_noise_{}'.format(config.xent,config.recon,config.enc_arch,config.bg_noise))

In [5]:
np.random.seed(7)
if not config.dev_mode:
    print('setting up...')
    prepare_dirs_and_logger(config)
    
vars(config)


{'batch_size': 512,
 'bg_noise': 0.8,
 'contrast_level': 0.8,
 'dataset': 'fashion_mnist',
 'dec_blocks': [4, 2, 1],
 'dev_mode': True,
 'ecc_max': 0.6,
 'enc_arch': 'dense',
 'enc_blocks': [128, 256, 512],
 'enc_layers': [500, 500],
 'epochs': 1000,
 'label_corruption': 0.0,
 'log_dir': '/home/elijahc/projects/vae/logs',
 'log_level': 'INFO',
 'model_dir': '/home/elijahc/projects/vae/models/2019-06-05/xent_15_recon_1_dense/bg_noise_0.8',
 'monitor': None,
 'optimizer': 'adam',
 'proj_root': '/home/elijahc/projects/vae',
 'project': 'vae',
 'recon': 1,
 'rot_max': 0,
 'seed': 7,
 'xcov': 0,
 'xent': 15,
 'y_dim': 35,
 'z_dim': 35}

In [None]:
oversample_factor=2
DL = Shifted_Data_Loader(dataset=config.dataset,flatten=False,num_train=60000*oversample_factor,
                         translation=translation_amt,
                         rotation=rot_max,
#                          contrast_level=config.contrast_level,
#                          bg='natural',
#                          blend=None,
                         noise_mode='uniform',
                         noise_kws={
                             'amount':1,
                             'width':config.bg_noise,
                         },
                         bg_only=True,
                        )

input_shape:  (56, 56, 1)
dataset:  fashion_mnist
background:  None
blend mode:  None
scale:  2
tx_max:  0.6
rot_max:  None
contrast_level:  1
noise_mode:  uniform
  amount: 1
  width: 0.8
creating noise uniform({'amount': 1, 'width': 0.8})...


train images: 100%|██████████| 120000/120000 [00:05<00:00, 23824.17it/s]
test_images: 100%|██████████| 10000/10000 [00:00<00:00, 30950.53it/s]


adding noise to training set


In [None]:
pt,idx = plot_ovt(DL,cmap='gray')

In [None]:
# plt.imshow(DL.fg_train[50].reshape(56,56),cmap='gray',vmin=0,vmax=1)

In [None]:
DL.sx_test.shape

In [None]:
mod = DRDuplex(img_shape=(56,56,1),
               num_classes=DL.num_classes,
               recon=config.recon,
               xent=config.xent,n_residual_blocks=4,
#                kernel_regularization=1e-5,
              )

In [None]:
mod.combined.summary()

In [None]:
DL.sx_test.shape

In [None]:
val_pct = 0.05
val_idxs = np.random.choice(np.arange(10000),int(val_pct*60000),replace=False)
validation_set = (DL.sx_test[val_idxs],
                  {'Classifier':DL.y_test_oh[val_idxs],
                   'Generator':DL.fg_test[val_idxs]}
                 )

In [None]:
mod.train(config.epochs,DL,config.batch_size,verbose=0,shuffle=True,
          validation_data=validation_set,
         )

In [None]:
hist_df = pd.DataFrame.from_records(mod.combined.history.history)
hist_df.head()

In [None]:
sns.set_context('paper')
metrics = ['loss','Generator_loss','Classifier_acc']
fig,axs = plt.subplots(nrows=len(metrics),sharex=True,figsize=(10,10))
for metric_name,ax in zip(metrics,axs):
    sns.scatterplot(data=hist_df[[metric_name,'val_'+metric_name]],ax=ax)
#     ax.set_xscale('log')
axs[2].hlines(y=(1.0/DL.num_classes),xmin=0,xmax=hist_df.index.values.max(),linestyles='dashed')

In [None]:
def enc_dec(model,DL):
    rand_im = np.random.randint(0,DL.x_train.shape[0])
    im = DL.sx_train[rand_im]
    y_true = DL.y_train_oh[rand_im]
    
    latent_rep = model.E.predict(im.reshape(1,56,56,1))
    y_pred = model.Q.predict(im.reshape(1,56,56,1))

    fig,axs = plt.subplots(2,2,figsize=(8,6))
    
    y_pred_axs = axs[1]
    y_pred_axs[0].imshow(y_true.reshape(1,-1))
    y_pred_axs[1].imshow(y_pred.reshape(1,-1))
    im_axs = axs[0]
    
    im_axs[0].imshow(im.reshape(56,56),cmap='gray')
    im_axs[0].set_title('Image; class: {}'.format(np.argmax(y_true)))
    im_axs[1].set_title('Recon; class: {}'.format(np.argmax(y_pred)))
    im_axs[1].imshow(model.G.predict(latent_rep).reshape(56,56),cmap='gray')
    for ax in axs.ravel():
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    
    plt.tight_layout()

In [None]:
enc_dec(mod,DL)

In [None]:
hist_df['generalization_error'] = hist_df.val_loss - hist_df.loss
hist_df['G_generalization_error'] = hist_df.val_Generator_loss - hist_df.Generator_loss
hist_df['class_generalization_error'] = hist_df.val_Classifier_loss - hist_df.Classifier_loss

In [None]:
sns.lineplot(data=hist_df[['class_generalization_error']])
# plt.yscale('log')

In [None]:
import datetime as dt
def clean_config(config,keys=['dev_mode','log_dir','log_level','proj_root']):
    c = vars(config)
    for k in keys:
        if k in c.keys():
            del c[k]
    
    c['uploaded_by']='elijahc'
    c['last_updated']= str(dt.datetime.now())
    return c

In [None]:
run_meta = clean_config(config)
run_meta['project']='vae'
# run_meta['ecc_max']=0.8
run_meta

In [None]:
trainer.save_model()
run_conf = clean_config(config)

with open(os.path.join(run_conf['model_dir'],'config.json'), 'w') as fp:
    json.dump(run_conf, fp)

hist_df.to_parquet(os.path.join(run_conf['model_dir'],'train_history.parquet'))

In [None]:
generator = mod.G

In [None]:
z_encoder = Model(mod.combined.input,mod.E.z_lat)
y_encoder = Model(trainer.input,trainer.y_lat)
classifier = Model(trainer.input,trainer.y_class)

l3_encoder = Model(trainer.input,trainer.model.get_layer(name='dense_1').output)
l1_encoder = Model(trainer.input,trainer.model.get_layer(name='conv2d_1').output)
# l2_encoder = Model(trainer.input,trainer.model.get_layer(name='block_2_Add_2').output)
# l2_encoder = Model(trainer.input,trainer.model.get_layer(name='block_4_Add_1').output)
l2_encoder = Model(trainer.input,trainer.model.get_layer(name='conv2d_3').output)

In [None]:
mod = trainer.model

In [None]:
# mod.summary()

In [None]:
def get_weight_grad(model, inputs, outputs):
    """ Gets gradient of model for given inputs and outputs for all weights"""
    grads = model.optimizer.get_gradients(model.total_loss, model.trainable_weights)
    symb_inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights)
    f = K.function(symb_inputs, grads)
    x, y, sample_weight = model._standardize_user_data(inputs, outputs)
    output_grad = f(x + y + sample_weight)
    return output_grad

In [None]:
classifier.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
res = classifier.evaluate(DL.sx_test,DL.y_test_oh,batch_size=config.batch_size)
ts_error = 1-res[1]
print(res[1])
df = pd.DataFrame.from_records({'test_acc':[res[1]],
                                'label_corruption':[config.label_corruption],
                                'recon':[config.recon],
                                'xent':[config.xent],
                                'ecc_max':[config.ecc_max],
                                'xcov': [config.xcov]})
df.to_json(os.path.join(config.model_dir,'performance.json'))

In [None]:
out_s = l1_encoder.output_shape
type(out_s)

In [None]:
l1_enc = l1_encoder.predict(DL.sx_test,batch_size=config.batch_size).reshape(10000,np.prod(l1_encoder.output_shape[1:]))
l2_enc = l2_encoder.predict(DL.sx_test,batch_size=config.batch_size).reshape(10000,np.prod(l2_encoder.output_shape[1:]))
l3_enc = l3_encoder.predict(DL.sx_test,batch_size=config.batch_size).reshape(10000,np.prod(l3_encoder.output_shape[1:]))

z_enc = z_encoder.predict(DL.sx_test,batch_size=config.batch_size)
# y_lat = y_lat_encoder.predict(DL.sx_test,batch_size=config.batch_size)
y_enc = y_encoder.predict(DL.sx_test,batch_size=config.batch_size)

In [None]:
l1_enc.shape

In [None]:
import xarray
import hashlib
import random
def raw_to_xr(encodings,l_2_depth,stimulus_set):
    obj_names = [
        "T-shirt",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Dress Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ]
    all_das = []
    for layer,activations in encodings.items():
        neuroid_n = activations.shape[1]
        n_idx = pd.MultiIndex.from_arrays([
            pd.Series(['{}_{}'.format(layer,i) for i in np.arange(neuroid_n)],name='neuroid_id'),
            pd.Series([l_2_depth[layer]]*neuroid_n,name='layer'),
            pd.Series([layer]*neuroid_n,name='region')
        ])
        p_idx = pd.MultiIndex.from_arrays([
            stimulus_set.image_id,
            stimulus_set.dx,
            stimulus_set.dy,
            stimulus_set.rxy,
            stimulus_set.numeric_label.astype('int8'),
            pd.Series([obj_names[i] for i in stimulus_set.numeric_label],name='object_name'),
            pd.Series(stimulus_set.dx.values/28, name='tx'),
            pd.Series(stimulus_set.dy.values/28, name='ty'),
            pd.Series([1.0]*len(stimulus_set),name='s'),
        ])
        da = xarray.DataArray(activations.astype('float32'),
                         coords={'presentation':p_idx,'neuroid':n_idx},
                         dims=['presentation','neuroid'])
        all_das.append(da)
        
    return xarray.concat(all_das,dim='neuroid')

In [None]:
encodings = {
    'pixel':DL.sx_test.reshape(10000,np.prod(DL.sx_test.shape[1:])),
    'dense_1':l1_enc,
    'dense_2':l2_enc,
    'dense_3':l3_enc,
    'y_lat':y_enc,
    'z_lat':z_enc
}
depths = {
    'pixel':0,
    'dense_1':1,
    'dense_2':2,
    'dense_3':3,
    'y_lat':4,
    'z_lat':4
}
slug = [(dx,dy,float(lab),float(random.randrange(20))) for dx,dy,rxy,lab in zip(DL.dx[1],DL.dy[1],DL.dtheta[1],DL.y_test)]
image_id = [hashlib.md5(json.dumps(list(p),sort_keys=True).encode('utf-8')).digest().hex() for p in slug]
stim_set = pd.DataFrame({'dx':DL.dx[1]-14,'dy':DL.dy[1]-14,'numeric_label':DL.y_test,'rxy':DL.dtheta[1],'image_id':image_id})

In [None]:
out = raw_to_xr(encodings,depths,stim_set)

In [None]:
out = raw_to_xr(encodings,depths,stim_set)
from collections import OrderedDict
def save_assembly(da,run_dir,fname,**kwargs):
    da = da.reset_index(da.coords.dims)
    da.attrs = OrderedDict()
    with open(os.path.join(run_dir,fname), 'wb') as fp:
        da.to_netcdf(fp,**kwargs)
        
    
save_assembly(out,run_dir=config.model_dir,fname='dataset.nc',
    format='NETCDF3_64BIT',
#         engine=
#         encoding=enc,
)

In [None]:
# z_enc_tr = z_encoder.predict(DL.sx_train,batch_size=config.batch_size)
# y_lat = y_lat_encoder.predict(DL.sx_test,batch_size=config.batch_size)
# y_enc_tr = y_encoder.predict(DL.sx_train,batch_size=config.batch_size)

In [None]:
np.save(os.path.join(config.model_dir,'z_enc'),z_enc)
np.save(os.path.join(config.model_dir,'l1_enc'),l1_enc)
np.save(os.path.join(config.model_dir,'l2_enc'),l2_enc)
np.save(os.path.join(config.model_dir,'y_enc'),y_enc)

In [None]:
y_enc.shape

In [None]:
_lat_vec = np.concatenate([y_enc,z_enc],axis=1)
_lat_vec.shape

In [None]:
z_enc_mu = np.mean(z_enc,axis=0)
z_enc_cov = np.cov(z_enc,rowvar=False)

In [None]:
np.random.multivariate_normal(z_enc_mu,z_enc_cov,size=50).shape

In [None]:
regen = generator.predict(_lat_vec,batch_size=config.batch_size)

In [None]:
rand_im = np.random.randint(0,10000)
plt.imshow(regen[rand_im].reshape(56,56),cmap='gray')

In [None]:
_lat_vec[rand_im]

In [None]:
# enc_dec_samples(DL.x_train,DL.sx_train,z_enc_tr,y_enc_tr,generator)

In [None]:
enc_dec_samples(DL.x_test,DL.sx_test,z_enc,y_enc,generator)

In [None]:
z_enc2 = z_encoder.predict(DL2.sx_test,batch_size=config.batch_size)
y_lat2 = y_encoder.predict(DL2.sx_test,batch_size=config.batch_size)
_lat_vec2 = np.concatenate([y_lat2,z_enc2],axis=1)
regen2 = generator.predict(_lat_vec2,batch_size=config.batch_size)

In [None]:
from src.plot import remove_axes,remove_labels
from src.utils import gen_trajectory

In [None]:
examples = 5
rand_im = np.random.randint(0,10000,size=examples)
fix,axs = plt.subplots(examples,11,figsize=(8,4))
_lat_s = []
regen_s = []
out = gen_trajectory(z_enc[rand_im],z_enc2[rand_im],delta=.25)
out_y = gen_trajectory(y_enc[rand_im],y_lat2[rand_im],delta=.25)

for z,y in zip(out,out_y):
    _lat = np.concatenate([y,z],axis=1)
    _lat_s.append(_lat)
    regen_s.append(generator.predict(_lat,batch_size=config.batch_size))

i=0
for axr,idx in zip(axs,rand_im):
    axr[0].imshow(DL.x_test[idx].reshape(28,28),cmap='gray')
    axr[1].imshow(DL.sx_test[idx].reshape(56,56),cmap='gray')
    axr[2].imshow(regen[idx].reshape(56,56),cmap='gray')
    for j,a in enumerate(axr[3:-3]):
        a.imshow(regen_s[j][i,:].reshape(56,56),cmap='gray')
#         a.imshow(s.reshape(56,56),cmap='gray')
    axr[-3].imshow(regen2[idx].reshape(56,56),cmap='gray')
    axr[-2].imshow(DL2.sx_test[idx].reshape(56,56),cmap='gray')
    axr[-1].imshow(DL2.x_test[idx].reshape(28,28),cmap='gray')
    for a in axr:
        remove_axes(a)
        remove_labels(a)
    i+=1
# plt.imshow(regen[rand_im].reshape(56,56),cmap='gray')

In [None]:
# fix.savefig('../../updates/2019-02-05/assets/img/translocate_{}.png'.format(translation_amt))

In [None]:
fdjsakl;fdsa

In [None]:
from collections import Counter
import dit
from dit import Distribution

def mutual_information(X,Y):
    XY_c = Counter(zip(X,Y))
    XY_pmf = {k:v/float(sum(XY_c.values())) for k,v in XY_c.items()}
    XY_jdist = Distribution(XY_pmf)
        
    return dit.shannon.mutual_information(XY_jdist,[0],[1])