In [5]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from keras.datasets import cifar10
from keras.callbacks import LearningRateScheduler
from keras.models import model_from_json
from keras.models import load_model
import numpy as np
import keras
import math
import time
import matplotlib.pyplot as plt
import sys
import tensorflow as tf

import warnings
warnings.filterwarnings('ignore')

from ourwrnet import create_wide_residual_network
from new_stud import create_wide_residual_network_student
from cifar10utils import getCIFAR10, getCIFAR10InputShape


'''
Function that loads from a file the teacher
'''
def getTeacher(file_name):
    with open(file_name + '.json', 'r') as f:
        model = model_from_json(f.read())
    model.load_weights(file_name + '.h5')
    
    with open(file_name + '_layer1.json', 'r') as f:
        m1 = model_from_json(f.read())
    m1.load_weights(file_name + '_layer1.h5')
    
    with open(file_name + '_layer2.json', 'r') as f:
        m2 = model_from_json(f.read())
    m2.load_weights(file_name + '_layer2.h5')
    
    with open(file_name + '_layer3.json', 'r') as f:
        m3 = model_from_json(f.read())
    m3.load_weights(file_name + '_layer3.h5')    
    
    
    
    print('Teacher loaded from' + file_name + '.h5')
    return model,m1,m2,m3
    
'''
Function that loads from a file the teacher and test it on the CIRAF10 dataset
'''
def testTeacher(file_name):
    x_train,y_train,x_test,y_test = getCIFAR10()
    model = getTeacher(file_name)
    opt_rms = keras.optimizers.rmsprop(lr=0.001,decay=1e-6)
    model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt_rms, metrics=['accuracy'])
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Teacher test loss:', score[0])
    print('Teacher test accuracy:', score[1])
    
'''
Function that returns a simple student done by 2 convolutions, a maxpool and a final two fully connected layers
'''
def getStudent(input_shape):
    num_classes = 10
    
    model_train, copy,m1,m2,m3 = create_wide_residual_network_student(input_shape, nb_classes=10, N=2,k=1)
    
    print('Simple student loaded')
    return model_train, copy,m1,m2,m3

'''
Function that returns a simple generator
'''
def getGenerator():

    noise_shape = (100,)

    model = Sequential()

    img_shape = getCIFAR10InputShape()

    model.add(Dense(128*8**2, input_shape=noise_shape))
    model.add(Reshape((8, 8, 128)))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2)) 

    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=(3,3), strides=1, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(3, kernel_size=(3,3), strides=1, padding="same"))
    model.add(BatchNormalization())   

    #model.summary()
    print('Generator loaded')
    
    return model

def getGAN(teacher,student,generator): #modified
    
    z = Input(shape=(100,))
    img = generator(z)
    
    #student.trainable = False # it works if it is not true
    teacher.trainable = False
    
    out_s = student(img)

    
    gan = Model(z,out_s)
    
    return gan

def gan_loss(y_true, y_pred):

    loss = keras.losses.kullback_leibler_divergence(y_true, y_pred)
    min_loss = (-(loss))
    
    return min_loss

def attention_loss(ta,sa):
    subtracted = (ta - sa)
    
    power2 = K.pow(subtracted,2)
    
    avg = K.mean(power2,-1)
    
    beta = 250
    
    return avg*beta

def student_loss(y_true,y_pred): #modified
    
    #y_true = teacher predictions
    #y_pred = student predictions
    KDloss = keras.losses.kullback_leibler_divergence(y_true,y_pred)
    
    '''
    attentionL = attention_loss(sa1,ta1)
    
    attentionL += attention_loss(sa2,ta2)
    
    attentionL += attention_loss(sa3,ta3)
    
    to_return = KDloss + attentionL'''
    
    return KDloss


def main():
    
    x_train,y_train,x_test,y_test = getCIFAR10()
    input_shape = getCIFAR10InputShape()
    
    optim_stud = Adam(lr=2e-3, clipnorm=5.0)
    optim_gen = Adam(lr=1e-3, clipnorm=5.0)  
    
    generator = getGenerator()
    
    teacher, t_layer1, t_layer2,t_layer3 = getTeacher('./pretrained_models/wrn_16_2')
    teacher.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    t_layer1.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    t_layer2.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    t_layer3.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])

    
    noise = np.random.normal(0, 1, (4, 100))
    gen_imgs = generator.predict(noise)
    t_predictions = teacher.predict(gen_imgs)
    ta1 = t_layer1.predict(gen_imgs)[0]
    ta2 = t_layer2.predict(gen_imgs)[0]
    ta3 = t_layer3.predict(gen_imgs)[0]
    
    sa1 = t_layer1.predict(gen_imgs)[0]
    sa2 = t_layer2.predict(gen_imgs)[0]
    sa3 = t_layer3.predict(gen_imgs)[0]
    
    student_train, student_test, s_layer1, s_layer2, s_layer3 = getStudent(input_shape)
    
    student_test.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
    
    
    student_train.compile(loss= student_loss, optimizer=optim_stud)
    s_layer1.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam')
    s_layer2.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam')
    s_layer3.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam')
    
    out_stud = student_train.predict([gen_imgs,ta1,ta2,ta3,sa1,sa2,sa3])
    
    print('out for the stud')
    print(str(out_stud.shape))
    
    
