Copyright 2020 The dnn-predict-accuracy Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# README

This notebook contains code for training predictors of DNN accuracy. 

Contents:

(1) Loading the Small CNN Zoo dataset

(2) Figure 2 of the paper

(3) Examples of training Logit-Linear / GBM / DNN predictors

(4) Transfer of predictors across CNN collections

(5) Various visualizations of CNN collections

Code dependencies:
Light-GBM package


In [0]:
from __future__ import division

import time
import os
import json
import sys
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import colors
import pandas as pd
import seaborn as sns
from scipy import stats
from tensorflow import keras
from tensorflow.io import gfile
import lightgbm as lgb

DATAFRAME_CONFIG_COLS = [
    'config.w_init',
    'config.activation',
    'config.learning_rate',
    'config.init_std',
    'config.l2reg',
    'config.train_fraction',
    'config.dropout']
CATEGORICAL_CONFIG_PARAMS = ['config.w_init', 'config.activation']
CATEGORICAL_CONFIG_PARAMS_PREFIX = ['winit', 'act']
DATAFRAME_METRIC_COLS = [
    'test_accuracy',
    'test_loss',
    'train_accuracy',
    'train_loss']
TRAIN_SIZE = 15000

# TODO: modify the following lines
CONFIGS_PATH_BASE = 'path_to_the_file_with_best_configs'
MNIST_OUTDIR = "path_to_files_with_mnist_collection"
FMNIST_OUTDIR = 'path_to_files_with_fmnist_collection'
CIFAR_OUTDIR = 'path_to_files_with_cifar10gs_collection'
SVHN_OUTDIR = 'path_to_files_with_svhngs_collection'

def filter_checkpoints(weights, dataframe,
                       target='test_accuracy',
                       stage='final', binarize=True):
  """Take one checkpoint per run and do some pre-processing.

  Args:
    weights: numpy array of shape (num_runs, num_weights)
    dataframe: pandas DataFrame which has num_runs rows. First 4 columns should
      contain test_accuracy, test_loss, train_accuracy, train_loss respectively.
    target: string, what to use as an output
    stage: flag defining which checkpoint out of potentially many we will take
      for the run.
    binarize: Do we want to binarize the categorical hyperparams?

  Returns:
    tuple (weights_new, metrics, hyperparams, ckpts), where
      weights_new is a numpy array of shape (num_remaining_ckpts, num_weights),
      metrics is a numpy array of shape (num_remaining_ckpts, num_metrics) with
        num_metric being the length of DATAFRAME_METRIC_COLS,
      hyperparams is a pandas DataFrame of num_remaining_ckpts rows and columns
        listed in DATAFRAME_CONFIG_COLS.
      ckpts is an instance of pandas Index, keeping filenames of the checkpoints
    All the num_remaining_ckpts rows correspond to one checkpoint out of each
    run we had.
  """

  assert target in DATAFRAME_METRIC_COLS, 'unknown target'
  ids_to_take = []
  # Keep in mind that the rows of the DataFrame were sorted according to ckpt
  # Fetch the unit id corresponding to the ckpt of the first row
  current_uid = dataframe.axes[0][0].split('/')[-2]  # get the unit id
  steps = []
  for i in range(len(dataframe.axes[0])):
    # Fetch the new unit id
    ckpt = dataframe.axes[0][i]
    parts = ckpt.split('/')
    if parts[-2] == current_uid:
      steps.append(int(parts[-1].split('-')[-1]))
    else:
      # We need to process the previous unit
      # and choose which ckpt to take
      steps_sort = sorted(steps)
      target_step = -1
      if stage == 'final':
        target_step = steps_sort[-1]
      elif stage == 'early':
        target_step = steps_sort[0]
      else:  # middle
        target_step = steps_sort[int(len(steps) / 2)]
      offset = [j for (j, el) in enumerate(steps) if el == target_step][0]
      # Take the DataFrame row with the corresponding row id
      ids_to_take.append(i - len(steps) + offset)
      current_uid = parts[-2]
      steps = [int(parts[-1].split('-')[-1])]

  # Fetch the hyperparameters of the corresponding checkpoints
  hyperparams = dataframe[DATAFRAME_CONFIG_COLS]
  hyperparams = hyperparams.iloc[ids_to_take]
  if binarize:
    # Binarize categorical features
    hyperparams = pd.get_dummies(
        hyperparams,
        columns=CATEGORICAL_CONFIG_PARAMS,
        prefix=CATEGORICAL_CONFIG_PARAMS_PREFIX)
  else:
    # Make the categorical features have pandas type "category"
    # Then LGBM can use those as categorical
    hyperparams.is_copy = False
    for col in CATEGORICAL_CONFIG_PARAMS:
      hyperparams[col] = hyperparams[col].astype('category')

  # Fetch the file paths of the corresponding checkpoints
  ckpts = dataframe.axes[0][ids_to_take]

  return (weights[ids_to_take, :],
          dataframe[DATAFRAME_METRIC_COLS].values[ids_to_take, :].astype(
              np.float32),
          hyperparams,
          ckpts)

