In [None]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from keras.datasets import cifar10
from keras.callbacks import LearningRateScheduler
from keras.models import model_from_json
from keras.models import load_model
import numpy as np
import keras
import math
import time
import matplotlib.pyplot as plt
import sys
import tensorflow as tf

from ourwnet import create_wide_residual_network

'''
Function that returns the trainand test data of the CIFAR10 already preprocessed
'''
def getCIFAR10():
    # input image dimensions
    img_rows, img_cols = 32, 32
    num_classes = 10

    # the data, split between train and test sets
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    
    # format of the tensor
    if K.image_data_format() == 'channels_first':
        x_train = x_train.reshape(x_train.shape[0], 3, img_rows, img_cols)
        x_test = x_test.reshape(x_test.shape[0], 3, img_rows, img_cols)
        input_shape = (3, img_rows, img_cols)
    else:
        x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 3)
        x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 3)
        input_shape = (img_rows, img_cols, 3)

    # convert in to float the images
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    
    # new normalization with z-score
    mean = np.mean(x_train,axis=(0,1,2,3))
    std = np.std(x_train,axis=(0,1,2,3))
    x_train = (x_train-mean)/(std+1e-7)
    x_test = (x_test-mean)/(std+1e-7)
    
    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)
    
    print('CIFAR10 loaded')
    return x_train,y_train,x_test,y_test

'''
Small function that returns the shape of the CIFAR10 images
'''
def getCIFAR10InputShape():
    img_rows, img_cols = 32, 32
    if K.image_data_format() == 'channels_first':
        input_shape = (3, img_rows, img_cols)
    else:
        input_shape = (img_rows, img_cols, 3)
        
    return input_shape

'''
Function that loads from a file the teacher
'''
def getTeacher(file_name):
    # Model reconstruction from JSON file
    with open(file_name + '.json', 'r') as f:
        model = model_from_json(f.read())

    # Load weights into the new model
    model.load_weights(file_name + '.h5')
    
    print('Teacher loaded from' + file_name + '.h5')
    return model
    
'''
Function that loads from a file the teacher and test it on the CIRAF10 dataset
'''
def testTeacher(file_name):
    
    x_train,y_train,x_test,y_test = getCIFAR10()
    
    model = getTeacher(file_name)
    
    # define optimizer
    opt_rms = keras.optimizers.rmsprop(lr=0.001,decay=1e-6)

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=opt_rms,
                  metrics=['accuracy'])

    # final evaluation on test
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Teacher test loss:', score[0])
    print('Teacher test accuracy:', score[1])

    
'''
Function that returns a simple student done by 2 convolutions, a maxpool and a final two fully connected layers
'''
def getSimpleStudent(input_shape):
    num_classes = 10
    
    model_train,m1,m2,m3 = create_wide_residual_network(input_shape, nb_classes=10, N=2,k=1)
    
    print('Simple student loaded')
    return model_train
    
'''
Function to try to train the simple sutdent in order to unerstand its capabilites
'''
def trainSimpleStudent(epochs):
    
    x_train,y_train,x_test,y_test = getCIFAR10()
    
    input_shape = getCIFAR10InputShape()
    
    model = getSimpleStudent(input_shape)
    
    model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])
    
    batch_size = 128
    n_batches = math.floor( x_train.shape[0] / batch_size)

    for e in range(epochs):
    
        for i in range(0,n_batches):
            imgs = x_train[i*batch_size:(i+1)*batch_size]
            labels = y_train[i*batch_size:(i+1)*batch_size]
            loss = model.train_on_batch(imgs,labels)
            print("Epoch: " + str(e+1) + " batch " + str(i) + " loss: " + str(loss[0]) + " acc: " + str( 100*loss[1]))
            
        score = model.evaluate(x_test, y_test, verbose=0)
        print('After epoch ' + str(e+1) + ' test loss ' + str(score[0]) + ' test accuracy ' + str(score[1]))

'''
Function that returns a simple generator
'''
def getGenerator():

    noise_shape = (100,)

    model = Sequential()

    img_shape = getCIFAR10InputShape()

    model.add(Dense(128, input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))

    noise = Input(shape=noise_shape)
    img = model(noise)

    print('Generator loaded')
    return Model(noise, img)

def getGAN(teacher,student,generator):
    z = Input(shape=(100,))
    img = generator(z)
    student.trainable = False
    teacher.trainable = False
    
    out_t = teacher(img)
    out_s = student(img)
    
    joinedOutput = keras.layer.Concatenate()([out_t,out_s])
    
    gan = Model(z,joinedOutput)
    
    return gan

def gan_loss(y_true, y_pred):
    
    t_out = y_pred[0:9]
    s_out = y_pred[10:-1]
    
    loss = tf.keras.losses.KLDivergence(t_out,s_out) 
    min_loss = -loss
    
    return min_loss
        
def main():
    
    x_train,y_train,x_test,y_test = getCIFAR10()
    input_shape = getCIFAR10InputShape()
    
    teacher = getTeacher('teacher-16-2')
    teacher.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer='adam',
                  metrics=['accuracy'])
    
    
    student = getSimpleStudent(input_shape)
    
    student.compile(loss='kld',
                  optimizer='sgd',
                  metrics = ['accuracy'])
    
    generator = getGenerator()
    
    gan = getGAN(teacher,student,generator)
    gan.summary()
    
    gan.compile(loss=gan_loss(), optimizer='adam')
    
    
    n_batches = 1000
    batch_size = 128
    log_freq = 10
    ns = 10
    
    for i in range(n_batches):
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        
        fake_lbl = K.zeros((batch_size),20)
        
        g_loss = combined.train_on_batch(noise,fake_lbl)
        
        s_loss = 0
        for j in range(ns):
            s_loss += student.train_on_batch(gen_imgs,t_predictions)[0]
        
        print('batch ' + str(i) + '/' + str(n_batches) + ' G loss: ' + str(g_loss) + ' S loss: ' + str(s_loss))
        
        if (i % log_freq) == 0:
            score = student.evaluate(x_test, y_test, verbose=0)
            print('Student test loss: '  + str(score))
        
        
    score = student.evaluate(x_test, y_test, verbose=0)
    print('Student test loss: '  + str(score))
    

main()
