In [2]:
import os
import json
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import tensorflow as tf

from src.data_loader import Shifted_Data_Loader
from src.plot import orig_vs_transformed as plot_ovt
from src.plot import enc_dec_samples
from src.models import GResNet,EDense
from src.config import get_config
from src.trainer import Trainer
from src.utils import prepare_dirs_and_logger
from src.losses import sse, mse
from src.test_models.EBGAN import EBGAN,Generator,Encoder,resample,gradient_penalty_loss

from keras.datasets import cifar10
from keras.layers import *
from keras.layers.merge import add
from keras.regularizers import l2
from keras.models import Model
from keras.callbacks import EarlyStopping
import keras.backend as K
from src.keras_callbacks import PrintHistory,Update_k
from src.resnet import _bn_relu_conv, _shortcut, basic_block, _handle_dim_ordering,ResnetBuilder,_conv_bn_relu,_residual_block
from recurrentshop import *
# from tabulate import tabulate

Using TensorFlow backend.


In [3]:
config,_ = get_config()

# Boilerplate
setattr(config, 'proj_root', '/home/elijahc/projects/vae')
setattr(config, 'log_dir', '/home/elijahc/projects/vae/logs')
setattr(config, 'dev_mode',True)
# setattr(config,'model_dir','/home/elijahc/projects/vae/models/2019-01-17/')

# Architecture Params
setattr(config, 'enc_layers', [3000,2000])
setattr(config, 'dec_blocks', [4,2,1])
setattr(config, 'z_dim', 10)
setattr(config, 'y_dim', 10)

# Training Params
setattr(config, 'batch_size', 512)
setattr(config, 'dataset', 'fashion_mnist')
setattr(config, 'epochs', 100)
setattr(config, 'monitor', 'val_loss')
setattr(config, 'min_delta', 0.5)
setattr(config, 'optimizer', 'adam')

# Loss Weights
setattr(config, 'xcov', 0)
setattr(config, 'recon', 10)
setattr(config, 'xent', 10)

In [4]:
if not config.dev_mode:
    print('setting up...')
    prepare_dirs_and_logger(config)
    
vars(config)

{'batch_size': 512,
 'dataset': 'fashion_mnist',
 'dec_blocks': [4, 2, 1],
 'dev_mode': True,
 'enc_layers': [3000, 2000],
 'epochs': 100,
 'log_dir': '/home/elijahc/projects/vae/logs',
 'log_level': 'INFO',
 'min_delta': 0.5,
 'monitor': 'val_loss',
 'optimizer': 'adam',
 'proj_root': '/home/elijahc/projects/vae',
 'recon': 10,
 'xcov': 0,
 'xent': 10,
 'y_dim': 10,
 'z_dim': 10}

In [5]:
(x_tr,y_tr),(x_te,y_te) = cifar10.load_data()

In [6]:
# encoder_inputs = Input(shape=(None, num_enc_tokens))

# encoder = LSTM(latent_dim, return_state=True)
# enc_out,state_h,state_c = encoder(encoder_inputs)

In [7]:
# encoder = ConvLSTM2D(8,(2,2),strides=(2,2),return_state=True,return_sequences=True)

# encoder_outputs, state_h, state_c = encoder(enc_inputs)

In [8]:
class DiracDeltaFunc(object):
    def __init__(self,t_open=0):
        self.t = 0.
        self.t_open = float(t_open)
        
    def __call__(self,x):
        if self.t == self.t_open:
            out = x * 1
        else:
            out = x * 0
            
        self.t += 1
        return out

In [17]:
from recurrentshop import *
x_t = Input(shape=(32,32,16)) # The input to the RNN at time t
h_tm1 = Input(shape=(32,32,16))  # Previous hidden state

# Compute new hidden state
_handle_dim_ordering()
k = _residual_block(basic_block,filters=32,repetitions=2)(h_tm1)

h_t = add([x_t, k])

# tanh activation
# h_t = Activation('tanh')(h_t)

# Build the RNN
# RecurrentModel is a standard Keras `Recurrent` layer. 
# RecurrentModel also accepts arguments such as unroll, return_sequences etc
rnn = RecurrentModel(input=x_t, initial_states=[h_tm1], output=h_t, final_states=[h_t])

ValueError: Operands could not be broadcast together with shapes (32, 32, 16) (16, 16, 32)

In [16]:
block

<tf.Tensor 'add_4/add:0' shape=(?, 16, 16, 32) dtype=float32>