def build_fcn(n_layers, n_hidden, n_outputs, dropout_rate, activation,
              w_regularizer, w_init, b_init, last_activation='softmax'):
  """Fully connected deep neural network."""
  model = keras.Sequential()
  model.add(keras.layers.Flatten())
  for _ in range(n_layers):
    model.add(
        keras.layers.Dense(
            n_hidden,
            activation=activation,
            kernel_regularizer=w_regularizer,
            kernel_initializer=w_init,
            bias_initializer=b_init))
    if dropout_rate > 0.0:
      model.add(keras.layers.Dropout(dropout_rate))
  if n_layers > 0:
    model.add(keras.layers.Dense(n_outputs, activation=last_activation))
  else:
    model.add(keras.layers.Dense(
        n_outputs,
        activation='sigmoid',
        kernel_regularizer=w_regularizer,
        kernel_initializer=w_init,
        bias_initializer=b_init))
  return model

def extract_summary_features(w, qts=(0, 25, 50, 75, 100)):
  """Extract various statistics from the flat vector w."""
  features = np.percentile(w, qts)
  features = np.append(features, [np.std(w), np.mean(w)])
  return features


def extract_per_layer_features(w, qts=None, layers=(0, 1, 2, 3)):
  """Extract per-layer statistics from the weight vector and concatenate."""
  # Indices of the location of biases/kernels in the flattened vector
  all_boundaries = {
      0: [(0, 16), (16, 160)], 
      1: [(160, 176), (176, 2480)], 
      2: [(2480, 2496), (2496, 4800)], 
      3: [(4800, 4810), (4810, 4970)]}
  boundaries = []
  for layer in layers:
    boundaries += all_boundaries[layer]
  
  if not qts:
    features = [extract_summary_features(w[a:b]) for (a, b) in boundaries]
  else:
    features = [extract_summary_features(w[a:b], qts) for (a, b) in boundaries]
  all_features = np.concatenate(features)
  return all_features


# 1. Loading the Small CNN Zoo dataset

The following code loads the dataset (trained weights from *.npy files and all the relevant metrics, including accuracy, from *.csv files). 

In [0]:
all_dirs = [MNIST_OUTDIR, FMNIST_OUTDIR, CIFAR_OUTDIR, SVHN_OUTDIR]
weights = {'mnist': None,
            'fashion_mnist': None,
            'cifar10': None,
            'svhn_cropped': None}
metrics = {'mnist': None,
            'fashion_mnist': None,
            'cifar10': None,
            'svhn_cropped': None}
