# Dense MIL

Now we're going to load and train and MIL network on MNIST (presence of 0 = positive bag).

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import utils
import mil_benchmarks

N_CLASSES = 2
MAX_BAG = 7

utils.gpu_fix()

tf.random.set_seed(42)
print(tf.__version__)

In [None]:
def define_model(shape=(28, 28, 1)):
  model = tf.keras.Sequential([
    layers.Input((MAX_BAG,) + shape),
    *utils.baseline_layers(shape, N_CLASSES),
    layers.Flatten(),
    layers.Dense(N_CLASSES, activation='softmax'),
  ])

  utils.compile(model, N_CLASSES)
  return model

print(define_model().summary())

In [None]:
standard_mnist = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.standard.mnist)
utils.plot_histories(standard_mnist, title='Fully-Connected / Standard MNIST', filename='img/dense-standard-mnist.jpg')

In [None]:
standard_fashion = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.standard.fashion)
utils.plot_histories(standard_fashion, title='Fully-Connected / Standard Fashion', filename='img/dense-standard-fashion.jpg')

In [None]:
standard_cifar10 = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.standard.cifar10, (32, 32, 3))
utils.plot_histories(standard_cifar10, title='Fully-Connected / Standard CIFAR-10', filename='img/dense-standard-cifar10.jpg')

In [None]:
presence_mnist = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.presence.mnist)
utils.plot_histories(presence_mnist, title='Fully-Connected / Presence MNIST', filename='img/dense-presence-mnist.jpg')

In [None]:
presence_fashion = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.presence.fashion)
utils.plot_histories(presence_fashion, title='Fully-Connected / Presence Fashion', filename='img/dense-presence-fashion.jpg')

In [None]:
presence_cifar10 = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.presence.cifar10, (32, 32, 3))
utils.plot_histories(presence_cifar10, title='Fully-Connected / Presence CIFAR-10', filename='img/dense-presence-cifar10.jpg')

In [None]:
absence_mnist = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.absence.mnist)
utils.plot_histories(absence_mnist, title='Fully-Connected / Absence MNIST', filename='img/dense-absence-mnist.jpg')

In [None]:
absence_fashion = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.absence.fashion)
utils.plot_histories(absence_fashion, title='Fully-Connected / Absence Fashion', filename='img/dense-absence-fashion.jpg')

In [None]:
absence_cifar10 = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.absence.cifar10, (32, 32, 3))
utils.plot_histories(absence_cifar10, title='Fully-Connected / Absence CIFAR-10', filename='img/dense-absence-cifar10.jpg')

In [None]:
complex_fashion = utils.evaluate_all(define_model, 'Fully-Connected', mil_benchmarks.complex.fashion)
utils.plot_histories(complex_fashion, title='Fully-Connected / Complex Fashion', filename='img/dense-complex-fashion.jpg')