In [1]:
import numpy as np
import scipy as scipy
import neptune
# from neptunecontrib.api.utils import get_filepaths
import pandas as pd
import matplotlib.pyplot as plt
import os
from sklearn.preprocessing import MinMaxScaler

import seaborn as sns

from src.data_generator import ShiftedDataBatcher
from src.test_models.drduplex import DRDuplex
from src.data_loader import _shift_image
from edcutils.datasets import bsds500
from edcutils.image import get_patch

Using TensorFlow backend.


In [2]:
PROPS = {
    'dataset':'fashion_mnist',
    'encoder_arch': 'dense',
    'generator_arch': 'resnet',
    'augmentation': 'dynamic',
}
PARAMS = {}

train_conf = {
    'n_epochs': 54000,
    'batch_sz':512,
}

data_conf = {
    'bg': 'natural',
    'im_translation':0.75,
    'bg_contrast': 0.3,
}

model_conf = {
    'xent_weight': 15,
    'recon_weight': 0,
}

for conf in [train_conf,data_conf,model_conf]:
    PARAMS.update(conf)

In [3]:
DB = ShiftedDataBatcher(PROPS['dataset'],translation=PARAMS['im_translation'],bg=PARAMS['bg'],
                        blend=None,
#                         blend='difference',
                        batch_size=PARAMS['batch_sz'],
                       )

In [4]:
num_pan = 3
test_generator = DB.gen_test_batches(4, batch_size=PARAMS['batch_sz'], bg=PARAMS['bg'])
not_unique = True
while not_unique:
    pan_idx = np.random.choice(np.arange(len(DB.x_te)),size=num_pan,replace=False)
    if len(np.unique(DB.y_test[pan_idx])) == num_pan:
        not_unique = False
    
px_ = DB.x_te[pan_idx]
tX, tX_fg,ty = next(test_generator)

(3, 28, 28)


In [22]:
bg_imgs,_ = bsds500.load_data()

In [19]:
def gen_pan_X(pX_fg, bg_imgs):
    pX_bg = DB.gen_backgrounds(pX_fg,bg_imgs,rand=rand)
    pX = DB.rasterize([pX_bg.copy(),pX_fg],blend=DB.blend)

    return pX

In [25]:
x_span = DB.gen_pan_deltas(step=2)
pX = np.stack([np.expand_dims(_shift_image(X=px_,dx=dx,dy=dy),-1) for dx,dy in x_span])
pX = np.concatenate([pX]*3,axis=4)

scaler01 = MinMaxScaler(feature_range=(0,1))
scalerminus = MinMaxScaler(feature_range=(-1,1))
n,r,x,y,c = pX.shape

pX = scaler01.fit_transform(pX.reshape(n,x*y*c*r)).reshape(n,r,x,y,c)

In [28]:
pX.shape

(41, 3, 56, 56, 3)

In [29]:
pXX = gen_pan_X(pX.reshape(n*r,x,y,c),bg_imgs).reshape(n,r,x,y,c)

TypeError: Cannot cast ufunc add output from dtype('float64') to dtype('uint8') with casting rule 'same_kind'

In [11]:
pX.reshape(n*r,x,y,c).shape

(123, 56, 56, 1)

In [17]:
c = 3

TypeError: Cannot cast ufunc add output from dtype('float64') to dtype('uint8') with casting rule 'same_kind'

In [None]:
fig,axs = plt.subplots(1,num_pan)
for i,ax in enumerate(axs.ravel()):
    ax.imshow(pX[27,i].reshape(56,56),cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
plt.tight_layout()

In [None]:
os.environ['NEPTUNE_API_TOKEN']="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5tbCIsImFwaV9rZXkiOiI3ZWExMTlmYS02ZTE2LTQ4ZTktOGMxMi0wMDJiZTljOWYyNDUifQ=="

In [None]:
# neptune.init('elijahc/sandbox')
# neptune.set_project('elijahc/sandbox')

In [None]:
# exp = neptune.project.get_experiments(id='SAN-18')

In [None]:
exp = exp[0]

In [None]:
exp.get_properties()

In [None]:
def blockshaped(arr, nrows, ncols):
    """
    Return an array of shape (n, nrows, ncols) where
    n * nrows * ncols = arr.size

    If arr is a 2D array, the returned array looks like n subblocks with
    each subblock preserving the "physical" layout of arr.
    """
    h, w = arr.shape
    return (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols))

