In [2]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from keras.datasets import cifar10
from keras.callbacks import LearningRateScheduler
from keras.models import model_from_json
from keras.models import load_model
import numpy as np
import keras
import math
import time
import matplotlib.pyplot as plt
import sys
import tensorflow as tf

import warnings
warnings.filterwarnings('ignore')

from ourwrnet import create_wide_residual_network
from student_wrnet import create_wide_residual_network_student
from cifar10utils import getCIFAR10, getCIFAR10InputShape


'''
Function that loads from a file the teacher
'''
def getTeacher(file_name):
    with open(file_name + '.json', 'r') as f:
        model = model_from_json(f.read())
    model.load_weights(file_name + '.h5')
    
    with open(file_name + '_layer1.json', 'r') as f:
        m1 = model_from_json(f.read())
    m1.load_weights(file_name + '_layer1.h5')
    
    with open(file_name + '_layer2.json', 'r') as f:
        m2 = model_from_json(f.read())
    m2.load_weights(file_name + '_layer2.h5')
    
    with open(file_name + '_layer3.json', 'r') as f:
        m3 = model_from_json(f.read())
    m3.load_weights(file_name + '_layer3.h5')    
    
    
    
    print('Teacher loaded from' + file_name + '.h5')
    return model,m1,m2,m3
    
'''
Function that loads from a file the teacher and test it on the CIRAF10 dataset
'''
def testTeacher(file_name):
    x_train,y_train,x_test,y_test = getCIFAR10()
    model = getTeacher(file_name)
    opt_rms = keras.optimizers.rmsprop(lr=0.001,decay=1e-6)
    model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt_rms, metrics=['accuracy'])
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Teacher test loss:', score[0])
    print('Teacher test accuracy:', score[1])
    
'''
Function that returns a simple student done by 2 convolutions, a maxpool and a final two fully connected layers
'''
def getStudent(input_shape):
    num_classes = 10
    
    model_train,model_test, = create_wide_residual_network_student(input_shape, nb_classes=10, N=2,k=1)
    
    print('Simple student loaded')
    return model_train, model_test

'''
Function that returns a simple generator
'''
def getGenerator():

    noise_shape = (100,)

    model = Sequential()

    img_shape = getCIFAR10InputShape()

    model.add(Dense(128*8**2, input_shape=noise_shape))
    model.add(Reshape((8, 8, 128)))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2)) 

    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=(3,3), strides=1, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(3, kernel_size=(3,3), strides=1, padding="same"))
    model.add(BatchNormalization())   

    #model.summary()
    print('Generator loaded')
    
    return model

def getGAN(teacher,student,generator):
    z = Input(shape=(100,))
    img = generator(z)
    #student.trainable = False # it works if it is not true
    teacher.trainable = False
    
    out_t = teacher(img)
    out_s = student(img)
    
    joinedOutput = Concatenate()([out_t,out_s[0]])
    
    gan = Model(z,joinedOutput)
    
    return gan

def gan_loss(y_true, y_pred):
    t_out = y_pred[:,0:10]
    s_out = y_pred[:,10:21]
    
    loss = keras.losses.kullback_leibler_divergence(t_out,s_out)
    min_loss = (-(loss))
    
    return min_loss

def student_loss(y_true,y_pred):
    
    subtracted = (y_true - y_pred)
    
    power2 = K.pow(subtracted,2)
    
    avg = K.mean(power2,-1)
    
    beta = 250
    
    to_return = avg*beta
    
    return to_return
        
def main():
    
    x_train,y_train,x_test,y_test = getCIFAR10()
    input_shape = getCIFAR10InputShape()
    
    teacher, t_layer1, t_layer2,t_layer3 = getTeacher('./pretrained_models/wrn_16_2')
    teacher.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    t_layer1.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    t_layer2.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    t_layer3.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    
    optim_stud = Adam(lr=2e-3, clipnorm=5.0)
    optim_gen = Adam(lr=1e-3, clipnorm=5.0)
    
    student_train, student_test = getStudent(input_shape)
    
    student_test.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    
    student_train.compile(loss=[keras.losses.kullback_leibler_divergence,student_loss,student_loss,student_loss], 
                          optimizer=optim_stud)
    
    generator = getGenerator()
    '''
    print('TEACHER SUMMARY:')
    teacher.summary()
    print('TEACHER L1 SUMMARY:')
    t_layer1.summary()
    print('TEACHER L2 SUMMARY:')
    t_layer2.summary()
    print('TEACHER L3 SUMMARY:')
    t_layer3.summary()
    
    print('STUDENT SUMMARY:')
    student_train.summary()
    '''
    
    gan = getGAN(teacher,student_train,generator)
    gan.summary()
    
    gan.compile(loss=gan_loss, optimizer=optim_gen)
    
    noise = np.random.normal(0, 1, (4, 100))
    gen_imgs = generator.predict(noise)
    t_predictions = teacher.predict(gen_imgs)
    t_pred_l1 = t_layer1.predict(gen_imgs)
    t_pred_l2 = t_layer2.predict(gen_imgs)
    t_pred_l3 = t_layer3.predict(gen_imgs)
    
    print('Teacher predictions shape: ' + str(t_predictions.shape))
    print('Teacher LAYER1 predictions shape: ' + str(t_pred_l1[0].shape))
    print('Teacher LAYER2 predictions shape: ' + str(t_pred_l2[0].shape))
    print('Teacher LAYER3 predictions shape: ' + str(t_pred_l3[0].shape))
    
    s_predictions = student_train.predict(gen_imgs)
    for i in range(len(s_predictions)):
        print('Student Predictions ' + str(i) + ' shape: ' + str(s_predictions[i].shape))    
    
    n_batches = 1000
    batch_size = 2
    log_freq = 10
    ns = 10
    
    for i in range(n_batches):
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        t_predictions = teacher.predict(gen_imgs)
        t_pred_l1 = t_layer1.predict(gen_imgs)[0]
        t_pred_l2 = t_layer2.predict(gen_imgs)[0]
        t_pred_l3 = t_layer3.predict(gen_imgs)[0]
        
        fake_lbl = K.zeros((batch_size,20))
        g_loss = gan.train_on_batch(noise,fake_lbl)
        s_loss = 0
        
        fake1 = K.zeros((batch_size,1024))
        fake2 = K.zeros((batch_size,256))
        fake3 = K.zeros((batch_size,64))        
        
        for j in range(ns):
            s_loss = student_train.train_on_batch(gen_imgs,[t_predictions, t_pred_l1,t_pred_l2,t_pred_l3])
        
        print('batch ' + str(i) + '/' + str(n_batches) + ' G loss: ' + str(g_loss) + ' S loss: ' + str(s_loss))
        
        if (i % log_freq) == 0:
            score = student_test.evaluate(x_test, y_test, verbose=0)
            print('Student test loss: '  + str(score))
        
        
    score = student_test.evaluate(x_test, y_test, verbose=0)
    print('Student test loss: '  + str(score))