In [None]:
class RRNCell2(Layer):
    def __init__(self,filters,
                 kernel_size,
                 block_fn,
                 num_states=1,
                 data_format='channels_last',
                 **kwargs):
        
        self.filters=filters
        self.kernel_size=kernel_size
        self.block_fn = block_fn
        self.num_states=num_states
        self.data_format = data_format
        self.state_size = (self.filters)
        super(RRNCell2, self).__init__(**kwargs)
        
    def build(self,input_shape):
        self.state_size = (1,)
    
    def call(self,inputs,states):
        prev_output = states[0]
        I = prev_output
        k = self.block_fn(prev_output)
        k_i = add([k,I])
        
        h = add([inputs,k_i])
        
        return h,[h]
        

In [None]:
class ResRNN(ConvRNN2D):
    def __init__(self, cell,
                 return_sequences=False,
                 return_state=False,
                 go_backwards=False,
                 stateful=False,
                 unroll=False,
                 **kwargs):
        super(ConvRNN2D, self).__init__(cell,
                                        return_sequences,
                                        return_state,
                                        go_backwards,
                                        stateful,
                                        unroll,
                                        **kwargs)
    
    def compute_output_shape(self,input_shape):
        if isinstance(input_shape, list):
            input_shape = input_shape[0]

        cell = self.cell
        if cell.data_format == 'channels_first':
            rows = input_shape[3]
            cols = input_shape[4]
        elif cell.data_format == 'channels_last':
            rows = input_shape[2]
            cols = input_shape[3]
        
        output_shape = input_shape[:2] + (rows, cols, cell.filters)
        output_shape = transpose_shape(output_shape, cell.data_format,
                                       spatial_axes=(2, 3))
        
        if not self.return_sequences:
            output_shape = output_shape[:1] + output_shape[2:]

        if self.return_state:
            output_shape = [output_shape]
            base = (input_shape[0], rows, cols, cell.filters)
            base = transpose_shape(base, cell.data_format, spatial_axes=(1, 2))
            output_shape += [base[:] for _ in range(1)]
        
        return output_shape
    
    def build(self, input_shape):
        # Note input_shape will be list of shapes of initial states and
        # constants if these are passed in __call__.
        
        cell = self.cell
        self.input_spec = [input_shape[2:5] for _ in range(input_shape[1])]
        self.state_spec = [input_shape[2:5] for _ in range(input_shape[1])]
        
        if self.stateful:
            self.reset_states()
            
        self.built = True
        
    def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
        additional_inputs = []
        additional_specs = []
        if initial_state is not None:
            kwargs['initial_state'] = initial_state
            additional_inputs += initial_state
            self.state_spec = []
            for state in initial_state:
                try:
                    shape = K.int_shape(state)
                # Fix for Theano
                except TypeError:
                    shape = tuple(None for _ in range(K.ndim(state)))
                self.state_spec.append(InputSpec(shape=shape))

            additional_specs += self.state_spec
            
        for tensor in additional_inputs:
            if K.is_keras_tensor(tensor) != K.is_keras_tensor(additional_inputs[0]):
                raise ValueError('The initial state or constants of an RNN'
                                 ' layer cannot be specified with a mix of'
                                 ' Keras tensors and non-Keras tensors')

        if K.is_keras_tensor(additional_inputs[0]):
            # Compute the full input spec, including state and constants
            full_input = [inputs] + additional_inputs
            full_input_spec = self.input_spec + additional_specs
            # Perform the call with temporarily replaced input_spec
            original_input_spec = self.input_spec
            self.input_spec = full_input_spec
            output = super(ConvRNN2D, self).__call__(full_input, **kwargs)
            self.input_spec = original_input_spec
            return output
        else:
            return super(ConvRNN2D, self).__call__(inputs, **kwargs)

In [None]:
RR = ResRNN(RRNCell2(16,kernel_size=(3,3),block_fn=basic_block),return_sequences=False,return_state=True)

In [None]:
RR.compute_output_shape((1,5,32,32,16))

In [None]:
input_shape = tuple([1,5]+list(x_tr.shape[1:]))
enc_inputs = Input(batch_shape=input_shape)
pre_x = TimeDistributed(Conv2D(16,kernel_size=(3,3),strides=(1,1),padding='same'))(enc_inputs)
RR(pre_x,initial_state=pre_x)

