In [80]:
# preamble copied from Task1.ipynb

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union

tf.enable_v2_behavior()

builder = tfds.builder('mnist')
BATCH_SIZE = 256
NUM_EPOCHS = 200
NUM_CLASSES = 10  # 10 total classes.

# number of subclasses per class
SUBCLASSES = 2
BETA = 0.7 # hyperparameter regarding auxillary loss
AUXILLIARY_TEMPERATURE = 4

# Load train and test splits.
def preprocess(x):
  image = tf.image.convert_image_dtype(x['image'], tf.float32)
  subclass_labels = tf.one_hot(x['label'], builder.info.features['label'].num_classes)
  return image, subclass_labels
  
mnist_train = tfds.load('mnist', split='train', shuffle_files=False).cache()
mnist_train = mnist_train.map(preprocess)
mnist_train = mnist_train.shuffle(builder.info.splits['train'].num_examples)
mnist_train = mnist_train.batch(BATCH_SIZE, drop_remainder=True)

mnist_test = tfds.load('mnist', split='test').cache()
mnist_test = mnist_test.map(preprocess).batch(BATCH_SIZE)

# this layer compresses the subclasses to the classes via a summation
class SubclassCollapseLayer(tf.keras.layers.Layer):
    def __init__(self, n, **kwargs):
        self.n = n
        super().__init__(**kwargs)
    
    def call(self, vals):
        # difficult to reshape a tensor with unknown first dimension
        vals = tf.expand_dims(vals, axis=-1)
        shape = [tf.shape(vals)[k] for k in range(len(vals.shape))]
        shape[-1] = 2
        shape[-2] = -1 # infer
        vals = tf.reshape(vals, shape=shape)
        vals = tf.reduce_sum(vals, axis=-1)
        return vals

cnn_input = tf.keras.Input(shape=(28, 28, 1), dtype=tf.float32)

# 3x3 convolution
cnn_layer0 = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu")(cnn_input)
# 2x2 spatial pool
cnn_layer1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(cnn_layer0)
# another convolution
cnn_layer2 = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu")(cnn_layer1)
# another pool
cnn_layer3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(cnn_layer2)
# flatten
cnn_layer4 = tf.keras.layers.Flatten()(cnn_layer3)
# dropout
cnn_layer5 = tf.keras.layers.Dropout(0.5)(cnn_layer4)
# fully connected
cnn_layer6 = tf.keras.layers.Dense(128, activation="relu")(cnn_layer5)
# dropout again
cnn_layer7 = tf.keras.layers.Dropout(0.5)(cnn_layer6)

cnn_layer8 = tf.keras.layers.Dense(NUM_CLASSES * SUBCLASSES)(cnn_layer7) # no activation (logit output)

cnn_layer9 = SubclassCollapseLayer(SUBCLASSES)(cnn_layer8) # still logits

cnn_model = tf.keras.Model(inputs=cnn_input, outputs=[cnn_layer8, cnn_layer9])

# student model
fc_input = tf.keras.Input(shape=(28, 28, 1), dtype=tf.float32)
fc_layer0 = tf.keras.layers.Flatten()(fc_input)
fc_layer1 = tf.keras.layers.Dense(784, activation="relu")(fc_layer0)
fc_layer2 = tf.keras.layers.Dense(784, activation="relu")(fc_layer1)
fc_layer3 = tf.keras.layers.Dense(NUM_CLASSES * SUBCLASSES)(fc_layer2) # no activation (logit output)
fc_layer4 = SubclassCollapseLayer(SUBCLASSES)(fc_layer3) # still logits
fc_model = tf.keras.Model(inputs=fc_input, outputs=[fc_layer3, fc_layer4])

# ================================================================================================

def compute_subclass_loss(subclass_logits):
    # using auxiliary loss equation described in the paper

    # input is 256 x (c x 10)
    n = tf.constant(subclass_logits.shape[0], dtype=float)
    T = AUXILLIARY_TEMPERATURE

    # normalized it. as described in the paper, mean 0 and var 1
    mean, variance = tf.nn.moments(subclass_logits, axes=[1], keepdims=True)
    vals = (subclass_logits - mean) / tf.sqrt(variance + 1e-8) # normalized_subclass_logits
    
    # I learned einstein summation notation for this assignment. very cool very good
    term = tf.reduce_sum(tf.math.log(tf.exp(tf.einsum('ik,jk->ij', vals, vals) / T)))

    ret = (1 / n) * term - (1 / T) - tf.math.log(n)
    return ret

@tf.function
def compute_teacher_loss(images, labels):
  subclass_logits, logits = cnn_model(images, training=True)

  typical_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels, logits))
  auxiliary_loss = compute_subclass_loss(subclass_logits)

  return typical_loss + BETA * auxiliary_loss

# ================================================================================================

@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)[1]
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.int32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)

