# Train SALLY ensemble

Johann Brehmer, Kyle Cranmer, Marco Farina, Felix Kling, Duccio Pappadopulo, Josh Ruderman 2018

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
import logging
import os

from madminer.sampling import SampleAugmenter
from madminer.sampling import multiple_benchmark_thetas
from madminer.sampling import constant_morphing_theta, multiple_morphing_thetas, random_morphing_thetas
from madminer.ml import MLForge, EnsembleForge


In [3]:
logging.basicConfig(
    format='%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s',
    datefmt='%H:%M',
    level=logging.INFO
)

for key in logging.Logger.manager.loggerDict:
    if "madminer" not in key:
        logging.getLogger(key).setLevel(logging.WARNING)

In [4]:
base_dir = '/Users/johannbrehmer/work/projects/madminer/diboson_mining/'
mg_dir = '/Users/johannbrehmer/work/projects/madminer/MG5_aMC_v2_6_2/'

In [5]:
sample_dir = base_dir + 'data/samples/wgamma_sys/'
card_dir = base_dir + 'cards/wgamma/'
ufo_model_dir = card_dir + 'SMWgamma_UFO'
run_card_dir = card_dir + 'run_cards/'
mg_process_dir = base_dir + 'data/mg_processes/wgamma_sys/'
log_dir = base_dir + 'logs/wgamma_sys/'
temp_dir = base_dir + 'data/temp'
delphes_dir = mg_dir + 'Delphes'
model_dir = base_dir + 'data/models/wgamma_sys/'

## Settings

In [6]:
n_estimators = 3

## Training function

In [7]:
def train_ensemble(filename, use_tight_cuts=True, n_estimators=n_estimators, **kwargs):
    cut_label = '_tight' if use_tight_cuts else ''
    
    ensemble = EnsembleForge(n_estimators, debug=True)

    ensemble.train_all(
        method='sally',
        x_filename=[sample_dir + 'train_local{}/x_train_{}.npy'.format(cut_label, i) for i in range(n_estimators)],
        t_xz0_filename=[sample_dir + 'train_local{}/t_xz_train_{}.npy'.format(cut_label, i) for i in range(n_estimators)],
        **kwargs
    )

    ensemble.calculate_expectation(
        x_filename=sample_dir + 'validation{}/x_validation.npy'.format(cut_label)
    )

    ensemble.save(model_dir + 'sally_ensemble_' + filename)

## All observables

In [None]:
train_ensemble(
    'all',
    use_tight_cuts=False,
    validation_split=0.5,
    early_stopping=True
)

In [8]:
train_ensemble(
    'all_tight',
    use_tight_cuts=True,
    validation_split=0.5,
    early_stopping=True
)

16:40 madminer.ml          INFO    Training 3 estimators in ensemble
16:40 madminer.ml          INFO    Training estimator 1 / 3 in ensemble
16:40 madminer.ml          INFO    Starting training
16:40 madminer.ml          INFO      Method:                 sally
16:40 madminer.ml          INFO      Training data: x at /Users/johannbrehmer/work/projects/madminer/diboson_mining/data/samples/wgamma_sys/train_local_tight/x_train_0.npy
16:40 madminer.ml          INFO                     t_xz (theta0) at  /Users/johannbrehmer/work/projects/madminer/diboson_mining/data/samples/wgamma_sys/train_local_tight/t_xz_train_0.npy
16:40 madminer.ml          INFO      Features:               all
16:40 madminer.ml          INFO      Method:                 sally
16:40 madminer.ml          INFO      Hidden layers:          (100, 100)
16:40 madminer.ml          INFO      Activation function:    tanh
16:40 madminer.ml          INFO      Batch size:             128
16:40 madminer.ml          INFO      Trainer

16:55 madminer.utils.ml.sc INFO      Epoch 36: train loss 141.4577 (mse_score: 141.4577)
16:55 madminer.utils.ml.sc INFO                val. loss  141.9442 (mse_score: 141.9442) (*)
16:56 madminer.utils.ml.sc INFO      Epoch 37: train loss 141.2071 (mse_score: 141.2071)
16:56 madminer.utils.ml.sc INFO                val. loss  141.9816 (mse_score: 141.9816)
16:56 madminer.utils.ml.sc INFO      Epoch 38: train loss 141.0516 (mse_score: 141.0516)
16:56 madminer.utils.ml.sc INFO                val. loss  141.8954 (mse_score: 141.8954) (*)
16:56 madminer.utils.ml.sc INFO      Epoch 39: train loss 140.8780 (mse_score: 140.8780)
16:56 madminer.utils.ml.sc INFO                val. loss  141.7160 (mse_score: 141.7160) (*)
16:57 madminer.utils.ml.sc INFO      Epoch 40: train loss 140.7259 (mse_score: 140.7259)
16:57 madminer.utils.ml.sc INFO                val. loss  142.4466 (mse_score: 142.4466)
16:57 madminer.utils.ml.sc INFO      Epoch 41: train loss 140.5936 (mse_score: 140.5936)
16:57 mad

