# GEMS (Neural Network) Demo

A brief demonstration of GEMS for neural networks, applied to MNIST (5 nodes) 

In [18]:
# Import relevant libraries 
import os

from utils.data_utils import *
from utils.model_utils import *
from utils.dnn3 import *

%matplotlib notebook
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
# Set experiment configurations
exp = {
    "dataset": 'mnist',
    "num_nodes": 5, 
    "model": "dnn2",
    "local_epochs": 40,
    "global_epochs": 50,
    "dropout": 0.5,
    "hidden": 50,
    "secondary_epochs": 10,
    "optimizer": "adam", 
    "loss": 'categorical_crossentropy',
    'final_epsilon': 0.70,
    'max_radius': 4000.0, 
    'num_samples': 500, 
    'fisher_prop': 1.0,
    'fisher_count': 1,
    'ellipse': True,
    'delta': 0.5,
    'ellipse_type': 'fisher',
    "hidden_slack": 1.0,
    "k": 75,
    'holdout_trials': 5,
    'holdout_sizes': [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
}

In [3]:
# Load data 
data_dir = os.path.join("../datasets", "%s_%d" % (exp['dataset'], exp['num_nodes']))
Xtr, ytr, Xval, yval, Xts, yts = load_data(data_dir)
Xts_flat, yts_flat = flatten(Xts, yts)
Xtr_flat, ytr_flat = flatten(Xtr, ytr)

In [8]:
# Train local models 
all_results = {}
models = []
for t in range(exp['num_nodes']):
    print("Training local model on node-%d..." % t)
    model, results = train_local_model(Xtr, ytr, Xts, yts, t, exp)
    models.append(model)

    all_results[t] = results

Training local model on node-0...
Training local model on node-1...
Training local model on node-2...
Training local model on node-3...
Training local model on node-4...


In [26]:
# Compute the GEMS model 
gems_model, results = run_dnn_gems(exp, Xtr, ytr, Xval, yval, Xts, yts, models)

Calculating good-enough spaces.


Number of hidden neurons: 74
Training new logistic models
Pure GEMS accuracy: 0.453400
Holdout Size: 100.
	Raw: 0.422720.
	Gems (Tuned): 0.701960.
	Average (Tuned): 0.469520.
	 Node-0 Loc Model (Tuned): 0.227480
	 Node-1 Loc Model (Tuned): 0.207440
	 Node-2 Loc Model (Tuned): 0.182880
	 Node-3 Loc Model (Tuned): 0.195880
	 Node-4 Loc Model (Tuned): 0.195440
Holdout Size: 200.
	Raw: 0.587400.
	Gems (Tuned): 0.788080.
	Average (Tuned): 0.636560.
	 Node-0 Loc Model (Tuned): 0.309400
	 Node-1 Loc Model (Tuned): 0.200440
	 Node-2 Loc Model (Tuned): 0.196280
	 Node-3 Loc Model (Tuned): 0.197880
	 Node-4 Loc Model (Tuned): 0.165160
Holdout Size: 300.
	Raw: 0.719000.
	Gems (Tuned): 0.827160.
	Average (Tuned): 0.694560.
	 Node-0 Loc Model (Tuned): 0.326120
	 Node-1 Loc Model (Tuned): 0.199400
	 Node-2 Loc Model (Tuned): 0.211640
	 Node-3 Loc Model (Tuned): 0.205560
	 Node-4 Loc Model (Tuned): 0.154360
Holdout Size: 400.
	Raw: 0.762240.
	Gems (Tuned): 0.846760.
	Average (Tuned): 0.737840.
	 Node