In [None]:
import tensorflow as tf
import numpy as np
import os

In [None]:
nodes_number = 32
learning_rate = 0.0001

# MNIST classification

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0

In [None]:
def filter_36(x, y):
    keep = (y == 3) | (y == 6)
    x, y = x[keep], y[keep]
    y = y == 3
    return x,y

print("Number of unfiltered training examples:", len(x_train))
print("Number of unfiltered test examples:", len(x_test))

x_train, y_train = filter_36(x_train, y_train)
x_test, y_test = filter_36(x_test, y_test)

print("Number of filtered training examples:", len(x_train))
print("Number of filtered test examples:", len(x_test))

In [None]:
model_full = tf.keras.models.Sequential()
model_full.add(tf.keras.layers.Conv2D(32, [3, 3], activation='relu', input_shape=(28,28,1)))
model_full.add(tf.keras.layers.Conv2D(64, [3, 3], activation='relu'))
model_full.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model_full.add(tf.keras.layers.Dropout(0.25))
model_full.add(tf.keras.layers.Flatten())

model_full.add(tf.keras.layers.Dense(nodes_number, activation='relu'))
model_full.add(tf.keras.layers.Dropout(0.5))
model_full.add(tf.keras.layers.Dense(1))

model_full.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate), metrics=['accuracy'])
model_full.summary()

model_full.fit(x_train, y_train, batch_size=256, epochs=1, verbose=1, validation_data=(x_test, y_test))
model_full_test_acc = model_full.evaluate(x_test, y_test)[1]

In [None]:
model_fair = tf.keras.Sequential()
model_fair.add(tf.keras.layers.Flatten(input_shape=(28,28,1)))
model_fair.add(tf.keras.layers.Dense(2, activation='relu'))
model_fair.add(tf.keras.layers.Dense(1))

model_fair.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate), metrics=['accuracy'])
model_fair.summary()

model_fair.fit(x_train, y_train, batch_size=256, epochs=1, verbose=1, validation_data=(x_test, y_test))

model_fair_test_acc = model_fair.evaluate(x_test, y_test)[1]

In [None]:
print('model_full_test_acc:', str(model_full_test_acc))
print('model_fair_test_acc:', str(model_fair_test_acc))

if model_full_test_acc > model_fair_test_acc:
    print('Full model provided better result')
    test_acc = model_full_test_acc
else:
    print('Fair model provided better result')
    test_acc = model_fair_test_acc

In [None]:
print(test_acc)