'''
    gan = getGAN(teacher,student_train,generator)
    #gan.summary()
    
    gan.compile(loss=gan_loss, optimizer=optim_gen)
    
    #student.trainable = False # it works if it is not true
    teacher.trainable = False
    
    n_batches = 1000
    batch_size = 128
    log_freq = 10
    ns = 10
    
    
    print("loop starting:")
    
    for i in range(n_batches):
        
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        
        t_predictions = teacher.predict(gen_imgs)
        s_predictions = student_train.predict(gen_imgs)
        
        gen_loss = gan.train_on_batch(noise,t_predictions)
        
        
        t_predictions = teacher.predict(gen_imgs)
        ta1 = t_layer1.predict(gen_imgs)[0]
        ta2 = t_layer2.predict(gen_imgs)[0]
        ta3 = t_layer3.predict(gen_imgs)[0]
        
        s_loss=0
        for j in range(ns):
            #s_predictions = student_train.predict(gen_imgs)
            sa1 = s_layer1.predict(gen_imgs)[0]
            sa2 = s_layer2.predict(gen_imgs)[0]
            sa3 = s_layer3.predict(gen_imgs)[0]
            
            fake_lbl = K.zeros((batch_size,100))
            
            s_loss += student_train.train_on_batch(gen_imgs,fake_lbl)

        
        print('batch ' + str(i) + '/' + str(n_batches) + ' G loss: ' + str(gen_loss) + ' S loss: ' + str(s_loss/ns))
        
        '''
'''     
        if (i % log_freq) == 0:
            score = student_test.evaluate(x_test, y_test, verbose=0)
            print('Student test loss: '  + str(score))
            
            model_json = student_test.to_json()
            with open('tmp-model' + str(i) + '.json','w') as json_file:
                json_file.write(model_json)
            student_test.save_weights('tmp-model' + str(i) + '.h5')
            print('saved model ' + str(i))
   '''
        
    #score = student_test.evaluate(x_test, y_test, verbose=0)
    #print('Student test loss: '  + str(score))

main()


'''
#If I want to see shapes put this before nbatches definition

    noise = np.random.normal(0, 1, (4, 100))
    gen_imgs = generator.predict(noise)
    t_predictions = teacher.predict(gen_imgs)
    ta1 = t_layer1.predict(gen_imgs)
    ta2 = t_layer2.predict(gen_imgs)
    ta3 = t_layer3.predict(gen_imgs)
    
    print('Teacher predictions shape: ' + str(t_predictions.shape))
    print('Teacher LAYER1 predictions shape: ' + str(ta1[0].shape))
    print('Teacher LAYER2 predictions shape: ' + str(ta2[0].shape))
    print('Teacher LAYER3 predictions shape: ' + str(ta3[0].shape))
    
    s_predictions = student_train.predict(gen_imgs)
    sa1 = s_layer1.predict(gen_imgs)
    sa2 = s_layer2.predict(gen_imgs)
    sa3 = s_layer3.predict(gen_imgs)
    
    print('Student predictions shape: ' + str(s_predictions.shape))
    print('Student LAYER1 predictions shape: ' + str(sa1[0].shape))
    print('Student LAYER2 predictions shape: ' + str(sa2[0].shape))
    print('Student LAYER3 predictions shape: ' + str(sa3[0].shape))
'''  

CIFAR10 loaded
Generator loaded
Teacher loaded from./pretrained_models/wrn_16_2.h5
Wide Residual Network-16-1 created.
Simple student loaded
out for the stud
(4, 2698)


" \n#If I want to see shapes put this before nbatches definition\n\n    noise = np.random.normal(0, 1, (4, 100))\n    gen_imgs = generator.predict(noise)\n    t_predictions = teacher.predict(gen_imgs)\n    ta1 = t_layer1.predict(gen_imgs)\n    ta2 = t_layer2.predict(gen_imgs)\n    ta3 = t_layer3.predict(gen_imgs)\n    \n    print('Teacher predictions shape: ' + str(t_predictions.shape))\n    print('Teacher LAYER1 predictions shape: ' + str(ta1[0].shape))\n    print('Teacher LAYER2 predictions shape: ' + str(ta2[0].shape))\n    print('Teacher LAYER3 predictions shape: ' + str(ta3[0].shape))\n    \n    s_predictions = student_train.predict(gen_imgs)\n    sa1 = s_layer1.predict(gen_imgs)\n    sa2 = s_layer2.predict(gen_imgs)\n    sa3 = s_layer3.predict(gen_imgs)\n    \n    print('Student predictions shape: ' + str(s_predictions.shape))\n    print('Student LAYER1 predictions shape: ' + str(sa1[0].shape))\n    print('Student LAYER2 predictions shape: ' + str(sa2[0].shape))\n    print('Stude

In [4]:
import torch.nn.functional as F
import torch

def attention(x):

    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))

def attention_diff(x, y):

    return (attention(x) - attention(y)).pow(2).mean()

x1 = torch.rand(4,32,32,32)
y1 = attention(x1)
print(str(y1.shape))
print(str(y1))

x2 = torch.rand(4,16,32,32)
y2 = attention(x2)
print(str(y2.shape))
print(str(y2))

y3 = attention_diff(x1,x2)
print(str(y3))
print(str(y3.shape))



torch.Size([4, 1024])
tensor([[0.0281, 0.0180, 0.0306,  ..., 0.0311, 0.0323, 0.0307],
        [0.0376, 0.0267, 0.0293,  ..., 0.0250, 0.0251, 0.0296],
        [0.0344, 0.0360, 0.0278,  ..., 0.0272, 0.0277, 0.0373],
        [0.0358, 0.0309, 0.0307,  ..., 0.0341, 0.0362, 0.0355]])
torch.Size([4, 1024])
tensor([[0.0275, 0.0311, 0.0335,  ..., 0.0392, 0.0358, 0.0284],
        [0.0417, 0.0398, 0.0382,  ..., 0.0369, 0.0415, 0.0331],
        [0.0202, 0.0169, 0.0289,  ..., 0.0372, 0.0257, 0.0403],
        [0.0257, 0.0245, 0.0465,  ..., 0.0305, 0.0282, 0.0244]])
tensor(6.9508e-05)
torch.Size([])
