# Imports

In [3]:
from main import build_model, normalize_img
import tensorflow_datasets as tfds
import tensorflow as tf
import pandas as pd

In [4]:
import matplotlib.pyplot as plt

%matplotlib inline

# Utility Functions

In [5]:
def build_pipelines(num_train_examples, num_test_examples):
    """
    Builds the training and test set
    """
    (ds_train, ds_test), ds_info = tfds.load(
        "mnist",
        split=["train", "test"],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )

    ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE).take(
        num_train_examples
    )
    ds_train = ds_train.cache()
    ds_train = ds_train.shuffle(num_train_examples)
    ds_train = ds_train.batch(128)
    ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

    ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE).take(
        num_test_examples
    )
    ds_test = ds_test.batch(128)
    ds_test = ds_test.cache()
    ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

    return ds_train, ds_test


In [6]:
def run_experiment(i, num_train_examples=60000, num_test_examples=10000):
    """
    Trains an model on the MNIST dataset given the number of training and test examples
    """
    print(f"Running trial {i} with {num_train_examples:,} training examples and {num_test_examples:,} test examples.")
    ds_train, ds_test = build_pipelines(num_train_examples, num_test_examples)
    model = build_model()
    model.fit(ds_train, epochs=6, validation_data=ds_test, verbose=0)
    _, test_accuracy = model.evaluate(ds_test, verbose=0)
    return test_accuracy


# Experiments

## Variable training set, fixed test set

In [7]:
%%time
trials = 1000

training_size = [1000, 10000, 30000, 60000]
test_size = 10000

data = [[run_experiment(i, ts, test_size) for i in range(trials)] for ts in training_size]

col_names = [str(f"{i:,}") for i in training_size]
df = pd.DataFrame(data).transpose()
df.columns = col_names

df.boxplot(col_names, figsize=(10,10))\
    .set(xlabel="Training Set Size", ylabel="Test Accuracy")

Running trial 0 with 1,000 training examples and 10,000 test examples.
Running trial 1 with 1,000 training examples and 10,000 test examples.
Running trial 2 with 1,000 training examples and 10,000 test examples.
Running trial 3 with 1,000 training examples and 10,000 test examples.
Running trial 4 with 1,000 training examples and 10,000 test examples.
Running trial 5 with 1,000 training examples and 10,000 test examples.
Running trial 6 with 1,000 training examples and 10,000 test examples.
Running trial 7 with 1,000 training examples and 10,000 test examples.
Running trial 8 with 1,000 training examples and 10,000 test examples.
Running trial 9 with 1,000 training examples and 10,000 test examples.
Running trial 10 with 1,000 training examples and 10,000 test examples.
Running trial 11 with 1,000 training examples and 10,000 test examples.
Running trial 12 with 1,000 training examples and 10,000 test examples.
Running trial 13 with 1,000 training examples and 10,000 test examples.
Ru

KeyboardInterrupt: 