In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
from keras.layers import Input, Lambda, BatchNormalization, Conv2D, Reshape, Dense,\
                         Dropout, Activation, Flatten, LeakyReLU, Add, MaxPooling2D,\
                         GlobalMaxPooling2D, Subtract, Concatenate, Average, Conv2DTranspose,\
                         GlobalAveragePooling2D
from keras.losses import categorical_crossentropy, mean_squared_error
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.callbacks import Callback

from pathlib import Path
import numpy as np
from src.data.dataset import load_ferg
from src.evaluation.resnet import resnet_v1
from src import PROJECT_ROOT
import imageio
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

Using TensorFlow backend.


In [14]:
loader = load_ferg()
(x_train, y_train, p_train), (x_test, y_test, p_test) = loader.load_data()
num_y, num_p = loader.get_num_classes()
input_shape = x_train[0].shape
print(f'{num_y}, {num_p}, {input_shape}')

7, 6, (64, 64, 3)


In [4]:
def empty_loss(y_true, y_pred):
    return y_pred
def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val
def show_model(model):
    print('-'*80)
    print(model.summary())
    print(model.metrics_names)
    print('-'*80)

In [44]:
#train model
def evaluate_encoder(train_data, test_data, num_classes, batch_size=256, num_epochs=20):
    decoder = build_classifier(num_classes)
    x_train, y_train = train_data
    x_test, y_test = test_data
    decoder.compile(optimizer=Adam(1e-3), loss='categorical_crossentropy', metrics=['accuracy'])
    history = decoder.fit(x=x_train, y=y_train, epochs=num_epochs,batch_size=batch_size,\
                validation_data=(x_test, y_test),verbose=0)
    return np.max(history.history['val_acc'])

def shuffling(x):
    idxs = K.arange(0, K.shape(x)[0])
    idxs = K.tf.random_shuffle(idxs)
    return K.gather(x, idxs)
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
def kl_loss_func(args):
    z_mean, z_log_var = args
    loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return loss
def rec_loss_func(args):
    y_true, y_pred = args
    return K.mean(K.square(y_pred - y_true))
def categorical_loss_func(args):
    y_true, y_pred = args
    return categorical_crossentropy(y_true, y_pred)

def build_mi_1(z_dim):
    z_in = Input(shape=(z_dim*2,))
    z = z_in
    z = Dense(z_dim, activation='relu')(z)
    z = Dense(z_dim, activation='relu')(z)
    z = Dense(z_dim, activation='relu')(z)
    z = Dense(1, activation='sigmoid')(z)
    model_t = Model(z_in, z)
    x = Input(shape=(z_dim,))
    x_shuffle = Lambda(shuffling)(x)
    real_score = model_t(Concatenate()([x, x]))
    fake_score = model_t(Concatenate()([x, x_shuffle]))
    def loss_func(args):
        score1, score2 = args
        return - K.mean(K.log(score1 + 1e-6) + K.log(1 - score2 + 1e-6))
    loss = Lambda(loss_func)([real_score, fake_score])
    model_mi = Model(x, loss)
    return model_mi
    
class Evaluate(Callback):
    def __init__(self, task, save_path=None):
        import os
        self._task = task
        self.lowest = 1e10
        self.losses = []
        if not os.path.exists('samples'):
            os.mkdir('samples')
        self.save_path = save_path
    def on_epoch_end(self, epoch, logs=None):
        path = 'samples/test_%s.png' % epoch
        self._task.sample_all(path)
        self.losses.append((epoch, logs['loss']))
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            if self.save_path is not None:
                self.model.save_weights(self.save_path)