In [None]:
class V1Cell(ConvLSTM2DCell):
    def __init__(self, filters,
                 kernel_size,
                 num_states,
                 **kwargs):
        self.num_states = num_states
        self.input_gate = DiracDeltaFunc()
        super(V1Cell, self).__init__(filters,kernel_size,**kwargs)

    def build(self, input_shape):

        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters*self.num_states)
        self.kernel_shape = kernel_shape
        recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters*self.num_states)

        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        self.recurrent_kernel = self.add_weight(
            shape=recurrent_kernel_shape,
            initializer=self.recurrent_initializer,
            name='recurrent_kernel',
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)
        if self.use_bias:
            if self.unit_forget_bias:
                def bias_initializer(_, *args, **kwargs):
                    return K.concatenate([
                        self.bias_initializer((self.filters,), *args, **kwargs),
                        initializers.Ones()((self.filters,), *args, **kwargs),
                        self.bias_initializer((self.filters * 2,), *args, **kwargs),
                    ])
            else:
                bias_initializer = self.bias_initializer
            self.bias = self.add_weight(shape=(self.filters * 2,),
                                        name='bias',
                                        initializer=bias_initializer,
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None

        for i in np.arange(self.num_states):
            setattr(self,'kernel_{}'.format(i),self.kernel[:,:,:,i*self.filters:self.filters*(i+1)])
            setattr(self,'recurrent_kernel_{}'.format(i),self.recurrent_kernel[:,:,:,i*self.filters:self.filters*(i+1)])

#         self.kernel_i = self.kernel[:, :, :, :self.filters]
        
#         self.recurrent_kernel_0 = self.recurrent_kernel[:, :, :, :self.filters]
#         self.recurrent_kernel_1 = self.recurrent_kernel[:,:,:,self.filters:self.filters*2]

        if self.use_bias:
            self.bias_i = self.bias[:self.filters]
#             self.bias_f = self.bias[self.filters: self.filters * 2]
#             self.bias_c = self.bias[self.filters * 2: self.filters * 3]
#             self.bias_o = self.bias[self.filters * 3:]
        else:
            self.bias_i = None
#             self.bias_f = None
#             self.bias_c = None
#             self.bias_o = None
        self.built = True

    def call(self,inputs,states):
        # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask
        
        if 0 < self.dropout < 1.:
            inputs_i = inputs * dp_mask[0]
            inputs_f = inputs * dp_mask[1]
            inputs_c = inputs * dp_mask[2]
            inputs_o = inputs * dp_mask[3]
        else:
            inputs_i = inputs
            inputs_f = inputs
            inputs_c = inputs
            inputs_o = inputs
    
        h_tm1 = states[0]  # previous memory state
        
        if 0 < self.recurrent_dropout < 1.:
            h_tm1_i = h_tm1 * rec_dp_mask[0]
            h_tm1_f = h_tm1 * rec_dp_mask[1]
            h_tm1_c = h_tm1 * rec_dp_mask[2]
            h_tm1_o = h_tm1 * rec_dp_mask[3]
        else:
            h_tm1_i = h_tm1
            h_tm1_f = h_tm1
            h_tm1_c = h_tm1
            h_tm1_o = h_tm1
        
#         I = UpSampling2D(data_format=self.data_format)(h_tm1)
#         print('I: ',K.int_shape(I))
#         print('inputs: ',inputs_i)
#         print('h_tm1: ',K.int_shape(h_tm1))
        I = UpSampling2D(data_format=self.data_format)(h_tm1_i)
#         I = h_tm1_i
        print('I : ',K.int_shape(I))
        bn = BatchNormalization()
        relu = Activation('relu')
        _bn_relu(bn,relu,)
        net = Activation('relu')(I)
#         print('BN: ',K.int_shape(net))
        self.K = self.input_conv(net,self.kernel_0,self.bias_i,padding='same')
#         self.K = self.BRC(I, self.kernel_0)
        print('K: ',K.int_shape(self.K))
#         self.K_2 = self.BRC(K_1,self.recurrent_kernel_1)
#         print('out: ',K.int_shape(out))
        print('inputs_i: ',K.int_shape(inputs_i))
        out = Add()([self.K,h_tm1_i])
        print('out: ',K.int_shape(out))
        
        self.h1 = _shortcut()


#         K_recurrent = UpSampling2D(size=(3,3),data_format='channels_last')(K_1)+I
                
        return self.h1, [self.h1]
    
    def _bn_relu(x,bn,relu,idx):
        return Lambda(lambda inp: relu(bn(inp)),name='bn_relu_{}'.format(idx))(x)
    
    def BRC(self,x,w):
        net = BatchNormalization()(x)
        net = Activation('relu')(net)
        print('BN: ',K.int_shape(net))
        out = self.recurrent_conv(net,w)
        
              
#         out = UpSampling2D(data_format=self.data_format)
        print('out: ',K.int_shape(out))
              
        return out
VC = V1Cell(filters=8,num_states=1,kernel_size=(2,2),strides=(2,2),padding='same')

In [None]:
from keras.layers.merge import add

In [None]:
C_RNN = ConvRNN2D(VC,return_sequences=True,return_state=True)
t_unroll = 5
# Channels_last format [batch,time,rows,cols,channels]
input_shape = tuple([1,t_unroll]+list(x_tr.shape[1:]))
enc_inputs = Input(batch_shape=input_shape)
block_fn=basic_block
enc_input = Input(shape=input_shape)
conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7), strides=(2, 2))(enc_input)
pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1)
pool1
block = pool1
px = TimeDistributed(Conv2D(filters=8,kernel_size=(3,3),padding='same'))(enc_inputs)
print("px: ",K.int_shape(px))
out_c = C_RNN(px)

