# Per-sample fit times on binary MNIST

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import dotenv
import pandas as pd
import mlflow
import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.subplots
import plotly.io as pio

import typing
import os
import shutil
import sys

In [None]:
EXPORT = True
SHOW_TITLES = not EXPORT
EXPERIMENT_NAME = 'neural_networks_teaser_plot_fit_times'

In [None]:
# Load environment variables
dotenv.load_dotenv()

# Enable loading of the project module
MODULE_DIR = os.path.join(os.path.abspath(os.path.join(os.path.curdir, os.path.pardir, os.pardir)), 'src')
sys.path.append(MODULE_DIR)

In [None]:
%load_ext autoreload
%autoreload 2
import interpolation_robustness as ir

In [None]:
if EXPORT:
    EXPORT_DIR = os.path.join(ir.util.REPO_ROOT_DIR, 'logs', f'export_{EXPERIMENT_NAME}')
    print('Using export directory', EXPORT_DIR)
    if os.path.exists(EXPORT_DIR):
        shutil.rmtree(EXPORT_DIR)
    os.makedirs(EXPORT_DIR)

def export_fig(fig: go.Figure, filename: str):
    # If export is disabled then do nothing
    if EXPORT:
        export_path = os.path.join(EXPORT_DIR, filename)
        fig.write_image(export_path)
        print('Exported figure at', export_path)


In [None]:
ir.plots.setup_plotly_template()

## Load experiment data

In [None]:
client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
runs = mlflow.search_runs(
    experiment.experiment_id
)

runs = runs.set_index('run_id', drop=False)  # set index, but keep column to not break stuff depending on it

# Convert number of MLP units to integer since they are stored as "[NUM_UNITS]" and soprt by them
runs['params.mlp_units'] = runs['params.mlp_units'].str.strip('[] \t').astype(int)
runs = runs.sort_values(['params.mlp_units'])
print('Loaded', len(runs), 'runs')
assert runs['status'].eq('FINISHED').all()

In [None]:
LOG_DIR = os.path.join(ir.util.REPO_ROOT_DIR, 'logs', EXPERIMENT_NAME)
NUM_SAMPLES = 12873
fit_times = np.zeros((len(runs), NUM_SAMPLES))
mlp_units = np.zeros((len(runs),), dtype=np.int)
for idx, run_id in enumerate(runs.index):
    mlp_units[idx] = runs.loc[run_id]['params.mlp_units']
    current_fit_times = np.load(os.path.join(LOG_DIR, f'fit_times_{run_id}.npy'))
    fit_times[idx] = current_fit_times
print('Loaded all fit times')

## Fit times for different widths

In [None]:
fig = go.Figure(
    data=go.Heatmap(
        z=fit_times,
        x=np.arange(NUM_SAMPLES),
        y=mlp_units
    )
)
fig.update_layout(
    title='Epoch at which training samples were fit',
    xaxis_title='Training sample',
    yaxis_title='Number of hidden units',
    yaxis_type='category'
)
fig.show()

In [None]:
# Since the behavior seems to differ for very small and very large models, just take the max epoch per sample for p=10, p=200 and p=10k
max_fit_time = np.max(fit_times[[0, -2, -1], :], axis=0)
sorted_max_fit_time_indices = np.flip(np.argsort(max_fit_time))
sorted_max_fit_time = max_fit_time[sorted_max_fit_time_indices]

# Learning rate decay happens at epoch 300, hence throw away all samples which take longer to fit
THRESHOLD = 100
num_discard = np.argmin(sorted_max_fit_time > THRESHOLD)
print('Discarding', num_discard, 'training samples')

fig = go.Figure(
    data=go.Scatter(
        x=np.arange(NUM_SAMPLES),
        y=sorted_max_fit_time,
        name='Last epoch the sample was fit'
    )
)
fig.add_vline(num_discard - 0.5, line_width=1.0, line_dash='dash', line_color='red')
fig.update_layout(
    title='Max fit times over p=10 and p=10k',
    xaxis_title='Training sample',
    yaxis_title='Last epoch the sample was fit'
)
fig.show()

In [None]:
discard_indices = sorted_max_fit_time_indices[:num_discard]
print('Training sample indices to discard:')
print(np.sort(discard_indices))

## Plot the samples that will be discarded

In [None]:
(train_xs, train_ys), _, _ = ir.data.make_image_dataset(
    dataset=ir.data.Dataset.MNIST,
    data_split=ir.data.DataSplit.NoSplit,
    binarized_classes=(1, 3),
    seed=1
)
assert train_xs.shape[0] == NUM_SAMPLES

In [None]:
num_cols = 10
num_rows = int(np.ceil(num_discard / num_cols))

fig, axes = plt.subplots(num_rows, num_cols)
fig.set_size_inches((2 * num_cols, 2 * num_rows))

for row_idx in range(num_rows):
    for col_idx in range(num_cols):
        axes[row_idx, col_idx].xaxis.set_visible(False)
        axes[row_idx, col_idx].yaxis.set_visible(False)

for idx, sample_idx in enumerate(discard_indices):
    row_idx = idx // num_cols
    col_idx = idx % num_cols

    actual_label = {0: 1, 1: 3}[train_ys[sample_idx]]
    axes[row_idx, col_idx].set_title(f'y={actual_label}, idx={sample_idx}')
    axes[row_idx, col_idx].imshow(train_xs[sample_idx], cmap='gray')