In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [10]:
from keras.layers import Input, Lambda, BatchNormalization, Conv2D, Reshape, Dense,\
                         Dropout, Activation, Flatten, LeakyReLU, Add, MaxPooling2D,\
                         GlobalMaxPooling2D, Subtract
from keras.losses import categorical_crossentropy
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam

from pathlib import Path
import numpy as np
from src.data.dataset import load_ferg
from src import PROJECT_ROOT
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [3]:
loader = load_ferg()
(x_train, y_train, p_train), (x_test, y_test, p_test) = loader.load_data()
num_y, num_p = loader.get_num_classes()
input_shape = x_train[0].shape
print(f'{num_y}, {num_p}, {input_shape}')

7, 6, (64, 64, 3)


In [19]:
def empty_loss(y_true, y_pred):
    return y_pred
def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val
def show_model(model):
    print('-'*80)
    print(model.summary())
    print(model.metrics_names)
    print('-'*80)

In [34]:
def build_decoder(num_classes, z_dim=128):
    x_in = Input((z_dim,))
    out = Dense(num_classes, activation='softmax')(x_in)
    model = Model(x_in, out)
    return model
def build_encoder(input_shape, z_dim=128):
    x_in = Input(input_shape)
    x = x_in
    field_size = 8
    for i in range(3):
        x = Conv2D(int(z_dim / 2**(2-i)),
                   kernel_size=(field_size, field_size),
                   padding='SAME')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)
        x = MaxPooling2D((2, 2))(x)
    x = GlobalMaxPooling2D()(x)
    return Model(x_in, x)
def build_adv(input_shape, encoder, decoder_y, decoder_p, num_y, num_p):
    x_in = Input(input_shape)
    y_in = Input((num_y,))
    p_in = Input((num_p,))
    x = x_in
    z = encoder(x)
    out_y = decoder_y(z)
    out_p = decoder_p(z)
    loss_layer = Lambda(lambda args: categorical_crossentropy(args[0], args[1]))
    loss_y = loss_layer([y_in, out_y])
    loss_p = loss_layer([p_in, out_p])
    loss = Subtract()([loss_y, loss_p])
    adv = Model([x_in, y_in, p_in], loss)
    return adv
def evaluate_encoder(train_data, test_data, num_classes, batch_size=256, num_epochs=20):
    decoder = build_decoder(num_classes)
    x_train, y_train = train_data
    x_test, y_test = test_data
    decoder.compile(optimizer=Adam(1e-3), loss='categorical_crossentropy', metrics=['accuracy'])
    history = decoder.fit(x=x_train, y=y_train, epochs=num_epochs,batch_size=batch_size,\
                validation_data=(x_test, y_test),verbose=0)
    return np.max(history.history['val_acc'])
def train(train_data, input_shape, num_y, num_p, num_epochs=20, batch_size=128, dry_run=False):
    x_train, y_train, p_train = train_data
    num_example = x_train.shape[0]
    encoder = build_encoder(input_shape)
    decoder_y = build_decoder(num_y)
    decoder_y.compile(optimizer=Adam(1e-3), loss='categorical_crossentropy')
    decoder_p = build_decoder(num_p)
    decoder_p.compile(optimizer=Adam(1e-3), loss='categorical_crossentropy')
    adv = build_adv(input_shape, encoder, decoder_y, decoder_p, num_y, num_p)
    adv.compile(optimizer=Adam(1e-3), loss=empty_loss)
    if dry_run:
        show_model(decoder_y)
        show_model(decoder_p)
        show_model(adv)
        return
    batch_count = num_example // batch_size
    #print(evaluate_encoder((x_train, y_train), (x_test, y_test), num_y, num_epochs=20))
    #print(evaluate_encoder((x_train, p_train), (x_test, p_test), num_p, num_epochs=20))
    for i in range(num_epochs):
        loss_y_history, loss_p_history, loss_adv_history = [], [], []
        idx = np.random.permutation(num_example)
        #evaluate encoder
        z_train = encoder.predict(x_train)
        z_test = encoder.predict(x_test)
        acc1 = evaluate_encoder((z_train, y_train), (z_test, y_test), num_y, num_epochs=20)
        acc2 = evaluate_encoder((z_train, p_train), (z_test, p_test), num_p, num_epochs=20)
        print(f'epoch{i}: acc1={acc1}, acc2={acc2}')
        for j in range(batch_count):
            selected_idx = idx[j*batch_size: (j+1)*batch_size]
            x_batch = x_train[selected_idx]
            y_batch = y_train[selected_idx]
            p_batch = p_train[selected_idx]
            #train decoder_y
            z_batch = encoder.predict_on_batch(x_batch)
            make_trainable(decoder_y, True)
            make_trainable(decoder_p, True)
            loss_y = decoder_y.train_on_batch(z_batch, y_batch)
            loss_p = decoder_p.train_on_batch(z_batch, p_batch)
            loss_y_history.append(loss_y)
            loss_p_history.append(loss_p)
            #train adv
            #make_trainable(decoder_y, False)
            make_trainable(decoder_p, False)
            loss_adv = adv.train_on_batch([x_batch, y_batch, p_batch], y_batch)
            loss_adv_history.append(loss_adv)
        print(f'epoch{i}: loss_y:{np.mean(loss_y_history)}, loss_p:{np.mean(loss_p_history)}, loss_adv:{np.mean(loss_adv_history)}')
train((x_train, y_train, p_train), input_shape, num_y, num_p, dry_run=False)

epoch0: acc1=0.5298266587539554, acc2=1.0
epoch0: loss_y:0.3752005100250244, loss_p:14.921903610229492, loss_adv:-15.293718338012695
epoch1: acc1=1.0, acc2=1.0
epoch1: loss_y:0.00109429145231843, loss_p:16.11809730529785, loss_adv:-16.117412567138672
epoch2: acc1=1.0, acc2=1.0
epoch2: loss_y:0.00025368554634042084, loss_p:16.11809730529785, loss_adv:-16.117868423461914
epoch3: acc1=0.9813508664143705, acc2=1.0
epoch3: loss_y:0.0001281958248000592, loss_p:16.11809730529785, loss_adv:-16.11797332763672
epoch4: acc1=1.0, acc2=1.0
epoch4: loss_y:7.683429430471733e-05, loss_p:16.11809730529785, loss_adv:-16.11802101135254
epoch5: acc1=1.0, acc2=1.0
epoch5: loss_y:0.5187912583351135, loss_p:15.665657997131348, loss_adv:-16.078916549682617
epoch6: acc1=0.9995218170950388, acc2=0.7260011959204187
epoch6: loss_y:0.086405448615551, loss_p:14.478635787963867, loss_adv:-16.090120315551758
epoch7: acc1=1.0, acc2=1.0
epoch7: loss_y:0.09762684255838394, loss_p:14.853155136108398, loss_adv:-16.0327377