In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from keras.layers import Input, Lambda, BatchNormalization, Conv2D, Reshape, Dense,\
                         Dropout, Activation, Flatten, LeakyReLU, Add, MaxPooling2D,\
                         GlobalMaxPooling2D, Subtract, Concatenate
from keras.losses import categorical_crossentropy
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam

from pathlib import Path
import numpy as np
from src.data.dataset import load_ferg
from src import PROJECT_ROOT
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

Using TensorFlow backend.


In [3]:
loader = load_ferg()
(x_train, y_train, p_train), (x_test, y_test, p_test) = loader.load_data()
num_y, num_p = loader.get_num_classes()
input_shape = x_train[0].shape
print(f'{num_y}, {num_p}, {input_shape}')

7, 6, (64, 64, 3)


In [4]:
def empty_loss(y_true, y_pred):
    return y_pred
def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val
def show_model(model):
    print('-'*80)
    print(model.summary())
    print(model.metrics_names)
    print('-'*80)

In [15]:
def build_classifier(num_classes, feature_dim=128):
    x_in = Input((feature_dim,))
    out = Dense(num_classes, activation='softmax')(x_in)
    model = Model(x_in, out)
    return model
def evaluate_encoder(train_data, test_data, num_classes, batch_size=256, num_epochs=20):
    decoder = build_classifier(num_classes)
    x_train, y_train = train_data
    x_test, y_test = test_data
    decoder.compile(optimizer=Adam(1e-3), loss='categorical_crossentropy', metrics=['accuracy'])
    history = decoder.fit(x=x_train, y=y_train, epochs=num_epochs,batch_size=batch_size,\
                validation_data=(x_test, y_test),verbose=0)
    return np.max(history.history['val_acc'])
def build_encoder(input_shape, z_dim=128):
    x_in = Input(input_shape)
    x = x_in
    field_size = 8
    for i in range(3):
        x = Conv2D(int(z_dim / 2**(2-i)),
                   kernel_size=(field_size, field_size),
                   padding='SAME')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)
        x = MaxPooling2D((2, 2))(x)
    x = GlobalMaxPooling2D()(x)
    return Model(x_in, x)
def shuffling(x):
    idxs = K.arange(0, K.shape(x)[0])
    idxs = K.tf.random_shuffle(idxs)
    return K.gather(x, idxs)
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0.,
                              stddev=epsilon_std)
    return z_mean + K.exp(z_log_var / 2) * epsilon

def build_discriminator(input_shape, z_dim=32):
    x_in = Input(input_shape)
    z = x_in
    z = Dense(z_dim, activation='relu')(z)
    z = Dense(z_dim, activation='relu')(z)
    z = Dense(z_dim, activation='relu')(z)
    z = Dense(1, activation='sigmoid')(z)
    return Model(x_in, z)
def build_mi_estimator(x_dim, y_dim):
    x_in = Input((x_dim,))
    y_in = Input((y_dim,))
    y_shuffle = Lambda(shuffling)(y_in)
    dis = build_discriminator((x_dim+y_dim,))
    score_joint = dis(Concatenate()([x_in, y_in]))
    score_indep = dis(Concatenate()([x_in, y_shuffle]))
    score_layer = Lambda(lambda x: - K.mean(K.log(x[0] + 1e-6) + K.log(1 - x[1] + 1e-6)))
    score = score_layer([score_joint, score_indep])
    return Model([x_in, y_in], score)

def train(train_data, input_shape, num_y, num_p, num_epochs=20, batch_size=128, dry_run=False, z_dim=128):
    x_train, y_train, p_train = train_data
    num_example = x_train.shape[0]
    encoder = build_encoder(input_shape, z_dim)
    decoder_y = build_mi_estimator(z_dim, num_y)
    decoder_y.compile(optimizer=Adam(1e-3), loss=empty_loss)
    decoder_p = build_mi_estimator(z_dim, num_p)
    decoder_p.compile(optimizer=Adam(1e-3), loss=empty_loss)
    adv  = build_adv(input_shape, encoder, decoder_y, decoder_p, num_y, num_p)
    adv.compile(optimizer=Adam(1e-3), loss=empty_loss)
    if dry_run:
        show_model(decoder_y)
        show_model(decoder_p)
        show_model(adv)
        return
    batch_count = num_example // batch_size
    for i in range(num_epochs):
        loss_y_history, loss_p_history, loss_adv_history = [], [], []
        idx = np.random.permutation(num_example)
        #evaluate encoder
        z_train = encoder.predict(x_train)
        z_test = encoder.predict(x_test)
        acc2 = evaluate_encoder((z_train, p_train), (z_test, p_test), num_p, num_epochs=20)
        acc1 = evaluate_encoder((z_train, y_train), (z_test, y_test), num_y, num_epochs=20)
        print(f'epoch{i}: acc1={acc1}, acc2={acc2}')
        for j in range(batch_count):
            selected_idx = idx[j*batch_size: (j+1)*batch_size]
            x_batch = x_train[selected_idx]
            y_batch = y_train[selected_idx]
            p_batch = p_train[selected_idx]
            #train decoder_y
            z_batch = encoder.predict_on_batch(x_batch)
            make_trainable(decoder_y, True)
            make_trainable(decoder_p, True)
            loss_y = decoder_y.train_on_batch([z_batch, y_batch], y_batch)
            loss_p = decoder_p.train_on_batch([z_batch, p_batch], p_batch)
            loss_y_history.append(loss_y)
            loss_p_history.append(loss_p)
            #train adv
            make_trainable(decoder_y, False)
            make_trainable(decoder_p, False)
            loss_adv = adv.train_on_batch([x_batch, y_batch, p_batch], y_batch)
            loss_adv_history.append(loss_adv)
        print(f'epoch{i}: loss_y:{np.mean(loss_y_history)}, loss_p:{np.mean(loss_p_history)}, loss_adv:{np.mean(loss_adv_history)}')
train((x_train, y_train, p_train), input_shape, num_y, num_p, dry_run=False)

epoch0: acc1=0.5396294082842825, acc2=1.0
epoch0: loss_y:1.3855125904083252, loss_p:13.72799015045166, loss_adv:-12.193449020385742
epoch1: acc1=0.8334728035681702, acc2=1.0


KeyboardInterrupt: 