In [20]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from keras.datasets import cifar10
from keras.callbacks import LearningRateScheduler
from keras.models import model_from_json
from keras.models import load_model
import numpy as np
import keras
import math
import time
import matplotlib.pyplot as plt
import sys
import tensorflow as tf

from ourwrnet import create_wide_residual_network
from cifar10utils import getCIFAR10, getCIFAR10InputShape


'''
Function that loads from a file the teacher
'''
def getTeacher(file_name):
    # Model reconstruction from JSON file
    with open(file_name + '.json', 'r') as f:
        model = model_from_json(f.read())

    # Load weights into the new model
    model.load_weights(file_name + '.h5')
    
    print('Teacher loaded from' + file_name + '.h5')
    return model
    
'''
Function that loads from a file the teacher and test it on the CIRAF10 dataset
'''
def testTeacher(file_name):
    x_train,y_train,x_test,y_test = getCIFAR10()
    model = getTeacher(file_name)
    opt_rms = keras.optimizers.rmsprop(lr=0.001,decay=1e-6)
    model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt_rms, metrics=['accuracy'])
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Teacher test loss:', score[0])
    print('Teacher test accuracy:', score[1])
    
'''
Function that returns a simple student done by 2 convolutions, a maxpool and a final two fully connected layers
'''
def getStudent(input_shape):
    num_classes = 10
    
    model_train,model_test,m1,m2,m3 = create_wide_residual_network(input_shape, nb_classes=10, N=2,k=1)
    
    print('Simple student loaded')
    return model_train, model_test

'''
Function that returns a simple generator
'''
def getGenerator():

    noise_shape = (100,)

    model = Sequential()

    img_shape = getCIFAR10InputShape()

    model.add(Dense(128*8**2, input_shape=noise_shape))
    model.add(Reshape((8, 8, 128)))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2)) 

    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=(3,3), strides=1, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(3, kernel_size=(3,3), strides=1, padding="same"))
    model.add(BatchNormalization())   

    #model.summary()
    print('Generator loaded')
    
    return model

def getGAN(teacher,student,generator):
    z = Input(shape=(100,))
    img = generator(z)
    student.trainable = False
    teacher.trainable = False
    
    out_t = teacher(img)
    out_s = student(img)
    
    joinedOutput = Concatenate()([out_t,out_s])
    
    gan = Model(z,joinedOutput)
    
    return gan

def gan_loss(y_true, y_pred):
    
    t_out = y_pred[0:9]
    s_out = y_pred[10:-1]
    
    loss = keras.losses.kullback_leibler_divergence(t_out,s_out)
    min_loss = -loss
    
    return min_loss

def useless_loss(y_true, y_pred):
    
    zer = K.zeros((1,1))
    
    return zer
        
def main():
    
    x_train,y_train,x_test,y_test = getCIFAR10()
    input_shape = getCIFAR10InputShape()
    
    teacher = getTeacher('./pretrained_models/wrn_16_2')
    teacher.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    
    optim_stud = Adam(lr=2e-3, clipnorm=5.0)
    optim_gen = Adam(lr=1e-3, clipnorm=5.0)
    
    student_train, student_test = getStudent(input_shape)
    
    student_test.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    
    student_train.compile(loss=keras.losses.kullback_leibler_divergence, optimizer=optim_stud)
    
    generator = getGenerator()
    
    gan = getGAN(teacher,student_train,generator)
    gan.summary()
    
    gan.compile(loss=useless_loss, optimizer=optim_gen)
    
    
    test_z = np.random.normal(0, 1, (2, 100))
    out = gan.predict(test_z)
    print(str(out))
    print(out.shape)
    
    
    
    n_batches = 1000
    batch_size = 128
    log_freq = 10
    ns = 10
    
    for i in range(n_batches):
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        
        fake_lbl = K.zeros((batch_size,20))
        
        g_loss = gan.train_on_batch(noise,fake_lbl)
        
        s_loss = 0
        for j in range(ns):
            s_loss += student_train.train_on_batch(gen_imgs,t_predictions)[0]
        
        print('batch ' + str(i) + '/' + str(n_batches) + ' G loss: ' + str(g_loss) + ' S loss: ' + str(s_loss))
        
        if (i % log_freq) == 0:
            score = student_test.evaluate(x_test, y_test, verbose=0)
            print('Student test loss: '  + str(score))
        
        
    score = student_test.evaluate(x_test, y_test, verbose=0)
    print('Student test loss: '  + str(score))
    

main()


CIFAR10 loaded
Teacher loaded from./pretrained_models/wrn_16_2.h5
Wide Residual Network-16-1 created.
Simple student loaded
Generator loaded
Model: "model_104"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_35 (InputLayer)           (None, 100)          0                                            
__________________________________________________________________________________________________
sequential_17 (Sequential)      (None, 32, 32, 3)    1051791     input_35[0][0]                   
__________________________________________________________________________________________________
model_17 (Model)                (None, 10)           693498      sequential_17[1][0]              
__________________________________________________________________________________________________
model_99 (Model)                (None, 10)      

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.