In [None]:
out_c

In [None]:
class RNNResBlockCell(Layer):
    def __init__(self,filters,
                 kernel_size,
                 padding='valid',
                 data_format=None,
                 strides=(1, 1),
                 dilation_rate=(1, 1),
                 kernel_activation='relu',
                 kernel_initializer='glorot_uniform',
                 kernel_constraint=None,
                 recurrent_activation='relu',
                 recurrent_initializer='orthogonal',
                 recurrent_constraint=None,
                 **kwargs):
        
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.dilation_rate = dilation_rate
#         self.state_size = filters
        self.data_format = K.normalize_data_format(data_format)
        self.kernel_activation=kernel_activation
        self.kernel_initializer=kernel_initializer
        self.kernel_constraint=kernel_constraint
        self.recurrent_activation=recurrent_activation
        self.recurrent_initializer=recurrent_initializer
        self.recurrent_constraint=recurrent_constraint
        
        self.input_gate_func = DiracDeltaFunc()
        self.state_size = (self.filters,self.filters)
    
        super(RNNResBlockCell, self).__init__(**kwargs)
        
    def build(self, input_shape):
#         super(RNNResBlockCell, self).__init__(**kwargs)
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters)
        
        self.kernel_shape = kernel_shape
        
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        
        recurrent_kernel_shape = self.kernel_size + (self.filters,self.filters)
        
        self.recurrent_kernel = self.add_weight(
            shape=recurrent_kernel_shape,
            initializer=self.recurrent_initializer,
            name='recurrent_kernel',
#             regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)
        
    
    def call(self, inputs, states):
        inputs = self.input_gate_func(inputs)
        prev_output = states[0]
        K = self.recurrent_conv(x=prev_output,w=self.recurrent_kernel)
        I = self.input_identity(prev_output)
        output = inputs + K + I
        return output, [output]
    
    def input_conv(self, x, w, b=None, padding='valid'):
        conv_out = K.conv2d(x, w, strides=self.strides,
                            padding=padding,
                            data_format=self.data_format,
                            dilation_rate=self.dilation_rate)
        if b is not None:
            conv_out = K.bias_add(conv_out, b,
                                  data_format=self.data_format)
        return conv_out
    
    def input_identity(self,x):
        I_out = K.identity(x,name='Identity')
        return I_out
    
    def h1_recurrent(x,w):
        return self.BRC(x,w)
        
    def BRC(self,x,w):
        net = BatchNormalization()(x)
        net = Activation('relu')(net)
        net = K.conv2d(net,w,strides=(2,2),
                       padding='same',
                       data_format=self.data_format)
        net = BatchNormalization()(net)
        net = Activation('relu')(net)
        out = K.conv2d(net,w,strides=(2,2),
                       padding='same',
                       data_format=self.data_format)
        
        return out
        
    def recurrent_conv(self, x, w):
        conv_out = K.conv2d(x, w, strides=(1, 1),
                            padding='same',
                            data_format=self.data_format)
        return conv_out

RRCell = RNNResBlockCell(filters=8,kernel_size=(3,3),data_format='channels_last',padding='same')

In [None]:
C_RNN = ConvRNN2D(ConvLSTM2DCell(filters=64,kernel_size=(2,2),strides=(2,2)),return_sequences=True,return_state=True,stateful=True)
t_unroll = 5
# Channels_last format [batch,time,rows,cols,channels]
input_shape = tuple([1,t_unroll]+list(x_tr.shape[1:]))
enc_inputs = Input(batch_shape=input_shape)
px = TimeDistributed(Conv2D(filters=8,kernel_size=(3,3),padding='same'))(enc_inputs)

out_c = C_RNN(px)

In [None]:
out_c

In [None]:
mod = Model(inputs=enc_inputs,outputs=out_c)

In [None]:
mod.outputs

In [None]:
def pre_net(x,filts=1,tsteps=5):
#     net = RepeatVector(tsteps)(net)
    
    return net

def fwd_res(x,features=16):
    net = BatchNormalization()(x)
    net = Activation('relu')(net)
    c_layer = ConvLSTM2D(filters=features,kernel_size=(3,3),padding="same",
                     activation=None,stateful=True)
    out = c_layer(initial_state=)
    
    return out

In [None]:

print(input_shape)

latent_dim = 10



In [None]:

# pnet = Conv2D(16,kernel_size=(3,3),padding="same")(enc_inputs)

# enc_inputs
# px = pre_net(enc_inputs,filts=16)
# net = BatchNormalization()(pnet)
# net = Activation('relu')(net)
c_layer = ConvLSTM2D(16,kernel_size=(3,3),padding="same",
                     activation=None,stateful=True,return_state=True,return_sequences=True)
out_c, state_h, state_c = c_layer(enc_inputs)
out
mod = Model(inputs=enc_inputs,outputs=out_c)
# net = fwd_res(enc_inputs,features=32)

In [None]:
out

In [None]:
mod.summary()

In [None]:
px = pre_net(enc_inputs,filts=16)
max_tsteps = 5
from keras.layers import Masking

In [None]:
Masking()

In [None]:
net

In [None]:
out = fwd_res(enc_inputs)

In [None]:
out

In [None]:
translation_amt = 0.5 # Med
DL = Shifted_Data_Loader(dataset=config.dataset,flatten=True,
                         rotation=None,
                         translation=translation_amt,
                        )

In [None]:
""" Model inputs"""

class_input = Input(shape=(10,),name='class_input')

In [None]:
""" AutoEncoder Critic"""
x = Input(shape=DL.input_shape,name='Image_input')

encoder = Encoder(input_shape=DL.input_shape,
                  y_dim=config.y_dim,
                  z_dim=config.z_dim,
                  layer_units=config.enc_layers)

net_out = encoder.build(x)
y = Activation('softmax',name='y')(net_out[0])
z = Activation('linear',name='z')(net_out[1])

# c = Activation('linear',name='critic_score')(net_out[2])

yz = Concatenate(name='yz')([y,z])

E = Model(inputs = x,
          outputs = [y,z],
          name='Encoder')

In [None]:
""" Decoder """
decoder = Generator(y_dim = config.y_dim,
                      z_dim = config.z_dim,
                      dec_blocks= config.dec_blocks)

Dec_input = Input(shape=(config.y_dim+config.z_dim,),name='Decoder_input')
Dec_output = decoder.build(Dec_input)

G = Model(inputs=Dec_input,
          outputs=Dec_output,
          name='Decoder')
# G.summary()

In [None]:
x_pred = Activation('linear',name='x_pred')(G(yz))


sse_layer = lambda x: K.expand_dims(sse(x,AE(x)))
AE = Model(inputs=x,outputs=x_pred,name='AE')
sse_out = Lambda(sse_layer)(AE(x))
D = Model(
    inputs=x,
    outputs=sse_out,
    name='D'
)

In [None]:
""" Generator """
def gen_Z(y):
    Z = K.random_normal(shape=(K.shape(y)[0],config.z_dim))
    
    return Z

generator = Generator(y_dim = config.y_dim,
                      z_dim = config.z_dim,
                      dec_blocks= config.dec_blocks)

G_input_y = Input(shape=(config.y_dim,),name='G_y')
G_input_z = Lambda(gen_Z,name='G_z')(G_input_y)

G_input = Concatenate(name='zy')([G_input_z,G_input_y])

G_img = generator.build(G_input)

Gen = Model(
    inputs=G_input_y,
    outputs=G_img,
    name='Generator'
)

In [None]:
""" Model Outputs """
fake_img = Activation('linear',name='fake_img')(Gen(class_input))

c_real = Activation('linear',name='C_real')(D(x))
c_fake = Activation('linear',name='C_fake')(D(fake_img))

# c_real = Activation('linear',name='C_real')(D(x))
# c_recon = D(recon_img)
# c_fake = Activation('linear',name='C_fake')(D(fake_img))

""" Losses """
# GAN Losses
GAN_d_loss = -1*(c_real - c_fake)
GAN_g_loss = -1*c_fake

# Gradient Penalty
gp_loss = gradient_penalty_loss(x,fake_img,D)

# Add Discriminator losses
D.add_loss([GAN_d_loss])

# Add Generator losses
# Gen.add_loss([GAN_g_loss])

EBGAN = Model(
    inputs=[x,class_input],
    outputs=[y,c_real,c_fake],
    name='EBGAN'
)
# mod_outputs = [
#     (recon_img, sse, config.recon),
#     (y, 'categorical_crossentropy', config.xent),
#     (c_fake,lambda yt,yp: GAN_d_loss+GAN_g_loss, 1),
# ]

# outs,ls,ws = zip(*mod_outputs)

# VGAN = Model(
# inputs=x,
# outputs=outs)

# losses = {k:v for k,v in zip(VGAN.output_names,ls)}
# loss_W = {k:v for k,v in zip(VGAN.output_names,ws)}

metrics = {
    'y': 'accuracy',
}

EBGAN.compile(optimizer=config.optimizer,loss={'y':'categorical_crossentropy','C_real':lambda yt,yp:GAN_d_loss,'C_fake':lambda yt,yp:GAN_g_loss},metrics=metrics)

In [None]:
EBGAN.output_names

In [None]:
from keras.utils import to_categorical
RF = to_categorical(np.ones(len(DL.sx_train)),num_classes=2)

In [None]:
print_history = PrintHistory(print_keys=['loss','val_loss','val_y_acc'])
# update_k = Update_k(k_var = k)
callbacks=[
    print_history,
#     update_k
]
if config.monitor is not None:
    early_stop = EarlyStopping(monitor=config.monitor,min_delta=config.min_delta,patience=10,restore_best_weights=True)
    callbacks.append(early_stop)
    
history = EBGAN.fit(x={'Image_input':DL.sx_train,'class_input':DL.y_train_oh},
              y={
                  'y':DL.y_train_oh,
                  'C_real':RF,
                  'C_fake':RF,
                  },
              verbose=0,
              batch_size=config.batch_size,
              callbacks=callbacks,
              validation_split=0.05,
              epochs=config.epochs,
              )

In [None]:
# # true_latent_vec = Concatenate()([y_class,z_lat_stats[0]])
# latent_vec = Concatenate()([y,z_lat])
# shuffled_lat = Concatenate()([y,z_sampled])
# G = trainer.G
# # recon = Activation('linear',name='G')(G(true_latent_vec))
# fake_inp = G(latent_vec)
# G_shuff = G(shuffled_lat)
# # fake_lat_vec = Concatenate()(E(fake_inp))
# # fake_ae = G(fake_lat_vec)

# D_real = Activation('linear',name='D_real')(D(real_inp))
# D_fake = Activation('linear',name='D_fake')(D(G_shuff))
# # D_fake = E(fake_inp)[2]
# D_all = Concatenate(axis=0,name='D_all')([D_fake,D_real])

In [None]:
pt,idx = plot_ovt(DL,cmap='gray')

In [None]:
# hist_df = pd.DataFrame.from_records(trainer.model.history.history)
hist_df = pd.DataFrame.from_records(VGAN.history.history)
hist_df.tail()

In [None]:
sns.set_context('paper')
metrics = ['loss','C_f_loss','y_acc']
fig,axs = plt.subplots(nrows=len(metrics),sharex=True,figsize=(5,10))
for metric_name,ax in zip(metrics,axs):
    sns.scatterplot(data=hist_df[[metric_name,'val_'+metric_name]],ax=ax)

In [None]:
# if not config.dev_mode:
# trainer.save_model()

In [None]:
from keras.models import Model
from keras.layers import Input

In [None]:
generator = G

In [None]:
z_encoder = Model(x,z)
classifier = Model(x,y)
# y_lat_encoder = Model(trainer.E.input,trainer.y_lat)
# decoder_inp = Input(shape=(config.y_dim+config.z_dim,))
# dec_layers = trainer.model.layers[-(1+(5*2)):]
# print(dec_layers)
# _gen_x = dec_layers[0](decoder_inp)
# l = dec_layers[1]
# isinstance(l,keras.layers.core.Reshape)
# F = None
# for l in dec_layers[1:]:
#     print(type(l))
    
#     if isinstance(l,keras.layers.merge.Add):
#         _gen_x = l([F,_gen_x])
#     else:
#         _gen_x = l(_gen_x)
    
#     if isinstance(l,keras.layers.convolutional.Conv2DTranspose):
#         if l.kernel_size==(1,1):
#             F = _gen_x
            
# # generator = Model(decoder_inp,_gen_x)

In [None]:
classifier.summary()

In [None]:
DL.y_test_oh.shape

In [None]:
classifier.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
classifier.evaluate(DL.sx_test,DL.y_test_oh,batch_size=config.batch_size)

In [None]:
z_enc = z_encoder.predict(DL.sx_test,batch_size=config.batch_size)
# y_lat = y_lat_encoder.predict(DL.sx_test,batch_size=config.batch_size)
y_lat = classifier.predict(DL.sx_test,batch_size=config.batch_size)

In [None]:
_lat_vec = np.concatenate([y_lat,z_enc],axis=1)
_lat_vec.shape

In [None]:
z_enc_mu = np.mean(z_enc,axis=0)
z_enc_cov = np.cov(z_enc,rowvar=False)

In [None]:
np.random.multivariate_normal(z_enc_mu,z_enc_cov,size=50).shape

In [None]:
regen = generator.predict(_lat_vec,batch_size=config.batch_size)

In [None]:
rand_im = np.random.randint(0,10000)
plt.imshow(regen[rand_im].reshape(56,56),cmap='gray')

In [None]:
_lat_vec[rand_im]

In [None]:
DL2 = Shifted_Data_Loader(dataset=config.dataset,flatten=True,
                         rotation=None,
                         translation=translation_amt,
                        )

In [None]:
enc_dec_samples(DL.x_test,DL.sx_test,z_enc,y_lat,generator)

In [None]:
z_enc2 = z_encoder.predict(DL2.sx_test,batch_size=config.batch_size)
y_lat2 = classifier.predict(DL2.sx_test,batch_size=config.batch_size)
_lat_vec2 = np.concatenate([y_lat2,z_enc2],axis=1)
regen2 = generator.predict(_lat_vec2,batch_size=config.batch_size)

In [None]:
from src.plot import remove_axes,remove_labels
from src.utils import gen_trajectory

In [None]:
examples = 5
rand_im = np.random.randint(0,10000,size=examples)
fix,axs = plt.subplots(examples,11,figsize=(8,4))
_lat_s = []
regen_s = []
out = gen_trajectory(z_enc[rand_im],z_enc2[rand_im],delta=.25)
out_y = gen_trajectory(y_lat[rand_im],y_lat2[rand_im],delta=.25)

for z,y in zip(out,out_y):
    _lat = np.concatenate([y,z],axis=1)
    _lat_s.append(_lat)
    regen_s.append(generator.predict(_lat,batch_size=config.batch_size))

i=0
for axr,idx in zip(axs,rand_im):
    axr[0].imshow(DL.x_test[idx].reshape(28,28),cmap='gray')
    axr[1].imshow(DL.sx_test[idx].reshape(56,56),cmap='gray')
    axr[2].imshow(regen[idx].reshape(56,56),cmap='gray')
    for j,a in enumerate(axr[3:-3]):
        a.imshow(regen_s[j][i,:].reshape(56,56),cmap='gray')
#         a.imshow(s.reshape(56,56),cmap='gray')
    axr[-3].imshow(regen2[idx].reshape(56,56),cmap='gray')
    axr[-2].imshow(DL2.sx_test[idx].reshape(56,56),cmap='gray')
    axr[-1].imshow(DL2.x_test[idx].reshape(28,28),cmap='gray')
    for a in axr:
        remove_axes(a)
        remove_labels(a)
    i+=1
# plt.imshow(regen[rand_im].reshape(56,56),cmap='gray')

In [None]:
from sklearn.preprocessing import MinMaxScaler

feat_range = (0,50)
z_enc_scaled = [MinMaxScaler(feat_range).fit_transform(z_enc[:,i].reshape(-1,1)).tolist() for i in np.arange(config.z_dim)]
z_enc_scaled = np.squeeze(np.array(z_enc_scaled,dtype=int))

In [None]:
from collections import Counter
import dit
from dit import Distribution
dxs = DL.dx[1]-14
dys = DL.dy[1]-14

def mutual_information(X,Y):
    XY_c = Counter(zip(X,Y))
    XY_pmf = {k:v/float(sum(XY_c.values())) for k,v in XY_c.items()}
    XY_jdist = Distribution(XY_pmf)
        
    return dit.shannon.mutual_information(XY_jdist,[0],[1])

In [None]:
z_dx_I = [mutual_information(z_enc_scaled[i],dxs.astype(int)+14) for i in np.arange(config.z_dim)]

In [None]:
z_dy_I = [mutual_information(z_enc_scaled[i],dys.astype(int)+14) for i in np.arange(config.z_dim)]

In [None]:
z_class_I = [mutual_information(z_enc_scaled[i],DL.y_test) for i in np.arange(config.z_dim)]

In [None]:
z_I_df = pd.DataFrame.from_records({'class':z_class_I,'dy':z_dy_I,'dx':z_dx_I})
z_I_df['class'] = z_I_df['class'].values.round(decimals=1)

In [None]:
sns.set_context('talk')
fig,ax = plt.subplots(1,1,figsize=(6,5))
ax.set_ylim(0,0.8)
ax.set_xlim(0,0.8)
points = plt.scatter(x=z_I_df['dx'],y=z_I_df['dy'],c=z_I_df['class'],cmap='plasma')
plt.colorbar(points)

In [None]:
fig,ax = plt.subplots(1,1,figsize=(5,5))
ax.scatter(z_dx_I,z_dy_I)
ax.set_ylim(0,0.6)
ax.set_xlim(0,0.6)

In [None]:
plt.scatter(np.arange(config.z_dim),sorted(z_dy_I,reverse=True))

In [None]:
from src.metrics import var_expl,norm_var_expl
from collections import Counter



dtheta = DL.dtheta[1]
fve_dx = norm_var_expl(features=z_enc,cond=dxs,bins=21)
fve_dy = norm_var_expl(features=z_enc,cond=dys,bins=21)
# fve_dt = norm_var_expl(features=z_enc,cond=dtheta,bins=21)

In [None]:
# fve_dx_norm = (dxs.var()-fve_dx)/dxs.var()
# fve_dy_norm = (dys.var()-fve_dy)/dys.var()
# fve_dth_norm = (dtheta.var()-fve_dt)/dtheta.var()
fve_dx_norm = fve_dx
fve_dy_norm = fve_dy

In [None]:
import seaborn as sns
sns.set_context('talk')

In [None]:
fve_dx_norm.shape
# np.save(os.path.join(config.model_dir,'fve_dx_norm'),fve_dx_norm)

In [None]:
fig,ax = plt.subplots(1,1,figsize=(5,5))
plt.scatter(fve_dx_norm.mean(axis=0),fve_dy_norm.mean(axis=0))
plt.xlabel('fve_dx')
plt.ylabel('fve_dy')
plt.tight_layout()
# plt.savefig(os.path.join(config.model_dir,'fve_dx.png'))
# plt.ylim(-0.125,0.25)
xdim = np.argmax(fve_dx_norm.mean(axis=0))

In [None]:
fve_dy_norm.mean(axis=0)
# np.save(os.path.join(config.model_dir,'fve_dy_norm'),fve_dy_norm)

In [None]:
plt.scatter(np.arange(config.z_dim),fve_dy_norm.mean(axis=0))
plt.xlabel('Z_n')
plt.ylabel('fve_dy')
plt.tight_layout()
# plt.savefig(os.path.join(config.model_dir,'fve_dy.png'))
# plt.ylim(-0.125,0.25)
ydim = np.argmax(fve_dy_norm.mean(axis=0))

In [None]:
# plt.scatter(np.arange(config.z_dim),fve_dth_norm.mean(axis=0))
# plt.xlabel('Z_n')
# plt.ylabel('fve_dtheta')
# # plt.ylim(0.0,0.5)
# np.argmax(fve_dth_norm.mean(axis=0))

In [None]:
from src.plot import Z_color_scatter
Z_color_scatter(z_enc,[xdim,ydim],dxs)

In [None]:
Z_color_scatter(z_enc,[xdim,ydim],dys)

In [None]:
Z_color_scatter(z_enc,[7,18],dtheta)

In [None]:
from plt.