main()


CIFAR10 loaded
Teacher loaded from./pretrained_models/wrn_16_2.h5
Wide Residual Network-16-1 created.
Simple student loaded
Generator loaded
Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
sequential_2 (Sequential)       (None, 32, 32, 3)    1051791     input_4[0][0]                    
__________________________________________________________________________________________________
model_1 (Model)                 (None, 10)           693498      sequential_2[1][0]               
__________________________________________________________________________________________________
model_4 (Model)                 [(None, 10), (None

W1011 14:44:49.050402 139700836164992 deprecation_wrapper.py:119] From /home/test/anaconda3/envs/tensorflow_gpuenv/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.



Teacher predictions shape: (4, 10)
Teacher LAYER1 predictions shape: (4, 1024)
Teacher LAYER2 predictions shape: (4, 256)
Teacher LAYER3 predictions shape: (4, 64)
Student Predictions 0 shape: (4, 10)
Student Predictions 1 shape: (4, 1024)
Student Predictions 2 shape: (4, 256)
Student Predictions 3 shape: (4, 64)


W1011 14:44:56.875252 139700836164992 deprecation.py:323] From /home/test/anaconda3/envs/tensorflow_gpuenv/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


batch 0/1000 G loss: -0.11603979 S loss: [2.0202286, 1.3896966, 0.03821964, 0.02991215, 0.006195548]
Student test loss: [2.8708274932861326, 0.10000000149011612]
batch 1/1000 G loss: -0.08200515 S loss: [1.6171075, 1.0523716, 0.022444464, 0.01997071, 0.006492977]
batch 2/1000 G loss: 0.1518204 S loss: [1.3439136, 0.8210882, 0.017655114, 0.020597614, 0.006885791]
batch 3/1000 G loss: -0.065604374 S loss: [1.3488047, 0.86577517, 0.013746072, 0.019581402, 0.00597917]
batch 4/1000 G loss: 0.0034220964 S loss: [0.9953609, 0.5381958, 0.012969375, 0.022635676, 0.0073476126]
batch 5/1000 G loss: -0.09316184 S loss: [0.98068476, 0.5522411, 0.011239606, 0.018858574, 0.009418648]
batch 6/1000 G loss: -0.91599685 S loss: [0.7694698, 0.34780324, 0.011329107, 0.021881284, 0.019346815]
batch 7/1000 G loss: -0.74169374 S loss: [1.4053622, 0.99699354, 0.012456604, 0.026853785, 0.01631473]
batch 8/1000 G loss: -0.75841373 S loss: [0.8205828, 0.42415768, 0.011068223, 0.021811627, 0.025738763]
batch 9/100

KeyboardInterrupt: 

In [4]:
import torch.nn.functional as F
import torch

def attention(x):

    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))

def attention_diff(x, y):

    return (attention(x) - attention(y)).pow(2).mean()

x1 = torch.rand(4,32,32,32)
y1 = attention(x1)
print(str(y1.shape))
print(str(y1))

x2 = torch.rand(4,16,32,32)
y2 = attention(x2)
print(str(y2.shape))
print(str(y2))

y3 = attention_diff(x1,x2)
print(str(y3))
print(str(y3.shape))



torch.Size([4, 1024])
tensor([[0.0281, 0.0180, 0.0306,  ..., 0.0311, 0.0323, 0.0307],
        [0.0376, 0.0267, 0.0293,  ..., 0.0250, 0.0251, 0.0296],
        [0.0344, 0.0360, 0.0278,  ..., 0.0272, 0.0277, 0.0373],
        [0.0358, 0.0309, 0.0307,  ..., 0.0341, 0.0362, 0.0355]])
torch.Size([4, 1024])
tensor([[0.0275, 0.0311, 0.0335,  ..., 0.0392, 0.0358, 0.0284],
        [0.0417, 0.0398, 0.0382,  ..., 0.0369, 0.0415, 0.0331],
        [0.0202, 0.0169, 0.0289,  ..., 0.0372, 0.0257, 0.0403],
        [0.0257, 0.0245, 0.0465,  ..., 0.0305, 0.0282, 0.0244]])
tensor(6.9508e-05)
torch.Size([])