class Vae:
    def __init__(self, data_loader, z_dim=128, debug=False):
        #get data
        if debug:
            train_data, test_data = data_loader.load_data(max_train=100, max_test=100)
        else:
            train_data, test_data = data_loader.load_data()
        num_y, num_p = data_loader.get_num_classes()
        num_train = train_data[0].shape[0]
        num_test = test_data[0].shape[0]
        img_dim = test_data[0].shape[1]
        input_shape = (img_dim, img_dim, 3)
        train_data = Vae.transform_data(train_data, num_y, num_p)
        test_data = Vae.transform_data(test_data, num_y, num_p)
        #build model
        encoder = self.build_encoder(input_shape, z_dim=z_dim)
        decoder = self.build_decoder(z_dim, img_dim)
        classifier = self.build_classifier(input_shape, num_y)
        cvae = self.build_cvae(input_shape, num_y, encoder, decoder, classifier)
        #assign variable
        self.z_dim = z_dim
        self.img_dim = img_dim
        self.input_shape = input_shape
        self.num_y = num_y
        self.num_p = num_p
        self.num_train = num_train
        self.num_test = num_test
        self.train_data = train_data
        self.test_data = test_data
        self.encoder = encoder
        self.decoder = decoder
        self.classifier = classifier
        self.cvae = cvae
    def build_encoder(self, input_shape, z_dim):
        x_in = Input(input_shape)
        x = x_in
        field_size = 8
        for i in range(3):
            x = Conv2D(int(z_dim / 2**(2-i)),
                       kernel_size=(field_size, field_size),
                       padding='SAME')(x)
            x = BatchNormalization()(x)
            x = LeakyReLU(0.2)(x)
            x = MaxPooling2D((2, 2))(x)
        x = GlobalMaxPooling2D()(x)
        z_mean = Dense(z_dim)(x)
        z_log_var = Dense(z_dim)(x)
        x = Lambda(sampling)([z_mean, z_log_var])
        return Model(x_in, [x, z_mean, z_log_var])
    def build_decoder(self, z_dim, img_dim):
        k = 8
        units = z_dim
        x = Input(shape=(z_dim,))
        h = x
        h = Dense(4 * 4 * 128, activation='relu')(h)
        h = Reshape((4, 4, 128))(h)
        # h = LeakyReLU(0.2)(h)
        h = Conv2DTranspose(units, (k, k), strides=(2, 2), padding='same', activation='relu')(h)  # 32*32*64
        # h = Dropout(dropout)(h)
        h = BatchNormalization(momentum=0.8)(h)
        # h = LeakyReLU(0.2)(h)
        # h = UpSampling2D(size=(2, 2))(h)
        h = Conv2DTranspose(units // 2, (k, k), strides=(2, 2), padding='same', activation='relu')(h)  # 64*64*64
        # h = Dropout(dropout)(h)
        h = BatchNormalization(momentum=0.8)(h)
        # h = LeakyReLU(0.2)(h)
        # h = UpSampling2D(size=(2, 2))(h)
        h = Conv2DTranspose(units // 2, (k, k), strides=(2, 2), padding='same', activation='relu')(h)  # 8*6*64
        # h = Dropout(dropout)(h)
        h = BatchNormalization(momentum=0.8)(h)

        h = Conv2DTranspose(3, (k, k), strides=(2, 2), padding='same', activation='tanh')(h)  # 8*6*64
        return Model(x, h, name="Decoder")
    def build_classifier(self, input_shape, num_classes):
        x_in = Input(shape=input_shape)
        label_in = Input(shape=(num_classes,))
        img_dim = input_shape[0]
        x = x_in
        x = Conv2D(img_dim // 2,
                   (5, 5),
                   strides=(2, 2),
                   padding='same')(x)
        x = LeakyReLU()(x)

        for i in range(4):
            x = Conv2D(img_dim * 2**i,
                       (5, 5),
                       strides=(2, 2),
                       padding='same')(x)
            x = BatchNormalization()(x)
            x = LeakyReLU()(x)

        x = GlobalAveragePooling2D()(x)
        out = Dense(num_classes, activation='softmax')(x)
        loss = Lambda(categorical_loss_func)([label_in, out])
        model = Model([x_in, label_in], loss)
        return model
    def build_cvae(self, input_shape, num_y, encoder, decoder,classifier):
        x_in = Input(shape=input_shape)
        y_in = Input(shape=(num_y,))
        z, z_mean, z_log_var = encoder(x_in)
        kl_loss = Lambda(kl_loss_func)([z_mean, z_log_var])
        x_rec = decoder(z)
        classify_loss = classifier([x_rec, y_in])
        rec_loss = Lambda(rec_loss_func)([x_in, x_rec])

        def weight_loss_func(args):
            kl_loss, rec_loss, classify_loss = args
            return 0.005 * kl_loss + rec_loss + classify_loss
        loss = Lambda(weight_loss_func)([kl_loss, rec_loss, classify_loss])
        return Model([x_in, y_in], loss)
    @staticmethod
    def transform_data(data, num_y, num_p):
        x, y, p= data
        x = (x-127.5)/127.5
        y = to_categorical(y, num_y)
        p = to_categorical(p, num_p)
        return (x, y, p)
    @staticmethod
    def recover_data(data):
        x, y, p = data
        x = (x + 1) / 2 * 255
        x = x.astype(np.uint8)
        x = np.clip(x, 0, 255)
        y = np.argmax(y, axis=-1)
        p = np.argmax(p, axis=-1)
        return x, y, p
    def sample_all(self, file_path):
        x_train, y_train, p_train = self.train_data
        num_y = self.num_y
        num_p = self.num_p
        num_data = self.num_train
        img_dim = self.img_dim
        output = np.zeros((2*num_p*img_dim, num_y*img_dim, 3))
        for i in range(num_p):
            for j in range(num_y):
                for idx in range(num_data):
                    if (np.argmax(p_train[idx]) == i) and (np.argmax(y_train[idx]) == j):
                        x, y, p = x_train[idx], y_train[idx], p_train[idx]
                        x_fake = self.predict_single(x)
                        output[i*img_dim:(i+1)*img_dim, j*img_dim:(j+1)*img_dim,:] = Vae.recover_data((x, y, p))[0]
                        output[i*img_dim+num_p*img_dim:(i+1)*img_dim+num_p*img_dim, \
                               j*img_dim:(j+1)*img_dim,:] = Vae.recover_data((x_fake, y, p))[0]
                        break
        imageio.imwrite(file_path, output)
    def load_weights(self, path):
        self.cvae.load_weights(path)
    def predict(self, x):
        z, z_mean, z_logvar = self.encoder.predict(x)
        rec_x = self.decoder.predict(z)
        return rec_x
    def predict_single(self, x):
        x = np.expand_dims(x, axis=0)
        rec_x = self.predict(x)[0]
        return rec_x
    def evaluate_y(self, num_epochs=20, batch_size=128):
        x_train, y_train, p_train = self.train_data
        x_test, y_test, p_test = self.test_data
        x_train = self.predict(x_train)
        x_test = self.predict(x_test)
        resnet = resnet_v1(self.input_shape, self.num_y)
        history = resnet.fit(x_train, y_train, validation_data=(x_test, y_test), \
                             batch_size=batch_size, epochs=num_epochs)
        acc = np.max(history.history['val_acc'])
        print(acc)
        return acc
    def evaluate_p(self, num_epochs=20, batch_size=128):
        x_train, y_train, p_train = self.train_data
        x_test, y_test, p_test = self.test_data
        x_train = self.predict(x_train)
        x_test = self.predict(x_test)
        resnet = resnet_v1(self.input_shape, self.num_p)
        history = resnet.fit(x_train, p_train, validation_data=(x_test, p_test), \
                             batch_size=batch_size, epochs=num_epochs)
        acc = np.max(history.history['val_acc'])
        print(acc)
        return acc
    def train(self, save_path=None, num_epochs=20, batch_size=128):
        self.cvae.compile(optimizer=Adam(1e-4), loss=empty_loss)
        x_train, y_train, p_train = self.train_data
        x_test, y_test, p_test = self.test_data
        evaluator = Evaluate(self, save_path=save_path)
        self.cvae.fit([x_train, y_train], y_train, validation_data=([x_test, y_test], y_test), \
                      batch_size=batch_size, epochs=num_epochs, callbacks=[evaluator])
vae = Vae(loader, debug=False)
vae.load_weights('./best.h5')
#vae.evaluate_p()
vae.evaluate_y()
#vae.train(save_path='./best.h5', num_epochs=40)

Learning rate:  0.001
Train on 47401 samples, validate on 8365 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
1.0


1.0