# ZeroShot (Algo 1)

    for total_batches in range(total_n_pseudo_batches):
    
        z = tf.random.normal([batch_size, z_dim])
        pseudo_images = get_gen_images(z)
        teacher_logits, *teacher_activations = get_model_outputs(teacher_model, pseudo_images, mode=0)

        #generator training
        for ng in range(ng_batches):
            student_logits, *student_activations = get_model_outputs(student_model, pseudo_images, mode=1)
            generator_loss = generator_loss(teacher_logits, student_logits)

            #################################
            # BACK PROP AND tick schedulers #
            #################################  

        for ns in range(ns_batches):
            student_logits, *student_activations = get_model_outputs(student_model, pseudo_images, mode=1)
            student_loss = student_loss(teacher_logits, teacher_activations, 
                                        student_logits, student_activations, attn_beta)

            #################################
            # BACK PROP AND tick schedulers #
            #################################   

        ######################################################
        ### Val accuracy computation and best model saving ###
        ######################################################    

In [2]:
import tensorflow as tf
from net.generator import NavieGenerator
from utils.cosine_anealing import CosineAnnealingScheduler
from utils.losses import kd_loss
from utils.losses import student_loss_fn
from tensorflow.keras.optimizers import Adam
from net.wide_resnet import WideResidualNetwork
from tensorflow.keras.experimental import CosineDecay
import numpy as np

In [3]:
tf.__version__

'2.0.0'

In [9]:
z_dim = 100
batch_size = 128
ng_batches = 1
ns_batches = 5
attn_beta = 250
total_n_pseudo_batches = 3
n_generator_items = ng_batches + ns_batches
student_lr = 2e-3
generator_lr = 1e-3
number_of_batches = 3

teacher = WideResidualNetwork(16, 1, input_shape=(32, 32, 3), dropout_rate=0.0, output_activations=True)
teacher.load_weights('saved_models/cifar10_WRN-16-1_model.005.h5')
teacher.trainable = False

student = WideResidualNetwork(16, 1, input_shape=(32, 32, 3), dropout_rate=0.0, output_activations=True)
student_optimizer = Adam(learning_rate=CosineDecay(student_lr, number_of_batches))

generator = NavieGenerator(input_dim=100)
generator_optimizer = Adam(learning_rate=CosineDecay(generator_lr, number_of_batches))

# Generator loss metrics
g_loss_met = tf.keras.metrics.Mean()
# Student loss metrics
stu_loss_met = tf.keras.metrics.Mean()

In [10]:
for total_batches in range(total_n_pseudo_batches):
    # sample from latern space to make an image
    z = tf.random.normal([batch_size, z_dim])

    # Generator training
    generator.trainable = True
    student.trainable = False
    for ng in range(ng_batches):
        with tf.GradientTape() as tape:
            pseudo_imgs = generator(z)
            t_logits, *_ = teacher(pseudo_imgs)
            s_logits, *_ = student(pseudo_imgs)

            # calculate the generator loss
            loss = -kd_loss(tf.math.softmax(t_logits),
                                     tf.math.softmax(s_logits))

        # The grad for generator
        grads = tape.gradient(loss, generator.trainable_weights)

        # update the generator paramter with the gradient
        generator_optimizer.apply_gradients(zip(grads, generator.trainable_weights))

        g_loss_met(loss)

        print('step %s: generator mean loss = %s' % (total_batches, g_loss_met.result()))
    # ==========================================================================

    # Student training
    generator.trainable = False
    student.trainable = True
    for ns in range(ns_batches):

        t_logits, *t_acts = teacher(pseudo_imgs)
        with tf.GradientTape() as tape:
            s_logits, *s_acts = student(pseudo_imgs)
            loss = student_loss_fn(tf.math.softmax(t_logits), t_acts, tf.math.softmax(s_logits), s_acts, attn_beta)

        # The grad for student
        grads = tape.gradient(loss, student.trainable_weights)

        # Apply grad for student
        student_optimizer.apply_gradients(zip(grads, student.trainable_weights))

        stu_loss_met(loss)

        print('step %s-%s: studnt mean loss = %s' % (total_batches, ns, stu_loss_met.result()))


step 0: generator mean loss = tf.Tensor(0.016730428, shape=(), dtype=float32)
step 0-0: studnt mean loss = tf.Tensor(0.016759966, shape=(), dtype=float32)
step 0-1: studnt mean loss = tf.Tensor(0.016232062, shape=(), dtype=float32)
step 0-2: studnt mean loss = tf.Tensor(0.015124943, shape=(), dtype=float32)
step 0-3: studnt mean loss = tf.Tensor(0.014083632, shape=(), dtype=float32)
step 0-4: studnt mean loss = tf.Tensor(0.013458845, shape=(), dtype=float32)
step 1: generator mean loss = tf.Tensor(0.01038218, shape=(), dtype=float32)
step 1-0: studnt mean loss = tf.Tensor(0.01189079, shape=(), dtype=float32)
step 1-1: studnt mean loss = tf.Tensor(0.010770751, shape=(), dtype=float32)
step 1-2: studnt mean loss = tf.Tensor(0.0099307215, shape=(), dtype=float32)
step 1-3: studnt mean loss = tf.Tensor(0.009277365, shape=(), dtype=float32)
step 1-4: studnt mean loss = tf.Tensor(0.00875468, shape=(), dtype=float32)
step 2: generator mean loss = tf.Tensor(0.008025792, shape=(), dtype=float32

In [11]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [24]:
(x_train.astype('float32')/255.).std(0)

array([[[0.2878878 , 0.28590465, 0.31548318],
        [0.28407958, 0.2819691 , 0.3121397 ],
        [0.28329298, 0.28107843, 0.3115068 ],
        ...,
        [0.28379878, 0.28180113, 0.312042  ],
        [0.2850664 , 0.28308594, 0.31296697],
        [0.28699955, 0.28499833, 0.3142848 ]],

       [[0.28503835, 0.28301004, 0.31331933],
        [0.28181496, 0.2797545 , 0.3105574 ],
        [0.28050894, 0.27834976, 0.30936906],
        ...,
        [0.2802267 , 0.2783556 , 0.30931875],
        [0.2818469 , 0.27998322, 0.31062594],
        [0.28387433, 0.28186673, 0.31181505]],

       [[0.28229436, 0.28026143, 0.3105508 ],
        [0.27865094, 0.27673486, 0.30747628],
        [0.27674127, 0.27466908, 0.30549774],
        ...,
        [0.2763604 , 0.27458864, 0.3056128 ],
        [0.27819574, 0.27650353, 0.3071304 ],
        [0.28067303, 0.2787829 , 0.30858272]],

       ...,

       [[0.2520456 , 0.24317743, 0.2571024 ],
        [0.24767458, 0.23855966, 0.2517179 ],
        [0.24632809, 0