17:09 madminer.utils.ml.sc INFO      Epoch 21: train loss 147.3945 (mse_score: 147.3945)
17:09 madminer.utils.ml.sc INFO                val. loss  148.6041 (mse_score: 148.6041)
17:09 madminer.utils.ml.sc INFO      Epoch 22: train loss 146.9756 (mse_score: 146.9756)
17:09 madminer.utils.ml.sc INFO                val. loss  148.1880 (mse_score: 148.1880) (*)
17:09 madminer.utils.ml.sc INFO      Epoch 23: train loss 146.6536 (mse_score: 146.6536)
17:09 madminer.utils.ml.sc INFO                val. loss  148.0531 (mse_score: 148.0531) (*)
17:10 madminer.utils.ml.sc INFO      Epoch 24: train loss 146.2626 (mse_score: 146.2626)
17:10 madminer.utils.ml.sc INFO                val. loss  147.7134 (mse_score: 147.7134) (*)
17:10 madminer.utils.ml.sc INFO      Epoch 25: train loss 145.9794 (mse_score: 145.9794)
17:10 madminer.utils.ml.sc INFO                val. loss  147.4437 (mse_score: 147.4437) (*)
17:11 madminer.utils.ml.sc INFO      Epoch 26: train loss 145.6842 (mse_score: 145.6842)
17:11

17:28 madminer.utils.ml.sc INFO      Epoch 6: train loss 156.6738 (mse_score: 156.6738)
17:28 madminer.utils.ml.sc INFO                val. loss  158.1348 (mse_score: 158.1348) (*)
17:28 madminer.utils.ml.sc INFO      Epoch 7: train loss 155.5461 (mse_score: 155.5461)
17:28 madminer.utils.ml.sc INFO                val. loss  157.1368 (mse_score: 157.1368) (*)
17:29 madminer.utils.ml.sc INFO      Epoch 8: train loss 154.6864 (mse_score: 154.6864)
17:29 madminer.utils.ml.sc INFO                val. loss  156.5643 (mse_score: 156.5643) (*)
17:29 madminer.utils.ml.sc INFO      Epoch 9: train loss 153.7223 (mse_score: 153.7223)
17:29 madminer.utils.ml.sc INFO                val. loss  156.1281 (mse_score: 156.1281) (*)
17:30 madminer.utils.ml.sc INFO      Epoch 10: train loss 153.1002 (mse_score: 153.1002)
17:30 madminer.utils.ml.sc INFO                val. loss  155.5065 (mse_score: 155.5065) (*)
17:30 madminer.utils.ml.sc INFO      Epoch 11: train loss 152.3693 (mse_score: 152.3693)
17:30

17:46 madminer.utils.ml.sc INFO    Finished training
17:46 madminer.ml          INFO    Calculating expectation for 3 estimators in ensemble
17:46 madminer.ml          INFO    Starting evaluation for estimator 1 / 3 in ensemble
17:46 madminer.ml          INFO    Starting evaluation for estimator 2 / 3 in ensemble
17:46 madminer.ml          INFO    Starting evaluation for estimator 3 / 3 in ensemble


## Shuffled label check

In [None]:
train_ensemble(
    'all_shuffled',
    use_tight_cuts=False,
    validation_split=0.5,
    early_stopping=True,
    shuffle_labels=True
)

## Minimal observable basis (no jets)

In [7]:
min_obs = [0,1] + list(range(4,12)) + list(range(16,33))

In [None]:
train_ensemble(
    'minimal',
    use_tight_cuts=False,
    features=[min_obs for _ in range(n_estimators)],
)

In [None]:
train_ensemble(
    'minimal_tight',
    use_tight_cuts=True,
    features=[min_obs for _ in range(n_estimators)],
)

## Just resurrection phi

In [None]:
train_ensemble(
    'phi_tight',
    use_tight_cuts=True,
    features=[[32] for _ in range(n_estimators)],
    validation_split=0.5,
    early_stopping=True,
)