def train_and_evaluate(model, compute_loss_fn):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """

  # your code start from here for step 4
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    do_print = epoch % 5 == 0
    if do_print:
      print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4
        loss_value = compute_loss_fn(images, labels)
      grads = tape.gradient(loss_value, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = tf.constant(0, dtype=tf.int32)
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
      # your code start from here for step 4
      num_correct += tf.reduce_sum(compute_num_correct(model, images, labels)[0])
    
    last_accuracy = num_correct / num_total * 100
    if do_print:
      print("Class_accuracy: " + '{:.2f}%'.format(last_accuracy))
  return last_accuracy


In [81]:
drop = train_and_evaluate(cnn_model, compute_teacher_loss)

Epoch 5: Class_accuracy: 27.25%
Epoch 10: Class_accuracy: 33.73%
Epoch 15: Class_accuracy: 46.05%
Epoch 20: Class_accuracy: 48.10%
Epoch 25: Class_accuracy: 58.72%
Epoch 30: Class_accuracy: 65.64%
Epoch 35: Class_accuracy: 77.02%
Epoch 40: Class_accuracy: 80.96%
Epoch 45: Class_accuracy: 85.88%
Epoch 50: Class_accuracy: 90.19%
Epoch 55: Class_accuracy: 93.06%
Epoch 60: Class_accuracy: 93.83%
Epoch 65: Class_accuracy: 94.43%
Epoch 70: Class_accuracy: 95.22%
Epoch 75: Class_accuracy: 95.46%
Epoch 80: Class_accuracy: 95.80%
Epoch 85: Class_accuracy: 96.04%
Epoch 90: Class_accuracy: 96.51%
Epoch 95: Class_accuracy: 96.64%
Epoch 100: Class_accuracy: 96.62%
Epoch 105: Class_accuracy: 96.72%
Epoch 110: Class_accuracy: 96.70%
Epoch 115: Class_accuracy: 96.72%
Epoch 120: Class_accuracy: 96.92%
Epoch 125: Class_accuracy: 97.00%
Epoch 130: Class_accuracy: 96.73%
Epoch 135: Class_accuracy: 97.02%
Epoch 140: Class_accuracy: 97.05%
Epoch 145: Class_accuracy: 97.04%
Epoch 150: Class_accuracy: 97.11%


In [82]:
# checking that the properties hold for the subclass weights

# this is the dense layer weights that go into creating the 
weights = cnn_model.layers[-2].get_weights()[0] # at 0 since there are two outputs in the graph
# 128 x 20 output weights

# n dimensional length
first_subclass_total_length = 0
second_subclass_total_length = 0

dot_pr = 0

for i in range(len(weights)):
    for j in range(0, len(weights[0]), 2):
        first_subclass_total_length += weights[i][j] ** 2
        second_subclass_total_length += weights[i][j+1] ** 2
        dot_pr += weights[i][j] * weights[i][j+1]

first_subclass_total_length **= 0.5
second_subclass_total_length **= 0.5
print(first_subclass_total_length, second_subclass_total_length)
print(dot_pr / (first_subclass_total_length * second_subclass_total_length))

6.552832026923353 6.585251884982124
0.16282136076192422


In [87]:
# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE_ORIGINAL = 4
DISTILLATION_TEMPERATURE = DISTILLATION_TEMPERATURE_ORIGINAL #temperature hyperparameter

# ================================================================================================

# distillation loss uses only the subclass logits
def distillation_loss(teacher_subclass_logits: tf.Tensor, student_subclass_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  # soften the teacher's logits
  soft_targets = tf.nn.softmax(teacher_subclass_logits / temperature)

  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_subclass_logits / temperature)) * temperature ** 2

def compute_student_loss(images, labels):
  student_subclass_logits, student_logits = fc_model(images, training=True)
  teacher_subclass_logits, teacher_logits = cnn_model(images, training=False)

  distillation_loss_value = distillation_loss(teacher_subclass_logits, student_subclass_logits, DISTILLATION_TEMPERATURE)

  cross_entropy_loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels, student_logits))

  return distillation_loss_value * ALPHA + cross_entropy_loss_value * (1 - ALPHA)

In [88]:
drop = train_and_evaluate(fc_model, compute_student_loss)

Epoch 5: Class_accuracy: 98.58%
Epoch 10: Class_accuracy: 98.54%
Epoch 15: Class_accuracy: 98.57%
Epoch 20: Class_accuracy: 98.49%
Epoch 25: Class_accuracy: 98.51%
Epoch 30: Class_accuracy: 98.55%
Epoch 35: Class_accuracy: 98.51%
Epoch 40: Class_accuracy: 98.62%
Epoch 45: Class_accuracy: 98.53%
Epoch 50: Class_accuracy: 98.50%
Epoch 55: Class_accuracy: 98.52%
Epoch 60: Class_accuracy: 98.55%
Epoch 65: Class_accuracy: 98.42%
Epoch 70: Class_accuracy: 98.52%
Epoch 75: Class_accuracy: 98.53%
Epoch 80: Class_accuracy: 98.51%
Epoch 85: Class_accuracy: 98.49%
Epoch 90: Class_accuracy: 98.51%
Epoch 95: Class_accuracy: 98.53%
Epoch 100: Class_accuracy: 98.51%
Epoch 105: Class_accuracy: 98.56%
Epoch 110: Class_accuracy: 98.58%
Epoch 115: Class_accuracy: 98.49%
Epoch 120: Class_accuracy: 98.46%
Epoch 125: Class_accuracy: 98.44%
Epoch 130: Class_accuracy: 98.55%
Epoch 135: Class_accuracy: 98.51%
Epoch 140: Class_accuracy: 98.49%
Epoch 145: Class_accuracy: 98.51%
Epoch 150: Class_accuracy: 98.48%
