In [None]:
import tensorflow as tf
from beyondml import tflow

from sklearn.metrics import classification_report, confusion_matrix

def print_results(truth, preds):
    print(confusion_matrix(truth, preds))
    print(classification_report(truth, preds))
    print('\n\n')

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
x_train = tf.keras.applications.resnet.preprocess_input(x_train)
x_test = tf.keras.applications.resnet.preprocess_input(x_test)

In [None]:
strategy = tf.distribute.MirroredStrategy()

with strategy.scope():

    input_layer = tf.keras.layers.Input(x_train.shape[1:])
    x = tf.keras.layers.UpSampling2D((2, 2))(input_layer)
    x = tf.keras.layers.UpSampling2D((2, 2))(x)
    x = tf.keras.layers.UpSampling2D((2, 2))(x)
    x = tf.keras.applications.ResNet101(include_top = False)(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tflow.layers.MultiMaskedDense(128, activation = 'relu')([x, x, x,])
    x1 = tflow.layers.SelectorLayer(0)(x)
    x2 = tflow.layers.SelectorLayer(1)(x)
    x3 = tflow.layers.SelectorLayer(2)(x)
    x1 = tf.keras.layers.Dropout(0.5)(x1)
    x2 = tf.keras.layers.Dropout(0.5)(x2)
    x3 = tf.keras.layers.Dropout(0.5)(x3)
    x1 = tf.keras.layers.BatchNormalization()(x1)
    x2 = tf.keras.layers.BatchNormalization()(x2)
    x3 = tf.keras.layers.BatchNormalization()(x3)
    x = tflow.layers.MultiMaskedDense(64, activation = 'relu')([x1, x2, x3])
    x1 = tflow.layers.SelectorLayer(0)(x)
    x2 = tflow.layers.SelectorLayer(1)(x)
    x3 = tflow.layers.SelectorLayer(2)(x)
    x1 = tf.keras.layers.Dropout(0.5)(x1)
    x2 = tf.keras.layers.Dropout(0.5)(x2)
    x3 = tf.keras.layers.Dropout(0.5)(x3)
    x1 = tf.keras.layers.BatchNormalization()(x1)
    x2 = tf.keras.layers.BatchNormalization()(x2)
    x3 = tf.keras.layers.BatchNormalization()(x3)
    output_layer = tflow.layers.MultiMaskedDense(100, activation = 'softmax')([x1, x2, x3])

    model = tf.keras.models.Model(input_layer, output_layer)
    model = tflow.utils.add_layer_masks(model)
    model.compile(loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'], optimizer = tf.keras.optimizers.Adamax())

In [None]:
model.fit(
    x_train,
    [y_train, y_train, y_train],
    batch_size = 64,
    epochs = 10,
    validation_split = 0.2,
    verbose = 0
)

In [None]:
model = tflow.utils.mask_model(model, 70, method = 'magnitude')
model.compile(loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'], optimizer = tf.keras.optimizers.Adamax())

callbacks = [
    tf.keras.callbacks.EarlyStopping(min_delta = 0.001, patience = 5, restore_best_weights = True),
    tf.keras.callbacks.ReduceLROnPlateau(min_lr = 1e-8, patience = 3)
]
model.fit(
    x_train,
    [y_train, y_train, y_train],
    batch_size = 64,
    epochs = 100,
    validation_split = 0.2,
    callbacks = callbacks,
    verbose = 2
)

model = tflow.utils.remove_layer_masks(model)

preds = model.predict(x_test)

print_results(y_test, preds[0].argmax(axis = 1))
print_results(y_test, preds[1].argmax(axis = 1))
print_results(y_test, preds[2].argmax(axis = 1))
print_results(y_test, sum(preds).argmax(axis = 1))

model.save('sparse_resnet_cifar100.h5')