# Max Pool MIL

Now we're going to load and train and MIL network on MNIST (presence of 0 = positive bag).

In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import utils
from noisy_and import dense
import mil_benchmarks

N_CLASSES = 2
MAX_BAG = 7

utils.gpu_fix()

tf.random.set_seed(42)
print(tf.__version__)

1 Physical GPUs, 1 Logical GPUs
2.4.1


In [2]:
def define_model(shape=(28, 28, 1)):
  model = tf.keras.Sequential([
    layers.Input((MAX_BAG,) + shape),
    *utils.baseline_layers(shape, N_CLASSES),
    layers.Flatten(),
    layers.Dense(N_CLASSES, activation='softmax'),
  ])

  utils.compile(model, N_CLASSES)
  return model

print(define_model().summary())

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
time_distributed (TimeDistri (None, 7, 28, 28, 64)     320       
_________________________________________________________________
time_distributed_1 (TimeDist (None, 7, 14, 14, 64)     0         
_________________________________________________________________
time_distributed_2 (TimeDist (None, 7, 12, 12, 32)     18464     
_________________________________________________________________
time_distributed_3 (TimeDist (None, 7, 6, 6, 32)       0         
_________________________________________________________________
time_distributed_4 (TimeDist (None, 7, 6, 6, 32)       0         
_________________________________________________________________
time_distributed_5 (TimeDist (None, 7, 1152)           0         
_________________________________________________________________
time_distributed_6 (TimeDist (None, 7, 256)            2

In [None]:
standard_mnist   = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.standard.mnist)
standard_fashion = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.standard.fashion)
standard_cifar10 = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.standard.cifar10, (32, 32, 3))

In [None]:
presence_mnist   = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.presence.mnist)
presence_fashion = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.presence.fashion)
presence_cifar10 = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.presence.cifar10, (32, 32, 3))

In [None]:
absence_mnist   = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.absence.mnist)
absence_fashion = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.absence.fashion)
absence_cifar10 = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.absence.cifar10, (32, 32, 3))

In [None]:
complex_fashion = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.complex.fashion)

In [None]:
utils.plot_histories(standard_mnist, title='Standard MNIST', filename='img/dense-standard-mnist.jpg')
utils.plot_histories(standard_fashion, title='Standard Fashion', filename='img/dense-standard-fashion.jpg')
utils.plot_histories(standard_cifar10, title='Standard CIFAR-10', filename='img/dense-standard-cifar10.jpg')

utils.plot_histories(presence_mnist, title='Presence MNIST', filename='img/dense-presence-mnist.jpg')
utils.plot_histories(presence_fashion, title='Presence Fashion', filename='img/dense-presence-fashion.jpg')
utils.plot_histories(presence_cifar10, title='Presence CIFAR-10', filename='img/dense-presence-cifar10.jpg')

utils.plot_histories(absence_mnist, title='Absence MNIST', filename='img/dense-absence-mnist.jpg')
utils.plot_histories(absence_fashion, title='Absence Fashion', filename='img/dense-absence-fashion.jpg')
utils.plot_histories(absence_cifar10, title='Absence CIFAR-10', filename='img/dense-absence-cifar10.jpg')

utils.plot_histories(complex_fashion, title='Complex Fashion', filename='img/dense-complex-fashion.jpg')