def unblockshaped(arr, h, w):
    """
    Return an array of shape (h, w) where
    h * w = arr.size

    If arr is of shape (n, nrows, ncols), n sublocks of shape (nrows, ncols),
    then the returned array preserves the "physical" layout of the sublocks.
    """
    n, nrows, ncols = arr.shape
    return (arr.reshape(h//nrows, -1, nrows, ncols)
               .swapaxes(1,2)
               .reshape(h, w))

In [None]:
exp_date = '2019-09-25'
proj_root = '/home/elijahc/projects/vae'
models_root = os.path.join(proj_root,'models',exp_date)


In [None]:
neptune.init()
PROPS = {
    'dataset':'fashion_mnist',
    'bg':'natural',
    'encoder_arch': 'dense',
    'generator_arch': 'resnet',
    'n_epochs': 36000,
    'augmentation': 'dynamic',
}

In [None]:
get_filepaths(dirpath='./src',extensions=['.py', '.yaml', '.yml'])

In [None]:
src_files = ['./src/data_generator.py','./src/test_models/drduplex.py']

In [None]:
exp = neptune.create_experiment(name='test_exp',
                                properties=PROPS,
                                upload_source_files=src_files,
                               )
exp_dir = os.path.join('models',exp_date,exp.id) 

os.mkdir(os.path.join(proj_root,exp_dir))

In [None]:
exp.set_property('dir',exp_dir)

In [None]:
exp.get_properties()

In [None]:
bg = 'natural'
DB = ShiftedDataBatcher('fashion_mnist',translation=0.75,bg=bg,
                        blend=None,
#                         blend='difference',
                        batch_size=512,
                       )

In [None]:
gen = DB.gen_train_batches(2,bg=bg,bg_contrast=0.3)

In [None]:
x,xbg,y = next(gen)
print(x.shape,y.shape)

In [None]:
xbg.shape

In [None]:
fig,axs = plt.subplots(1,2)

axs[0].imshow(x[5].reshape(56,56),cmap='gray')
axs[1].imshow(xbg[5].reshape(56,56),cmap='gray')

In [None]:
w_xent = 15
w_recon = 1

In [None]:
mod = DRDuplex(img_shape=(56,56,1),
               num_classes=DB.num_classes,
               recon=w_recon,
               xent=w_xent,n_residual_blocks=4,
               kernel_regularization=1e-5,
              )

In [None]:
mod.combined.metrics_names

In [None]:
# val_pct = 0.05
# val_idxs = np.random.choice(np.arange(10000),int(val_pct*60000),replace=False)
# validation_set = (DB.x_te[val_idxs],
#                   {'Classifier':DB.y_test_oh[val_idxs],
#                    'Generator':DB.fg_test[val_idxs]}
#                  )

In [None]:
i = 0
epoch_sz = int(DB.num_train/512)

# hist_labels = mod.combined.metrics_names
hist_labels = ['loss','G_loss','C_loss','G_mse','acc']

train_hist = []
test_hist = []

# val_X,val_X_fg,val_y = DB.gen_batch(DB.x_te,DB.y_test_oh,batch_size=1000,bg='natural')
test_generator = DB.gen_test_batches(4, batch_size=1024, bg=bg)
tX, tX_fg,ty = next(test_generator)

In [None]:
for X,X_fg,y in DB.gen_train_batches(36000,bg=bg):
    y = {'Classifier':y,
         'Generator':X_fg,
        }
    r = mod.combined.train_on_batch(X,y)
    r = {k:v for k,v in zip(hist_labels,r)}
    
#     metrics = {k:v for k,v in zip(met_names,met_vals)}
    if i > 100 and i % epoch_sz == 0:
        r_te = mod.combined.test_on_batch(tX,{'Classifier':ty,
                                              'Generator':tX_fg
                                             })
        r_te = {k:v for k,v in zip(hist_labels,r_te)}
        r['batch']=i
        r_te['batch']=i
        
        r['result_type'] = 'train'
        r_te['result_type'] = 'valid'
        
        train_hist.append(r)
        train_hist.append(r_te)

        p_loss = "{:5d} (train/val) G/C_loss[{:2.2f}/{:.2f},   {:2.2f}/{:.2f}]".format(i,r['G_loss'],r['C_loss'],r_te['G_loss'],r_te['C_loss'])
#         p_loss = "{:5d} (train/val loss:[{:2.2f}, {:2.2f}])".format(i,r['loss'],r_te['loss'])
        p_acc = "[acc: {:2.2%},   val_acc: {:2.2%}]".format(r['acc'],r_te['acc'])
        print(p_loss,' ',p_acc)
    i+=1

In [None]:
hist_tr = pd.DataFrame.from_records(train_hist)

In [None]:
sns.set_context('talk')
fig,axs = plt.subplots(1,4,figsize=(4*4,1*4),sharex=True)
for ax,metric in zip(axs.ravel(),['loss','G_loss','C_loss','acc']):
    sns.lineplot(x='batch',y=metric, hue='result_type',data=hist_tr,ax=ax)

In [None]:
# sns.scatterplot(x='batch',y='val_acc',data=hist_tr)

In [None]:
mod_dir = '/home/elijahc/projects/vae/models/2019-09-11/{}_xent_{}_recon_{}'.format(bg,w_xent,w_recon)
hist_tr.to_csv(os.path.join(mod_dir,'training_hist.csv'))

In [None]:
mod.combined.save_weights(os.path.join(mod_dir,'weights.h5'))

In [None]:
teX,_,_ = next(test_generator)

In [None]:
out = mod.combined.predict_on_batch(teX)

In [None]:
x_recon,y = out

In [None]:
choices = np.random.choice(np.arange(1024),size=5,replace=False)

In [None]:
fix,axes = plt.subplots(2,5,figsize=(2*5,2*2))

for i,idx in enumerate(choices):
    axs = axes[:,i]
    axs[0].imshow(teX[idx].reshape(56,56),cmap='gray')
    axs[1].imshow(x_recon[idx].reshape(56,56),cmap='gray')
    
for ax in axes.ravel():
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)
    
plt.tight_layout()

In [None]:
x_tr = DL.sx_train
x_tro = DL.x_train

In [None]:
x_tro.shape

In [None]:
im = np.squeeze(x_tr[5])
plt.imshow(x_tro[5])

In [None]:
x_t_warp = np.zeros((28,28))
xt_warp = scipy.ndimage.zoom(x_tro[5],(1.1,0.9))
plt.imshow(xt_warp)

In [None]:
DL.meta_train