In [1]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tensorflow.keras.backend as kb
from backwardcompatibilityml import scores
from backwardcompatibilityml.tensorflow import helpers as tf_helpers
from backwardcompatibilityml.tensorflow.loss.strict_imitation import BCStrictImitationKLDivLoss
import copy

tf.enable_v2_behavior()
tf.random.set_seed(0)

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [3]:
def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    label_one_hot = tf.one_hot(label, 10)
    return tf.cast(image, tf.float32) / 255., label_one_hot

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [4]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

In [5]:
kldiv_loss = tf.keras.losses.KLDivergence()
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
    loss=kldiv_loss,
    optimizer=tf.keras.optimizers.Adam(0.001),
    metrics=['accuracy'],
)

model.fit(
    ds_train,
    epochs=3,
    validation_data=ds_test,
)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f0adc08fd68>

In [6]:
lambda_c = 0.9
model.trainable = False

h2 = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

bc_loss = BCStrictImitationKLDivLoss(model, h2, lambda_c)

In [7]:
optimizer = tf.keras.optimizers.Adam(0.001)

In [8]:
tf_helpers.bc_fit(h2, training_set=ds_train, testing_set=ds_test, epochs=6, bc_loss=bc_loss, optimizer=optimizer)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6
Training done.


In [9]:
model.trainable = False
h2.trainable = False

In [10]:
h1_predicted_labels = []
h2_predicted_labels = []
ground_truth_labels = []
for x_batch_test, y_batch_test in ds_test:
    h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)
    h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)
    h1_predicted_labels += h1_batch_predictions.numpy().tolist()
    h2_predicted_labels += h2_batch_predictions.numpy().tolist()
    ground_truth_labels += y_batch_test.numpy().tolist()

In [11]:
btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)
bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)

print(f"lambda_c: {lambda_c}")
print(f"BTC: {btc}")
print(f"BEC: {bec}")

lambda_c: 0.9
BTC: 0
BEC: 1.0
