In [1]:
import tensorflow as tf
from net.generator import *
from utils.cosine_anealing import *
from utils.losses import *
from tensorflow.keras.optimizers import Adam

from net.wide_resnet import WideResidualNetwork
#from train_scratch import *
import numpy as np

In [2]:
tf.__version__

'2.0.0'

In [3]:
def get_model_outputs(model, input):
    """
    given model and the input data, outputs the logits and activations required for attention training

    :param model: either student or teacher model
    :param input: input batch of images
    :param mode: 0 for test mode or 1 for train mode
    :return: [logits, activations of 3 main blocks]

    TODO:
        1. use ```get_intm_outputs_of``` instead
        2. Remove this when fully migrated
    """
    output_layer_names = ['logits', 'attention1', 'attention2', 'attention3']
    get_outputs = K.function([model.layers[0].input],
                             [model.get_layer(l).output for l in output_layer_names])

    return get_outputs([input])

In [6]:
z_dim = 100
batch_size = 128
ng_batches = 1
ns_batches = 10
attn_beta = 250
total_n_pseudo_batches = 10
n_generator_items = ng_batches + ns_batches
total_batches = 0
student_lr = 2e-3
generator_lr = 1e-3
number_of_batches = 10

teacher_model = WideResidualNetwork(16, 1, input_shape=(32, 32, 3), dropout_rate=0.0)
teacher_model.load_weights('saved_models/cifar10_WRN-16-1_model.005.h5')

student_model = WideResidualNetwork(16, 1, input_shape=(32, 32, 3), dropout_rate=0.0)
student_optimizer=Adam(learning_rate=student_lr)
student_scheduler = CosineAnnealingScheduler(T_max=number_of_batches, eta_max=student_lr, eta_min=0)
"""
compile student model with loss and lr_scheduler
"""
generator_model = generator(100)
generator_optimizer=Adam(learning_rate=generator_lr)
generator_scheduler = CosineAnnealingScheduler(T_max=number_of_batches, eta_max=generator_lr, eta_min=0)
"""
compile generator model with loss and lr_scheduler
"""

student_model.trainable = True
teacher_model.trainable = False
generator_model.trainable = True

gen_loss_metric = tf.keras.metrics.Mean()

# algo 1
loss_metric = tf.keras.metrics.Mean()

for total_batches in range(total_n_pseudo_batches):
    z = tf.random.normal([batch_size, z_dim])
    pseudo_images = generator_model(z)
    teacher_logits, *teacher_activations = get_model_outputs(teacher_model, pseudo_images, mode=0)
    
    #generator training
    for ng in range(ng_batches):
        with tf.GradientTape() as tape:
            student_logits, *student_activations = get_model_outputs(student_model, pseudo_images, mode=1)
            generator_loss = generator_loss(teacher_logits, student_logits)
            
        grads = tape.gradient(generator_loss, generator_model.trainable_weights)
        generator_optimizer.apply_gradients(zip(grads, generator_model.trainable_weights))
        
        loss_metric(generator_loss)
        
        if total_batches % 100 == 0:
            print('step %s: mean loss = %s' % (total_batches, loss_metric.result()))
    """   
        #################
        ### BACK PROP AND MODEL RELATED UPDATES
        ##################
        
    for ns in range(ns_batches):
        student_logits, *student_activations = get_model_outputs(student_model, pseudo_images, mode=1)
        student_loss = student_loss(teacher_logits, teacher_activations, 
                                    student_logits, student_activations, attn_beta)
        
        #################
        ### BACK PROP AND MODEL RELATED UPDATES
        ##################    
        
    #Best test accuracy computation and best model saving
    """

In [5]:
for total_batches in range(total_n_pseudo_batches):
    with tf.GradientTape() as tape:
        z = tf.random.normal([batch_size, z_dim])
        pseudo_images = generator_model(z)
        #teacher_logits, *teacher_activations = get(teacher_model, pseudo_images, mode=0)
        teacher_logits = teacher_model(pseudo_images)

        #generator training
        for ng in range(ng_batches):
        
            #pseudo_images = generator_model(z)
            student_logits = student_model(pseudo_images)
            #teacher_logits = teacher_model(pseudo_images)
            #teacher_logits, *teacher_activations = get_model_outputs(teacher_model, pseudo_images)
            #act1 = teacher_model
            # teacher_logits, *teacher_activations = get_intm_outputs_of(teacher_model, pseudo_images, "eval")
            generator_loss = tf.keras.metrics.KLD(teacher_logits, student_logits)

            grads = tape.gradient(generator_loss, generator_model.trainable_weights)
            generator_optimizer.apply_gradients(zip(grads, generator_model.trainable_weights))

            gen_loss_metric(generator_loss)

            if total_batches % 2 == 0:
                print('step %s: generator mean loss = %s' % (total_batches, gen_loss_metric.result()))

step 0: generator mean loss = tf.Tensor(1.1801524, shape=(), dtype=float32)
step 2: generator mean loss = tf.Tensor(1.0724399, shape=(), dtype=float32)


KeyboardInterrupt: 