for (dirname, dataname) in zip(
    all_dirs, ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']):
  print('Loading %s' % dataname)
  with gfile.GFile(os.path.join(dirname, "all_weights.npy"), "rb") as f:
    # Weights of the trained models
    weights[dataname] = np.load(f)
  with gfile.GFile(os.path.join(dirname, "all_metrics.csv")) as f:
    # pandas DataFrame with metrics
    metrics[dataname] = pd.read_csv(f, index_col=0)

Next it filters the dataset by keeping only checkpoints corresponding to 18 epochs and discarding runs that resulted in numerical instabilities. Finally, it performs the train / test splits.

In [0]:
weights_train = {}
weights_test = {}
configs_train = {}
configs_test = {}
outputs_train = {}
outputs_test = {}

for dataset in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:
  # Take one checkpoint per each run
  # If using GBM as predictor, set binarize=False
  weights_flt, metrics_flt, configs_flt, ckpts = filter_checkpoints(
      weights[dataset], metrics[dataset], binarize=True)

  # Filter out DNNs with NaNs and Inf in the weights
  idx_valid = (np.isfinite(weights_flt).mean(1) == 1.0)
  inputs = np.asarray(weights_flt[idx_valid], dtype=np.float32)
  outputs = np.asarray(metrics_flt[idx_valid], dtype=np.float32)
  configs = configs_flt.iloc[idx_valid]
  ckpts = ckpts[idx_valid]

  # Shuffle and split the data
  random_idx = list(range(inputs.shape[0]))
  np.random.shuffle(random_idx)
  weights_train[dataset], weights_test[dataset] = (
      inputs[random_idx[:TRAIN_SIZE]], inputs[random_idx[TRAIN_SIZE:]])
  outputs_train[dataset], outputs_test[dataset] = (
      1. * outputs[random_idx[:TRAIN_SIZE]],
      1. * outputs[random_idx[TRAIN_SIZE:]])
  configs_train[dataset], configs_test[dataset] = (
      configs.iloc[random_idx[:TRAIN_SIZE]], 
      configs.iloc[random_idx[TRAIN_SIZE:]])

# 2. Figure 2 of the paper

Next we plot distribution of CNNs from 4 collections in Small CNN Zoo according to their train / test accuracy

In [0]:
plt.figure(figsize = (16, 8))
pic_id = 0

for dataset in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:
  pic_id += 1
  sp = plt.subplot(2, 4, pic_id)

  outputs = outputs_train[dataset]

  if dataset == 'mnist':
    plt.title('MNIST', fontsize=24)
  if dataset == 'fashion_mnist':
    plt.title('Fashion MNIST', fontsize=24)
  if dataset == 'cifar10':
    plt.title('CIFAR10-GS', fontsize=24)
  if dataset == 'svhn_cropped':
    plt.title('SVHN-GS', fontsize=24)

  # 1. test accuracy hist plots
  sns.distplot(np.array(outputs[:, 0]), bins=15, kde=False, color='green')
  plt.xlim((0.0, 1.0))
  sp.axes.get_xaxis().set_ticklabels([])
  sp.axes.get_yaxis().set_ticklabels([])
  pic_id += 4
  sp = plt.subplot(2, 4, pic_id)

  # 2. test / train accuracy scatter plots
  NUM_POINTS = 1000
  random_idx = range(len(outputs))
  np.random.shuffle(random_idx)
  plt.plot([0.0, 1.0], [0.0, 1.0], 'r--')
  sns.scatterplot(np.array(outputs[random_idx[:NUM_POINTS], 0]),  # test acc
                  np.array(outputs[random_idx[:NUM_POINTS], 2]),  # train acc
                  s=30
                  )
  if pic_id == 5:
    plt.ylabel('Train accuracy', fontsize=22)
    sp.axes.get_yaxis().set_ticklabels([0.0, 0.2, .4, .6, .8, 1.])
  else:
    sp.axes.get_yaxis().set_ticklabels([])
  plt.xlim((0.0, 1.0))
  plt.ylim((0.0, 1.0))
  sp.axes.get_xaxis().set_ticks([0.0, 0.2, .4, .6, .8, 1.])
  sp.axes.tick_params(axis='both', labelsize=18)
  plt.xlabel('Test accuracy', fontsize=22)

  pic_id -= 4

plt.tight_layout()

# 3. Examples of training Logit-Linear / GBM / DNN predictors

Next we train 3 models on all 4 CNN collections with the best hyperparameter configurations we found during our studies (documented in Table 2 and Section 4 of the paper).

First, we load the best hyperparameter configurations we found.
The file best_configs.json contains a list. 
Each entry of that list corresponds to the single hyperparameter configuration. 
It consists of: 

  (1) name of the CNN collection (mnist/fashion mnist/cifar10/svhn) 
  
  (2) predictor type (linear/dnn/lgbm)
  
  (3) type of inputs, (refer to Table 2)
  
  (4) value of MSE you will get training with these settings, 
  
  (5) dictionary of "parameter name"-> "parameter value" for the given type of predictor.

In [0]:
with gfile.GFile(os.path.join(CONFIGS_PATH_BASE, 'best_configs.json'), 'r') as file:
  best_configs = json.load(file)



# 3.1 Training GBM predictors

GBM code below requires the lightgbm package.

This is an example of training GBM on CIFAR10-GS CNN collection using per-layer weights statistics as inputs.

In [0]:
# Take the best config we found
config = [el[-1] for el in best_configs if
          el[0] == 'cifar10' and
          el[1] == 'lgbm' and
          el[2] == 'wstats-perlayer'][0]

# Pre-process the weights
train_x = np.apply_along_axis(
    extract_per_layer_features, 1,
    weights_train['cifar10'],
    qts=None,
    layers=(0, 1, 2, 3))
test_x = np.apply_along_axis(
    extract_per_layer_features, 1,
    weights_test['cifar10'], 
    qts=None, 
    layers=(0, 1, 2, 3))
# Get the target values
train_y, test_y = outputs_train['cifar10'][:, 0], outputs_test['cifar10'][:, 0]

# Define the GBM model
lgbm_model = lgb.LGBMRegressor(
    num_leaves=config['num_leaves'],
    max_depth=config['max_depth'],
    learning_rate=config['learning_rate'],
    max_bin=int(config['max_bin']),
    min_child_weight=config['min_child_weight'],
    reg_lambda=config['reg_lambda'],
    reg_alpha=config['reg_alpha'],
    subsample=config['subsample'],
    subsample_freq=1,  # it means always subsample
    colsample_bytree=config['colsample_bytree'],
    n_estimators=2000,
    first_metric_only=True
)

# Train the GBM model;
# Early stopping will be based on rmse of test set
eval_metric = ['rmse', 'l1']
eval_set = [(test_x, test_y)]
lgbm_model.fit(train_x, train_y, verbose=100,
               early_stopping_rounds=500,
               eval_metric=eval_metric,
               eval_set=eval_set,
               eval_names=['test'])

# Evaluate the GBM model
assert hasattr(lgbm_model, 'best_iteration_')
# Choose the step which had the best rmse on the test set
best_iter = lgbm_model.best_iteration_ - 1
lgbm_history = lgbm_model.evals_result_
mse = lgbm_history['test']['rmse'][best_iter] ** 2.
mad = lgbm_history['test']['l1'][best_iter]
var = np.mean((test_y - np.mean(test_y)) ** 2.)
r2  = 1. - mse / var
print('Test MSE = ', mse)
print('Test MAD = ', mad)
print('Test R2 = ', r2)

# 3.2 Training DNN predictors

This is an example of training DNN on MNIST CNN collection using all weights as inputs.

In [0]:
# Take the best config we found
config = [el[-1] for el in best_configs if
          el[0] == 'mnist' and
          el[1] == 'dnn' and
          el[2] == 'weights'][0]

train_x, test_x = weights_train['cifar10'], weights_test['cifar10']
train_y, test_y = outputs_train['cifar10'][:, 0], outputs_test['cifar10'][:, 0]

# Get the optimizer, initializers, and regularizers
optimizer = keras.optimizers.get(config['optimizer_name'])
optimizer.learning_rate = config['learning_rate']
w_init = keras.initializers.get(config['w_init_name'])
if config['w_init_name'].lower() in ['truncatednormal', 'randomnormal']:
  w_init.stddev = config['init_stddev']
b_init = keras.initializers.get('zeros')
w_reg = (keras.regularizers.l2(config['l2_penalty']) 
         if config['l2_penalty'] > 0 else None)

# Get the fully connected DNN architecture
dnn_model = build_fcn(int(config['n_layers']),
                      int(config['n_hiddens']),
                      1,  # number of outputs
                      config['dropout_rate'],
                      'relu',
                      w_reg, w_init, b_init,
                      'sigmoid')  # Last activation
dnn_model.compile(
    optimizer=optimizer,
    loss='mean_squared_error',
    metrics=['mse', 'mae'])

# Train the model
dnn_model.fit(
    train_x, train_y,
    batch_size=int(config['batch_size']),
    epochs=300,
    validation_data=(test_x, test_y),
    verbose=1,
    callbacks=[keras.callbacks.EarlyStopping(
        monitor='val_loss',
        min_delta=0,
        patience=10,
        verbose=0,
        mode='auto',
        baseline=None,
        restore_best_weights=False)]
    )

# Evaluate the model
eval_train = dnn_model.evaluate(train_x, train_y, batch_size=128, verbose=0)
eval_test = dnn_model.evaluate(test_x, test_y, batch_size=128, verbose=0)
assert dnn_model.metrics_names[1] == 'mean_squared_error'
assert dnn_model.metrics_names[2] == 'mean_absolute_error'
mse = eval_test[1]
var = np.mean((test_y - np.mean(test_y)) ** 2.)
r2  = 1. - mse / var
print('Test MSE = ', mse)
print('Test MAD = ', eval_test[2])
print('Test R2 = ', r2)

# 3.3 Train Logit-Linear predictors

This is an example of training Logit-Linear model on CIFAR10 CNN collection using hyperparameters as inputs.

In [0]:
# Take the best config we found
config = [el[-1] for el in best_configs if
          el[0] == 'cifar10' and
          el[1] == 'linear' and
          el[2] == 'hyper'][0]

# Turn DataFrames to numpy arrays. 
# Since we used "binarize=True" when calling filter_checkpoints all the
# categorical columns were binarized.
train_x = configs_train['cifar10'].values.astype(np.float32)
test_x = configs_test['cifar10'].values.astype(np.float32)
train_y, test_y = outputs_train['cifar10'][:, 0], outputs_test['cifar10'][:, 0]

# Get the optimizer, initializers, and regularizers
optimizer = keras.optimizers.get(config['optimizer_name'])
optimizer.learning_rate = config['learning_rate']
w_init = keras.initializers.get(config['w_init_name'])
if config['w_init_name'].lower() in ['truncatednormal', 'randomnormal']:
  w_init.stddev = config['init_stddev']
b_init = keras.initializers.get('zeros')
w_reg = (keras.regularizers.l2(config['l2_penalty']) 
         if config['l2_penalty'] > 0 else None)

# Get the linear architecture (DNN with 0 layers)
dnn_model = build_fcn(int(config['n_layers']),
                      int(config['n_hiddens']),
                      1,  # number of outputs
                      None,  # Dropout is not used
                      'relu',
                      w_reg, w_init, b_init,
                      'sigmoid')  # Last activation
dnn_model.compile(
    optimizer=optimizer,
    loss='mean_squared_error',
    metrics=['mse', 'mae'])

# Train the model
dnn_model.fit(
    train_x, train_y,
    batch_size=int(config['batch_size']),
    epochs=300,
    validation_data=(test_x, test_y),
    verbose=1,
    callbacks=[keras.callbacks.EarlyStopping(
        monitor='val_loss',
        min_delta=0,
        patience=10,
        verbose=0,
        mode='auto',
        baseline=None,
        restore_best_weights=False)]
    )

# Evaluate the model
eval_train = dnn_model.evaluate(train_x, train_y, batch_size=128, verbose=0)
eval_test = dnn_model.evaluate(test_x, test_y, batch_size=128, verbose=0)
assert dnn_model.metrics_names[1] == 'mean_squared_error'
assert dnn_model.metrics_names[2] == 'mean_absolute_error'
mse = eval_test[1]
var = np.mean((test_y - np.mean(test_y)) ** 2.)
r2  = 1. - mse / var
print('Test MSE = ', mse)
print('Test MAD = ', eval_test[2])
print('Test R2 = ', r2)

# 4. Figure 4: Transfer across datasets

Train GBM predictor using statistics of all layers as inputs on all 4 CNN collections. Then evaluate them on each of the 4 CNN collections (without fine-tuning). Store all results.

In [0]:
transfer_results = {}

for dataset in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:
  print('Training on %s' % dataset)
  transfer_results[dataset] = {}
  
  train_x = weights_train[dataset]
  test_x = weights_test[dataset]
  train_y = outputs_train[dataset][:, 0]
  test_y = outputs_test[dataset][:, 0]

  # Pre-process the weights by taking the statistics across layers
  train_x = np.apply_along_axis(
      extract_per_layer_features, 1, 
      train_x, qts=None, layers=(0, 1, 2, 3))
  test_x = np.apply_along_axis(
      extract_per_layer_features, 1,
      test_x, qts=None, layers=(0, 1, 2, 3))

  # Take the best config we found
  config = [el[-1] for el in best_configs if
            el[0] == dataset and
            el[1] == 'lgbm' and
            el[2] == 'wstats-perlayer'][0]

  lgbm_model = lgb.LGBMRegressor(
      num_leaves=config['num_leaves'],
      max_depth=config['max_depth'], 
      learning_rate=config['learning_rate'], 
      max_bin=int(config['max_bin']),
      min_child_weight=config['min_child_weight'],
      reg_lambda=config['reg_lambda'],
      reg_alpha=config['reg_alpha'],
      subsample=config['subsample'],
      subsample_freq=1,  # Always subsample
      colsample_bytree=config['colsample_bytree'],
      n_estimators=4000,
      first_metric_only=True,
      )
  
  # Train the GBM model
  lgbm_model.fit(
      train_x,
      train_y,
      verbose=100,
      # verbose=False,
      early_stopping_rounds=500,
      eval_metric=['rmse', 'l1'],
      eval_set=[(test_x, test_y)],
      eval_names=['test'])
  
  # Evaluate on all 4 CNN collections
  for transfer_to in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:
    print('Evaluating on %s' % transfer_to)
    # Take the test split of the dataset
    transfer_x = weights_test[transfer_to]
    transfer_x = np.apply_along_axis(
        extract_per_layer_features, 1,
        transfer_x, qts=None, layers=(0, 1, 2, 3))
    y_hat = lgbm_model.predict(transfer_x)
    transfer_results[dataset][transfer_to] = y_hat

And plot everything

In [0]:
plt.figure(figsize = (15, 15))
pic_id = 0
for dataset in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:
  for transfer_to in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:
    pic_id += 1
    sp = plt.subplot(4, 4, pic_id)
    # Take true labels
    y_true = outputs_test[transfer_to][:, 0]
    # Take the predictions of the model
    y_hat = transfer_results[dataset][transfer_to]
    plt.plot([0.01, .99], [0.01, .99], 'r--', linewidth=2)
    sns.scatterplot(y_true, y_hat)
    # Compute the Kendall's tau coefficient
    tau = stats.kendalltau(y_true, y_hat)[0]
    plt.text(0.05, 0.9, r"$\tau=%.3f$" % tau, fontsize=25)
    plt.xlim((0.0, 1.0))
    plt.ylim((0.0, 1.0))

    if pic_id % 4 != 1:
      sp.axes.get_yaxis().set_ticklabels([])
    else:
      plt.ylabel('Predictions', fontsize=22)
      sp.axes.tick_params(axis='both', labelsize=15)

    if pic_id < 13:
      sp.axes.get_xaxis().set_ticklabels([])
    else:
      plt.xlabel('Test accuracy', fontsize=22)
      sp.axes.tick_params(axis='both', labelsize=15)

    if pic_id == 1:
      plt.title('MNIST', fontsize=22)
    if pic_id == 2:
      plt.title('Fashion-MNIST', fontsize=22)
    if pic_id == 3:
      plt.title('CIFAR10-GS', fontsize=22)
    if pic_id == 4:
      plt.title('SVHN-GS', fontsize=22)

plt.tight_layout()

# 5. Figure 3: various 2d plots based on subsets of weights statistics

Take weight statistics for the CIFAR10 CNN collection. Plot various 2d plots

In [0]:
# Take the per-layer weights stats for the train split of CIFAR10-GS collection
per_layer_stats = np.apply_along_axis(
    extract_per_layer_features, 1,
    weights_train['cifar10'])
train_test_accuracy = outputs_train['cifar10'][:, 0]
# Positions of various stats
b0min = 0  # min of the first layer
b0max = 4  # max of the first layer
bnmin = 6*7 + 0  # min of the last layer
bnmax = 6*7 + 4  # max of the last layer
x = per_layer_stats[:,b0max] - per_layer_stats[:,b0min]
y = per_layer_stats[:,bnmax] - per_layer_stats[:,bnmin]

plt.figure(figsize=(10,8))
plt.scatter(x, y, s=15,
            c=train_test_accuracy,
            cmap="jet",
            vmin=0.1,
            vmax=0.54,
            linewidths=0)
plt.yscale("log")
plt.xscale("log")
plt.ylim(0.1, 10)
plt.xlim(0.1, 10)
plt.xlabel("Bias range, first layer", fontsize=22)
plt.ylabel("Bias range, final layer", fontsize=22)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=18) 
plt.tight_layout()