In [1]:
import tensorflow as tf
from net.generator import *
from utils.cosine_anealing import *
from utils.losses import *
from tensorflow.keras.optimizers import Adam
from net.wide_resnet import WideResidualNetwork
from train_scratch import *
import numpy as np

In [2]:
tf.__version__

'2.0.0'

In [3]:
z_dim = 100
batch_size = 128
ng_batches = 1
ns_batches = 10
attn_beta = 250
total_n_pseudo_batches = 10
n_generator_items = ng_batches + ns_batches
total_batches = 0
student_lr = 2e-3
generator_lr = 1e-3
number_of_batches = 10

teacher_model = WideResidualNetwork(16, 1, input_shape=(32, 32, 3), dropout_rate=0.0, output_activations=True)
teacher_model.load_weights('saved_models/cifar10_WRN-16-1_model.005.h5')

student_model = WideResidualNetwork(16, 1, input_shape=(32, 32, 3), dropout_rate=0.0, output_activations=True)
student_optimizer=Adam(learning_rate=student_lr)
student_scheduler = CosineAnnealingScheduler(T_max=number_of_batches, eta_max=student_lr, eta_min=0)

generator_model = generator(100)
generator_optimizer=Adam(learning_rate=generator_lr)
generator_scheduler = CosineAnnealingScheduler(T_max=number_of_batches, eta_max=generator_lr, eta_min=0)

student_model.trainable = True
teacher_model.trainable = False
generator_model.trainable = True

gen_loss_metric = tf.keras.metrics.Mean()
stu_loss_metric = tf.keras.metrics.Mean()

def cosine_lr_schedule(epoch, T_max, eta_max, eta_min=0):
    lr = eta_min + (eta_max - eta_min) * (1 + math.cos(math.pi * epoch / T_max)) / 2
    return lr

# algo 1

    for total_batches in range(total_n_pseudo_batches):
    
        z = tf.random.normal([batch_size, z_dim])
        pseudo_images = get_gen_images(z)
        teacher_logits, *teacher_activations = get_model_outputs(teacher_model, pseudo_images, mode=0)

        #generator training
        for ng in range(ng_batches):
            student_logits, *student_activations = get_model_outputs(student_model, pseudo_images, mode=1)
            generator_loss = generator_loss(teacher_logits, student_logits)

            #################################
            # BACK PROP AND tick schedulers #
            #################################  

        for ns in range(ns_batches):
            student_logits, *student_activations = get_model_outputs(student_model, pseudo_images, mode=1)
            student_loss = student_loss(teacher_logits, teacher_activations, 
                                        student_logits, student_activations, attn_beta)

            #################################
            # BACK PROP AND tick schedulers #
            #################################   

        ######################################################
        ### Val accuracy computation and best model saving ###
        ######################################################    

In [4]:
# output_layer_names = ['logits', 'attention1', 'attention2', 'attention3']
# t_model = Model(teacher_model.input, [teacher_model.get_layer(l).output for l in output_layer_names])
# s_model = Model(student_model.input, [student_model.get_layer(l).output for l in output_layer_names])

In [5]:
for total_batches in range(total_n_pseudo_batches):
    z = tf.random.normal([batch_size, z_dim])
    #pseudo_images = generator_model(z)
    #teacher_logits, *teacher_activations = get(teacher_model, pseudo_images, mode=0)
    #teacher_logits, *teacher_activations = t_model(pseudo_images)

    #generator training
    for ng in range(ng_batches):
        with tf.GradientTape() as gtape:
            pseudo_images = generator_model(z)
            teacher_logits, *teacher_activations = teacher_model(pseudo_images)
            student_logits, *student_activations = student_model(pseudo_images)
            generator_loss = kd_loss(tf.math.softmax(teacher_logits), tf.math.softmax(student_logits))

        gen_grads = gtape.gradient(generator_loss, generator_model.trainable_weights)
        
        #cosine annealing for learning rate
        generator_optimizer.learning_rate = cosine_lr_schedule(total_batches, total_n_pseudo_batches, generator_lr)
        
        #update gradient
        generator_optimizer.apply_gradients(zip(gen_grads, generator_model.trainable_weights))

        gen_loss_metric(generator_loss)

        if total_batches % 2 == 0:
            print('step %s: generator mean loss = %s' % (total_batches, gen_loss_metric.result()))
    
    for ns in range(2):
        with tf.GradientTape() as stape:
            z = tf.random.normal([batch_size, z_dim])
            pseudo_images = generator_model(z)
            teacher_logits, *teacher_activations = teacher_model(pseudo_images)
            student_logits, *student_activations = student_model(pseudo_images)
            student_loss = student_loss(teacher_logits, teacher_activations, 
                                student_logits, student_activations, attn_beta)

        st_grads = stape.gradient(student_loss, student_model.trainable_weights)
        
        student_optimizer.learning_rate = cosine_lr_schedule(total_batches, total_n_pseudo_batches, student_lr)
        student_optimizer.apply_gradients(zip(st_grads, student_model.trainable_weights))

        stu_loss_metric(student_loss)

        if total_batches % 2 == 0:
            print('step %s: studnt mean loss = %s' % (total_batches, stu_loss_metric.result()))
    

step 0: generator mean loss = tf.Tensor(0.02347466, shape=(), dtype=float32)
step 0: studnt mean loss = tf.Tensor(0.55845463, shape=(), dtype=float32)


TypeError: 'tensorflow.python.framework.ops.EagerTensor' object is not callable

In [4]:
with tf.GradientTape() as stape:
            z = tf.random.normal([batch_size, z_dim])
            pseudo_images = generator_model(z)
            teacher_logits, *teacher_activations = teacher_model(pseudo_images)
            student_logits, *student_activations = student_model(pseudo_images)
            student_loss = student_loss(teacher_logits, teacher_activations, 
                                student_logits, student_activations, attn_beta)

In [6]:
student_model.trainable_weights

[<tf.Variable 'conv2d_15/kernel:0' shape=(3, 3, 3, 16) dtype=float32, numpy=
 array([[[[-1.36175558e-01,  1.92794502e-02, -6.28035963e-02,
            1.53739661e-01,  1.17667347e-01, -7.92065039e-02,
           -8.43740553e-02, -1.50640279e-01,  4.67011184e-02,
           -1.62312582e-01, -1.73553124e-01,  3.36108506e-02,
           -1.46332651e-01, -1.83363602e-01,  7.31568635e-02,
            1.24933809e-01],
          [ 8.92638564e-02,  1.54020756e-01,  1.49682105e-01,
            1.71467543e-01,  1.42989367e-01,  1.83314472e-01,
           -1.19284786e-01, -1.39543414e-02,  4.37110513e-02,
           -1.63303137e-01,  1.25635952e-01,  9.34698582e-02,
            1.15560830e-01,  6.96158409e-02, -9.88365859e-02,
           -9.93480757e-02],
          [-9.93471369e-02, -1.71751186e-01,  1.61541700e-01,
           -1.49901956e-01,  6.69551790e-02,  1.69312090e-01,
           -8.07022080e-02,  1.24500960e-01, -6.69332072e-02,
            4.92233038e-03,  3.37925702e-02,  1.65842712e-0