In [1]:
import os
import json
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

from src.data_loader import Shifted_Data_Loader
from src.plot import orig_vs_transformed as plot_ovt
from src.plot import enc_dec_samples
from src.models import GResNet,EDense
from src.config import get_config
from src.trainer import Trainer
from src.utils import prepare_dirs_and_logger
from keras.datasets import fashion_mnist,mnist
from keras.layers import Dense
from keras.models import Model
from keras.utils import to_categorical
from src.metrics import var_expl
# from tabulate import tabulate

Using TensorFlow backend.


In [2]:
config,_ = get_config()
setattr(config, 'batch_size', 512)
setattr(config, 'dataset', 'fashion_mnist')
setattr(config, 'epochs', 100)
setattr(config, 'enc_layers', [3000,2000])
setattr(config, 'dec_blocks', [4,2,1])
setattr(config, 'z_dim', 10)
setattr(config, 'xcov', 1000)
setattr(config, 'recon', 5)
setattr(config, 'log_dir', '../logs')
setattr(config, 'dev_mode',False)
setattr(config, 'monitor', 'val_G_loss')
setattr(config, 'min_delta', 0.5)
# setattr(config, 'xcov', None)
setattr(config, 'optimizer', 'adam')

vars(config)

{'batch_size': 512,
 'data_dir': 'data',
 'dataset': 'fashion_mnist',
 'dec_blocks': [4, 2, 1],
 'dev_mode': False,
 'enc_layers': [3000, 2000],
 'epochs': 100,
 'log_dir': '../logs',
 'log_level': 'INFO',
 'min_delta': 0.5,
 'monitor': 'val_G_loss',
 'optimizer': 'adam',
 'recon': 5,
 'xcov': 1000,
 'xent': 10,
 'y_dim': 10,
 'z_dim': 10}

In [3]:
if not config.dev_mode:
        print('setting up...')
        prepare_dirs_and_logger(config)

setting up...


In [4]:
tx_val = 0
def trainer_generator(translation_vals):
    for tx_val in translation_vals:
        
        # Update config
        setattr(config, 'max_translation',tx_val)
        setattr(config,'model_name','translation'+str(tx_val)+config.model_name)
        prepare_dirs_and_logger(config)
        with open(os.path.join(config.model_dir,'params.json'), 'w') as fp:
            json.dump(vars(config), fp)
        
        print('Starting run tx=',tx_val)
        
        # Build Dataloader 
        if tx_val==0:
            shift = None
        else:
            shift = tx_val
        DL = Shifted_Data_Loader(dataset=config.dataset,flatten=True,
                             rotation=None,
                             translation=shift,
                            )
        G_builder = GResNet(y_dim=config.y_dim,z_dim=config.z_dim,dec_blocks=config.dec_blocks)
        E_builder = EDense(enc_layers=config.enc_layers,z_dim=config.z_dim,)
        trainer = Trainer(config,DL,E_builder,G_builder,)
        yield trainer

In [5]:
trainers = trainer_generator([0.0,0.15,0.3,0.5,0.7,0.9])

In [6]:
for trn in trainers:
    DL = trn.data_loader
    config = trn.config
    tx_val = trn.config.max_translation
    trn.compile_model()
    RF = to_categorical(np.ones(len(trn.data_loader.sx_train)),num_classes=2)
    val_pct = 0.05
    
    cut_pt = int(len(trn.data_loader.sx_train)*(1-val_pct))
    
    tr_x = trn.data_loader.sx_train[:cut_pt]
    val_x = trn.data_loader.sx_train[cut_pt:]
    tr_y = {'class':trn.data_loader.y_train_oh[:cut_pt],'D':RF[:cut_pt],'G':trn.data_loader.sx_train[:cut_pt]}
    val_y = {'class':trn.data_loader.y_train_oh[cut_pt:],'D':RF[cut_pt:],'G':trn.data_loader.sx_train[cut_pt:]}
    
    trn.go(x=tr_x,
           y=tr_y,
           validation_data=(val_x,val_y),
           verbose=0)
    
    hist_df = pd.DataFrame.from_records(trn.model.history.history)
    hist_df.to_pickle(os.path.join(config.model_dir,'tx_'+str(tx_val)+'training_hist.pk'))
    
    z_encoder = Model(trn.E.input,trn.z_lat)
    classifier = Model(trn.E.input,trn.y_class)
    z_enc = z_encoder.predict(trn.data_loader.sx_test,batch_size=config.batch_size)
    class_enc = classifier.predict(trn.data_loader.sx_test,batch_size=config.batch_size)
    np.save(os.path.join(config.model_dir,''.join(['z_enc_','tr',str(tx_val)])),z_enc)
    np.save(os.path.join(config.model_dir,''.join(['class_enc_','tr',str(tx_val)])),class_enc)
    dxs = trn.data_loader.dx[1]-14
    dys = trn.data_loader.dy[1]-14
    np.save(os.path.join(config.model_dir,'dxs_'+'tr'+str(tx_val)),dxs)
    np.save(os.path.join(config.model_dir,'dys_'+'tr'+str(tx_val)),dys)
# #     dtheta = DL.dtheta[1]
#     fve_dx = var_expl(features=z_enc,cond=dxs,bins=21)
#     fve_dy = var_expl(features=z_enc,cond=dys,bins=21)
    
#     fve_dx_norm = np.nan_to_num((dxs.var()-fve_dx)/dxs.var())
#     fve_dy_norm = np.nan_to_num((dys.var()-fve_dy)/dys.var())
# #     fve_dt = var_expl(features=z_enc,cond=dtheta,bins=21)

Starting run tx= 0.0
input_shape:  (3136,)
dataset:  fashion_mnist
scale:  2
tx_max:  None
rot_max:  None
loading fashion_mnist...
sx_train:  (60000, 3136)
making training data...
making testing data...
building encoder...
building decoder/generator...
Epoch        G_loss      val_G_loss  val_class_acc
0:           117.2644    63.5031     0.3213      
1:           27.6169     28.9169     0.4767      
2:           22.0376     25.207      0.5747      
3:           19.7434     20.7395     0.602       
4:           18.0581     18.6643     0.6077      
5:           17.0386     18.7438     0.7473      
6:           16.283      18.0107     0.836       
7:           15.4867     15.858      0.8487      
8:           14.9485     15.0814     0.856       
9:           14.4488     14.3334     0.857       
10:          14.0752     14.0168     0.8633      
11:          13.7958     14.2705     0.867       
12:          13.4436     13.39       0.869       
13:          13.1954     13.1105     0.8723   