In [23]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tensorflow.keras.backend as kb
from backwardcompatibilityml import scores
from backwardcompatibilityml.tensorflow.models import BCNewErrorCompatibilityModel

tf.enable_v2_behavior()
tf.random.set_seed(0)

In [24]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [25]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [26]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

In [27]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

In [28]:
model.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    optimizer=tf.keras.optimizers.Adam(0.001),
    metrics=['accuracy'],
)

model.fit(
    ds_train,
    epochs=3,
    validation_data=ds_test,
)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f2b8c6abf28>

In [29]:
lambda_c = 0.0
model.trainable = False

h2 = BCNewErrorCompatibilityModel([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
], h1=model, lambda_c=lambda_c)

In [30]:
len(model.trainable_weights), len(h2.trainable_weights)

(0, 4)

In [31]:
h2.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    optimizer=tf.keras.optimizers.Adam(0.001),
    metrics=['accuracy']
)

h2.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<tensorflow.python.keras.callbacks.History at 0x7f2b8c579128>

In [32]:
model.trainable = False
h2.trainable = False

In [33]:
h1_predicted_labels = []
h2_predicted_labels = []
ground_truth_labels = []
for x_batch_test, y_batch_test in ds_test:
    h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)
    h2_batch_predictions = tf.argmax(h2(x_batch_test), axis=1)
    h1_predicted_labels += h1_batch_predictions.numpy().tolist()
    h2_predicted_labels += h2_batch_predictions.numpy().tolist()
    ground_truth_labels += y_batch_test.numpy().tolist()

In [34]:
btc = scores.trust_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)
bec = scores.error_compatibility_score(h1_predicted_labels, h2_predicted_labels, ground_truth_labels)

print(f"lambda_c: {lambda_c}")
print(f"BTC: {btc}")
print(f"BEC: {bec}")

lambda_c: 0.0
BTC: 0.9915437764256987
BEC: 0.5808580858085809


In [35]:
lambda_c = 0.9
model.trainable = False

h3 = BCNewErrorCompatibilityModel([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
], h1=model, lambda_c=lambda_c)

In [36]:
h3.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    optimizer=tf.keras.optimizers.Adam(0.001),
    metrics=['accuracy']
)

h3.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<tensorflow.python.keras.callbacks.History at 0x7f2c58193320>

In [37]:
model.trainable = False
h3.trainable = False

In [38]:
h1_predicted_labels = []
h3_predicted_labels = []
ground_truth_labels = []
for x_batch_test, y_batch_test in ds_test:
    h1_batch_predictions = tf.argmax(model(x_batch_test), axis=1)
    h3_batch_predictions = tf.argmax(h3(x_batch_test), axis=1)
    h1_predicted_labels += h1_batch_predictions.numpy().tolist()
    h3_predicted_labels += h3_batch_predictions.numpy().tolist()
    ground_truth_labels += y_batch_test.numpy().tolist()

In [39]:
btc = scores.trust_compatibility_score(h1_predicted_labels, h3_predicted_labels, ground_truth_labels)
bec = scores.error_compatibility_score(h1_predicted_labels, h3_predicted_labels, ground_truth_labels)

print(f"lambda_c: {lambda_c}")
print(f"BTC: {btc}")
print(f"BEC: {bec}")

lambda_c: 0.9
BTC: 0.9914406517479633
BEC: 0.6732673267326733
