In [1]:
#  Required imports
print("Importing standard library")
import os, time, sys

print("Importing python data libraries")
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt

print("Importing custom backends")
from   backends.stats import special_whiten_dataset, special_unwhiten_dataset
from   backends.utils import joint_shuffle, make_sure_dir_exists_for_filename
from   backends.ParameterisedSimulator import ParameterisedSimulator, Simulator_Model3
from   backends.SamplingSimulator      import SamplingSimulator

print("Importing keras objects")
from keras.layers      import BatchNormalization, Dense, Dropout, Input, LeakyReLU, Concatenate, Lambda, Reshape
from keras.models      import Model, Sequential
from keras.optimizers  import Adam, SGD, RMSprop
from keras.constraints import Constraint
from keras.callbacks   import EarlyStopping

print("Importing keras backend")
import keras.backend as K

Importing standard library
Importing python data libraries
Importing standard library
Importing python data libraries
Importing custom backends
Importing keras objects
Importing custom backends
Importing keras objects


Using TensorFlow backend.
Using TensorFlow backend.


Importing keras backend
Importing keras backend


In [2]:
#  Program constants

mu_scan_points = np.linspace(-2, 3, 10)

n_gen_points_per_c_per_ds        = 50000
n_train_points_per_c_per_ds_true = 50000
n_train_points_per_c_per_ds_fake = 50000
use_Adam_optimiser = True

train_batch_size            = 1000
max_epochs                  = 1000

do_whiten_data = True
white_linear_fraction = 0.5

plot_tag = None

axis_lims = [(50, 250), (50, 500), (-1*np.pi, np.pi)]
if do_whiten_data : axis_lims = [(-5., 5.), (-5., 5.), (-5., 5.)]

In [3]:
#  Set up "true" model
model = Simulator_Model3

#  Generate several scan points for "true" model
xsections, datasets, weights = {}, {}, {}
for mu in mu_scan_points :
    model.set_param_value("c", mu)
    xsec, dataset = model.generate(n_gen_points_per_c_per_ds)
    xsections [mu] = xsec
    datasets  [mu] = dataset
    weights   [mu] = np.full(shape=(len(dataset),), fill_value=1./n_gen_points_per_c_per_ds)
    
#  Make sure one of the datasets was the SM
model.set_param_value("c", 0)
if 0 not in xsections :
    xsections [0], datasets [0] = model.generate(n_gen_points_per_c_per_ds)
    mu_scan_points = np.sort(np.concatenate([mu_scan_points, [0]]))
    weights [0] = np.full(shape=(len(dataset),), fill_value=1./n_gen_points_per_c_per_ds)
xsec_SM, dataset_SM, weights_SM = xsections[0], datasets[0], weights[0]

#  Whiten the data using the "special" hard-boundary-respecting method
if do_whiten_data :
    white_dataset_SM, whitening_funcs, whitening_params = special_whiten_dataset (dataset_SM,
                                                                                  [50, 250, 201, white_linear_fraction],
                                                                                  [50, 500, 301, white_linear_fraction],
                                                                                  [-1.*np.pi, np.pi, 101, white_linear_fraction],
                                                                                  rotate=True)
else :
    white_dataset_SM, whitening_funcs, whitening_params = dataset_SM, None, None

def whiten_data (dataset) :
    if do_whiten_data == False : return dataset
    white_datasets = {}
    for mu in dataset :
        white_datasets [mu], _, _ = special_whiten_dataset (dataset[mu], 
                                                            whitening_funcs =whitening_funcs, 
                                                            whitening_params=whitening_params,
                                                            rotate=True)
    return white_datasets

def unwhiten_data (dataset) : 
    if do_whiten_data == False : return dataset 
    unwhite_datasets = {}
    for mu in dataset :
        unwhite_datasets [mu] = special_unwhiten_dataset (dataset[mu], 
                                                          whitening_funcs =whitening_funcs, 
                                                          whitening_params=whitening_params)
    return unwhite_datasets
        
white_datasets   = whiten_data  (datasets      )
unwhite_datasets = unwhiten_data(white_datasets)
num_datasets = len(white_datasets)
print(f"Datasets generated for scan points: {', '.join([f'{mu:.3f}' for mu in mu_scan_points])}")

Datasets generated for scan points: -2.000, -1.444, -0.889, -0.333, 0.000, 0.222, 0.778, 1.333, 1.889, 2.444, 3.000
Datasets generated for scan points: -2.000, -1.444, -0.889, -0.333, 0.000, 0.222, 0.778, 1.333, 1.889, 2.444, 3.000


In [None]:
def plot_dataset (mu_scan_points, xsections, datasets, weights=None, ref=None) :
    num_datasets = len(mu_scan_points)
           
    if type(weights) == type(None) :
        tmp_weights = {mu:np.full(fill_value=1./len(datasets[mu]), shape=(len(datasets[mu]),)) for mu in mu_scan_points}
    elif type(weights) == np.ndarray :
        tmp_weights = {mu:weights/np.sum(weights) for mu in mu_scan_points}
    elif type(weights) == dict : 
        tmp_weights = {mu:tmp_weights/np.sum(tmp_weights) for mu, tmp_weights in weights.items()}
    else :
        raise TypeError(f"Don't know what to do with weights of type {type(weights)}")
    
    plot_reference = False
    if type(ref) != type(None) :
        plot_reference = True
        xsections_ref, datasets_ref, weights_ref = ref
        
        if type(datasets_ref) == np.ndarray :
            tmp_datasets_ref = {mu:datasets_ref for mu in mu_scan_points}
        elif type(datasets_ref) == dict : 
            tmp_datasets_ref = datasets_ref
        else :
            raise TypeError(f"Don't know what to do with reference dataset of type {type(datasets_ref)}")
            
        if type(weights_ref) == type(None) :
            tmp_weights_ref = {mu:np.full(fill_value=1/len(tmp_datasets_ref[mu]), shape=(len(tmp_datasets_ref[mu]),)) for mu in mu_scan_points}
        elif type(weights_ref) == np.ndarray :
            tmp_weights_ref = {mu:weights_ref/np.sum(weights_ref) for mu in mu_scan_points}
        elif type(weights_ref) == dict : 
            tmp_weights_ref = {mu:tmp_weights_ref/np.sum(tmp_weights_ref) for mu, tmp_weights_ref in weights_ref.items()}
        else :
            raise TypeError(f"Don't know what to do with reference weights of type {type(weights_ref)}")
            
        if type(xsections_ref) == dict : 
            tmp_datasets_ref = datasets_ref
        else :
            tmp_datasets_ref = {mu:datasets_ref for mu in mu_scan_points}
        
        if type(xsections_ref) == type(None) :
            raise RuntimeError(f"Reference datasets must be provided reference cross sections too")
        if type(datasets_ref) == dict : 
            tmp_xsections_ref = xsections_ref
        else :
            tmp_xsections_ref = {mu:xsections_ref for mu in mu_scan_points}

    num_plot_rows = 6
    if plot_reference : num_plot_rows = 9
    
    fig = plt.figure(figsize=(4*num_datasets, 4*num_plot_rows))
    plot_row_idx = 0
    for ax_idx, mu in enumerate(mu_scan_points) :
        dataset = datasets  [mu]
        xsec    = xsections [mu]
        weight  = tmp_weights [mu]
        ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
        ax1.hist(dataset[:,0], alpha=0.5, weights=xsec*weight, fill=False, edgecolor="k", linestyle="-", linewidth=3)
        if plot_reference :
            dataset_ref = tmp_datasets_ref [mu]
            xsec_ref    = tmp_xsections_ref [mu]
            weight_ref  = tmp_weights_ref [mu]
            ax1.hist(dataset_ref[:,0], alpha=0.5, weights=xsec_ref*weight_ref, fill=True, color="r", linestyle="-", linewidth=3)
        ax1.set_title(f"$c = {mu:.3f}$", fontsize=30)
        if ax_idx > 0 : continue
        ax1.set_ylabel(r"$\frac{d\sigma}{dA}$", fontsize=30, rotation=0, labelpad=40)

    plot_row_idx = plot_row_idx + 1
    for ax_idx, mu in enumerate(mu_scan_points) :
        dataset = datasets  [mu]
        xsec    = xsections [mu]
        weight  = tmp_weights [mu]
        ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
        ax1.hist(dataset[:,1], alpha=0.5, weights=xsec*weight, fill=False, edgecolor="k", linestyle="-" , linewidth=3)
        if plot_reference :
            dataset_ref = tmp_datasets_ref [mu]
            xsec_ref    = tmp_xsections_ref [mu]
            weight_ref  = tmp_weights_ref [mu]
            ax1.hist(dataset_ref[:,1], alpha=0.5, weights=xsec_ref*weight_ref, fill=True, color="r", linestyle="-", linewidth=3)
        if ax_idx > 0 : continue
        ax1.set_ylabel(r"$\frac{d\sigma}{dB}$", fontsize=30, rotation=0, labelpad=40)

    plot_row_idx = plot_row_idx + 1
    for ax_idx, mu in enumerate(mu_scan_points) :
        dataset = datasets  [mu]
        xsec    = xsections [mu]
        weight  = tmp_weights [mu]
        ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
        ax1.hist(dataset[:,2], alpha=0.5, weights=xsec*weight, fill=False, edgecolor="k", linestyle="-" , linewidth=3)
        if plot_reference :
            dataset_ref = tmp_datasets_ref [mu]
            xsec_ref    = tmp_xsections_ref [mu]
            weight_ref  = tmp_weights_ref [mu]
            ax1.hist(dataset_ref[:,2], alpha=0.5, weights=xsec_ref*weight_ref, fill=True, color="r", linestyle="-", linewidth=3)
        if ax_idx > 0 : continue
        ax1.set_ylabel(r"$\frac{d\sigma}{dC}$", fontsize=30, rotation=0, labelpad=40)

    plot_row_idx = plot_row_idx + 1
    for ax_idx, mu in enumerate(mu_scan_points) :
        dataset = datasets [mu]
        weight  = tmp_weights [mu]
        ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
        ax1.hist2d(dataset[:,0], dataset[:,1], weights=weight)
        if ax_idx == 0 : 
            ax1.set_ylabel("$A$ \n / \n $B$", fontsize=30, rotation=0, labelpad=40)
    
    if plot_reference :
        plot_row_idx = plot_row_idx + 1
        for ax_idx, mu in enumerate(mu_scan_points) :
            dataset = tmp_datasets_ref [mu]
            weight  = tmp_weights_ref [mu]
            ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
            ax1.hist2d(dataset[:,0], dataset[:,1], weights=weight)
            if ax_idx == 0 : 
                ax1.set_ylabel("$A$ \n / \n $B$ \n ref", fontsize=30, rotation=0, labelpad=40, va="center")

    plot_row_idx = plot_row_idx + 1
    for ax_idx, mu in enumerate(mu_scan_points) :
        dataset = datasets [mu]
        weight  = tmp_weights [mu]
        ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
        ax1.hist2d(dataset[:,0], dataset[:,2], weights=weight)
        if ax_idx == 0 : 
            ax1.set_ylabel("$A$ \n / \n $C$", fontsize=30, rotation=0, labelpad=40)
    
    if plot_reference :
        plot_row_idx = plot_row_idx + 1
        for ax_idx, mu in enumerate(mu_scan_points) :
            dataset = tmp_datasets_ref [mu]
            weight  = tmp_weights_ref [mu]
            ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
            ax1.hist2d(dataset[:,0], dataset[:,2], weights=weight)
            if ax_idx == 0 : 
                ax1.set_ylabel("$A$ \n / \n $C$ \n ref", fontsize=30, rotation=0, labelpad=40, va="center")

    plot_row_idx = plot_row_idx + 1
    for ax_idx, mu in enumerate(mu_scan_points) :
        dataset = datasets [mu]
        weight  = tmp_weights [mu]
        ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
        ax1.hist2d(dataset[:,1], dataset[:,2], weights=weight)
        if ax_idx == 0 : 
            ax1.set_ylabel("$B$ \n / \n $C$", fontsize=30, rotation=0, labelpad=40)
    
    if plot_reference :
        plot_row_idx = plot_row_idx + 1
        for ax_idx, mu in enumerate(mu_scan_points) :
            dataset = tmp_datasets_ref [mu]
            weight  = tmp_weights_ref [mu]
            ax1 = fig.add_subplot(num_plot_rows, num_datasets, plot_row_idx*num_datasets + ax_idx + 1)
            ax1.hist2d(dataset[:,1], dataset[:,2], weights=weight)
            if ax_idx == 0 : 
                ax1.set_ylabel("$B$ \n / \n $C$ \n ref", fontsize=30, rotation=0, labelpad=40, va="center")
                
    plt.show()



In [None]:
print("Plotting datasets")
#plot_dataset(mu_scan_points, xsections, datasets, weights, ref=(xsec_SM, dataset_SM, weights_SM))
plot_dataset(mu_scan_points, xsections, datasets, weights)

print("Plotting unwhitened(whitened(datasets)) to show this reconstructs original datasets")
#plot_dataset(mu_scan_points, xsections, unwhiten_data(white_datasets), weights, ref=(xsec_SM, dataset_SM, weights_SM))
plot_dataset(mu_scan_points, xsections, unwhiten_data(white_datasets), weights)


In [None]:
print("Plotting whitened datasets")
#plot_dataset(mu_scan_points, xsections, white_datasets, weights, ref=(xsec_SM, white_dataset_SM, weights_SM))
plot_dataset(mu_scan_points, xsections, white_datasets, weights)

In [None]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

# clip model weights to a given hypercube
class ClipConstraint(Constraint):
    def __init__(self, clip_value):
        self.clip_value = clip_value
    def __call__(self, weights):
        return K.clip(weights, -self.clip_value, self.clip_value)
    def get_config(self):
        return {'clip_value': self.clip_value}

def create_autoreg_wgan_segment (name, **kwargs) :
    #  Parse arguments
    #
    layer_idx      = int (kwargs.get("layer_idx"     ))
    num_conditions = int (kwargs.get("num_conditions"))
    verbose        = bool(kwargs.get("verbose" , True))
    prev_generator = kwargs.get("prev_generator"  , None)
    
    #  Print a status message
    #
    if verbose : 
        print(f"Creating autoregressive WGAN segment: {name}")
        print(f"  - layer_idx      is {layer_idx}")
        print(f"  - num_conditions is {num_conditions}")
      
    #  Create Inputs for the GAN
    #  
    condition_input   = Input((num_conditions,))
    critic_data_input = Input((layer_idx+1,))
    gen_noise_input   = Input((layer_idx+1,))
    
    #  Create critic
    #
    critic_data      = Dense    (20 , kernel_constraint=ClipConstraint(0.02))(critic_data_input)
    critic_data      = LeakyReLU(0.2)(critic_data      )
    critic_condition = Dense    (20 , kernel_constraint=ClipConstraint(0.02))(condition_input  )
    critic_condition = LeakyReLU(0.2)(critic_condition )
    critic           = Concatenate()([critic_data, critic_condition])
    critic           = Dropout(0.1)           (critic)
    critic           = Dense             (50 , kernel_constraint=ClipConstraint(0.02))(critic)
    critic           = LeakyReLU         (0.2)(critic)
    critic           = Dropout(0.1)           (critic)
    critic           = Dense             (50 , kernel_constraint=ClipConstraint(0.02))(critic)
    critic           = LeakyReLU         (0.2)(critic)
    critic           = Dropout(0.1)           (critic)
    critic           = Dense             (50 , kernel_constraint=ClipConstraint(0.02))(critic)
    critic           = LeakyReLU         (0.2)(critic)
    critic           = Dropout(0.1)           (critic)
    critic           = Dense             (1  , activation="linear")(critic)
    critic           = Model(name=name+"_critic", 
                                   inputs=[critic_data_input, condition_input], 
                                   outputs=[critic])
    
    if use_Adam_optimiser :
        critic.compile(loss=wasserstein_loss, optimizer=Adam())
    else :
        critic.compile(loss=wasserstein_loss, optimizer=RMSprop(learning_rate=5e-6, rho=0))
    if verbose : critic.summary()
    
    #  Process Input objects into the first layers of the layer_idx'th generator
    #
    if layer_idx == 0 :
        gen_observable_noise = gen_noise_input
        gen_all_conditions   = condition_input
    else :
        gen_observable_noise = Lambda(lambda x: x[:, layer_idx])(gen_noise_input)
        gen_observable_noise = Reshape((1,))(gen_observable_noise)
        prev_gen_noise       = Lambda(lambda x: x[:,:layer_idx])(gen_noise_input)
        prev_gen_noise       = Reshape((layer_idx,))(prev_gen_noise)
        prev_observables     = prev_generator([prev_gen_noise, condition_input])
        gen_all_conditions   = Concatenate()([condition_input, prev_observables])                

    #  Create generator
    #
    generator_noise     = Dense    (20   )(gen_observable_noise)
    generator_noise     = LeakyReLU(0.2  )(generator_noise     )
    generator_condition = Dense    (20   )(gen_all_conditions  )
    generator_condition = LeakyReLU(0.2  )(generator_condition )
    generator           = Concatenate()([generator_noise, generator_condition])
    #generator           = BatchNormalization()            (generator)
    generator           = Dropout(0.1                    )(generator)
    generator           = Dense  (50                     )(generator)
    generator           = LeakyReLU(0.2                  )(generator)
    #generator           = BatchNormalization()            (generator)
    generator           = Dropout(0.1                    )(generator)
    generator           = Dense  (50,                    )(generator)
    generator           = LeakyReLU(0.2                  )(generator)
    generator           = Dropout(0.1                    )(generator)
    generator           = Dense  (50,                    )(generator)
    generator           = LeakyReLU(0.2                  )(generator)
    generator           = Dropout(0.1                    )(generator)
    generator           = Dense  (1 , activation="linear")(generator)
    if layer_idx > 0 : generator = Concatenate()([prev_observables, generator])
    generator           = Model(name=name+"_generator", 
                                inputs=[gen_noise_input, condition_input], 
                                outputs=[generator])
    if verbose : generator.summary()

    GAN = critic([generator([gen_noise_input, condition_input]), condition_input])
    GAN = Model([gen_noise_input, condition_input], GAN, name=name)
    critic.trainable = False
    
    if use_Adam_optimiser :
        GAN.compile(loss=wasserstein_loss, optimizer=Adam())
    else :
        GAN.compile(loss=wasserstein_loss, optimizer=RMSprop(learning_rate=5e-6, rho=0))
    if verbose : GAN.summary()
        
    return critic, generator, GAN
           
def create_autoreg_wgan (name, **kwargs) :
    num_conditions  = int (kwargs.get("num_conditions" ))
    num_observables = int (kwargs.get("num_observables"))
    verbose         = bool(kwargs.get("verbose"  , True))

    if verbose : 
        print(f"Creating WGAN: {name}")
        print(f"  - num_observables is {num_observables}")
        print(f"  - num_conditions  is {num_conditions}")
    
    layer_critics, layer_generators, layer_GANs = [], [], []
    for i in range(num_observables):
        prev_generator = None
        if len(layer_generators) > 0 : prev_generator = layer_generators[-1]
        critic_i, generator_i, GAN_i = create_autoreg_wgan_segment(name              = f"{name}_observable{i}",
                                                                   layer_idx         = i,
                                                                   num_conditions    = num_conditions,
                                                                   prev_generator    = prev_generator,
                                                                   verbose           = verbose)
        critic_i.trainable, generator_i.trainable = False, False
        layer_critics             .append(critic_i                  )
        layer_generators          .append(generator_i               )
        layer_GANs                .append(GAN_i                     )
    
    return layer_critics, layer_generators, layer_GANs


In [None]:

layer_critics, layer_generators, layer_GANs = create_autoreg_wgan("autoreg_GAN", 
                                                                  num_conditions=1, 
                                                                  num_observables=3,
                                                                  verbose=False)


In [None]:
def get_noise (batch_size, train_points_c, GAN_noise_size) :
    hyperparams = np.concatenate([np.full(fill_value=c, shape=(batch_size, 1)) for c in train_points_c])
    return np.random.normal(size=(batch_size*len(train_points_c), GAN_noise_size)), hyperparams

def get_train_data (batch_size, train_points_c, datasets, max_axis) :
    data_batch  = np.concatenate([ds[np.random.randint(0, len(ds), batch_size),:max_axis] for c, ds in datasets.items()])
    if len(data_batch.shape) == 1 :
        data_batch = data_batch.reshape((data_batch.shape[0],1))
    hyperparams = np.concatenate([np.full(fill_value=c, shape=(batch_size, 1)) for c in train_points_c])
    return data_batch, hyperparams
    
def get_train_fakes (batch_size, train_points_c, GAN_noise_size, generator) :
    noise, hyperparams = list(get_noise(batch_size, train_points_c, GAN_noise_size))
    fakes_batch = generator.predict([noise, hyperparams])
    return fakes_batch, hyperparams

'''def get_batch_size (epoch_idx, min_batch_size, max_batch_size, batch_update_per_epoch) :
    return int(np.min([min_batch_size + batch_update_per_epoch*epoch_idx, max_batch_size]))'''


In [None]:
def energy_test_stat (ds1, ds2) :
    E1, E2, E3 = 0., 0., 0.
    n1, n2 = len(ds1), len(ds2)
    for x1 in ds1 :
        for x2 in ds2 :
            res = x2 - x1
            E1 = E1 + np.matmul(res, res)
    E1 = np.sqrt(E1) / n1 / n2
    for x1 in ds1 :
        for x2 in ds1 :
            res = x2 - x1
            E2 = E2 + np.matmul(res, res)
    E2 = np.sqrt(E2) / n1 / n1
    for x1 in ds2 :
        for x2 in ds2 :
            res = x2 - x1
            E3 = E3 + np.matmul(res, res)
    E3 = np.sqrt(E3) / n1 / n1
    return 2*E1 - E2 - E3

def get_Et_thresholds (batch_size, threshold, num_toys=10) :
    toy_energy_values = {}
    for mu in mu_scan_points :
        toy_energy_values [mu] = []
    
    sys.stdout.write("Throwing toys...")
    for i in range(num_toys) :
        sys.stdout.write(f"\rThrowing {num_toys} toys... [{int(100.*(i+1)/num_toys)}%]")
        toy1_data , toy1_data_conditions  = get_train_data (batch_size, mu_scan_points, white_datasets, max_axis=1)
        toy2_data , toy2_data_conditions  = get_train_data (batch_size, mu_scan_points, white_datasets, max_axis=1)
        for mu in mu_scan_points :
            toy1_data_mu  = np.array([x for x,c in zip(toy1_data, toy1_data_conditions) if c == mu])
            toy2_data_mu  = np.array([x for x,c in zip(toy2_data, toy2_data_conditions) if c == mu])
            toy_energy_values[mu].append(energy_test_stat(toy1_data_mu, toy2_data_mu))
    sys.stdout.write("\n")
    
    num_mu_scan_points = len(mu_scan_points)
    
    fig = plt.figure(figsize=(2.5*num_mu_scan_points, 2))
    Et_thresholds = {}
    for ax_idx, mu in enumerate(mu_scan_points) :
        vals = sorted(toy_energy_values[mu])
        Et_threshold      = vals[min(int(threshold*len(vals)), len(vals)-1)]
        Et_thresholds[mu] = Et_threshold
        ax = fig.add_subplot(1, num_mu_scan_points, 1+ax_idx)
        ax.hist(vals, color="green", alpha=0.5)
        ax.axvline(Et_threshold, linestyle="--", c="k", linewidth=2)
        ax.set_title(f"$\mu = {mu:.2f}$")
        ax.set_xlabel("Energy test", fontsize=12, labelpad=15)
    plt.show()
                
    return Et_thresholds

In [None]:

initial_batch_size, batch_update_factor, max_batch_size = 60, 1.5, 1000
max_epochs                   = 100000000
max_critic_updates_per_epoch = 10
max_gen_updates_per_epoch    = 1

epoch_print_interval = 200

energy_test_threshold = 1
energy_test_ntoys     = 20


In [None]:
generator, critic, GAN = layer_generators[0], layer_critics[0], layer_GANs[0]

epoch_idx = -1
axis_idx  = 0
batch_size = initial_batch_size
print("Getting initial Et thresholds")
Et_thresholds = get_Et_thresholds(batch_size, threshold=energy_test_threshold, num_toys=energy_test_ntoys)
saved_epochs, saved_energy_tests, epoch_batch_size_transitions = [], {}, []
while epoch_idx < max_epochs :
    epoch_idx = epoch_idx + 1
    
    num_conditions = len(mu_scan_points)
    data_batch , data_conditions  = get_train_data (batch_size, mu_scan_points, white_datasets, max_axis=axis_idx+1)
    fakes_batch, fakes_conditions = get_train_fakes(batch_size, mu_scan_points, axis_idx+1, generator)
    
    data_labels = np.array([1.  for i in range(num_conditions*batch_size)])
    fake_labels = np.array([-1. for i in range(num_conditions*batch_size)])
    
    critic_train_X_datapoints = np.concatenate([data_batch     , fakes_batch     ])
    critic_train_X_conditions = np.concatenate([data_conditions, fakes_conditions])
    critic_train_Y_labels     = np.concatenate([data_labels    , fake_labels     ])
    critic_train_X_datapoints, critic_train_X_conditions, critic_train_Y_labels = joint_shuffle(critic_train_X_datapoints, critic_train_X_conditions, critic_train_Y_labels)
    critic_train_X = [critic_train_X_datapoints, critic_train_X_conditions]
    
    # train critic
    #print(f"Training critic at epoch {epoch_idx}")
    critic.trainable = True
    for i in range(max_critic_updates_per_epoch) :
        critic.train_on_batch(critic_train_X, -1.*critic_train_Y_labels)
    '''critic.fit(critic_train_X, 
               -1.*critic_train_Y_labels,
               shuffle=True,
               validation_split=0.3,
               epochs=max_critic_updates_per_epoch,
               callbacks=[EarlyStopping(restore_best_weights=True, 
                                        monitor="val_loss", 
                                        patience=3)]
               )'''
    
    # train generator
    #print(f"Training generator at epoch {epoch_idx}")
    new_noise, new_conditions = get_noise (batch_size, mu_scan_points, axis_idx+1)
    generator_train_X = [new_noise, new_conditions]
    critic   .trainable = False
    generator.trainable = True
    for i in range(max_critic_updates_per_epoch) :
        GAN_loss = GAN.train_on_batch(generator_train_X, fake_labels)
    '''GAN.fit(generator_train_X, 
            fake_labels,
            shuffle=True,
            validation_split=0.3,
            epochs=max_gen_updates_per_epoch,
            callbacks=[EarlyStopping(restore_best_weights=True, 
                                     monitor="val_loss", 
                                     patience=3)]
               )
    GAN_loss = GAN.evaluate(generator_train_X, fake_labels)'''
    if epoch_idx % epoch_print_interval != 0 : continue
    print(f"Epoch {epoch_idx} GAN loss is {GAN_loss}")
    print(f"Batch size is {batch_size}")
    
    data_batch , data_conditions  = get_train_data (batch_size, mu_scan_points, white_datasets, max_axis=axis_idx+1)
    fakes_batch, fakes_conditions = get_train_fakes(batch_size, mu_scan_points, axis_idx+1, generator)
    
    num_mu_scan_points = len(mu_scan_points)
    fig = plt.figure(figsize=(2.5*num_mu_scan_points, 2))
    ax_lims = axis_lims[0]
    update_batch_size = True if batch_size < max_batch_size else False
    saved_epochs.append(epoch_idx+1)
    for ax_idx, mu in enumerate(mu_scan_points) :
        ax = fig.add_subplot(1, num_mu_scan_points, 1+ax_idx)
        plot_data  = np.array([x for x,c in zip(data_batch , data_conditions ) if c == mu])
        plot_data .reshape((plot_data.shape[0],))
        plot_fakes = np.array([x for x,c in zip(fakes_batch, fakes_conditions) if c == mu])
        plot_fakes.reshape((plot_fakes.shape[0],))
        #ax.hist(plot_data , color="blue", alpha=0.5)
        #ax.hist(plot_fakes, color="red" , alpha=0.5)
        ax.hist(plot_data , bins=np.linspace(ax_lims[0], ax_lims[1], 21), color="blue", alpha=0.5)
        ax.hist(plot_fakes, bins=np.linspace(ax_lims[0], ax_lims[1], 21), color="red" , alpha=0.5)
        ax.text(0.95, 0.9, f"{100.*len([x for x in plot_fakes if x > ax_lims[1]])/len(plot_data):.0f}% ->", ha="right", color="red", transform=ax.transAxes)
        ax.text(0.05, 0.9, f"<- {100.*len([x for x in plot_fakes if x < ax_lims[0]])/len(plot_data):.0f}%", ha="left" , color="red", transform=ax.transAxes)
        ax.set_xlabel("Observable", fontsize=12, labelpad=15)
        ax.set_title(f"$\mu = {mu:.2f}$")
        Et           = energy_test_stat(plot_data, plot_fakes)
        Et_threshold = Et_thresholds [mu]
        print(f"Energy test [mu = {mu:.2f}] = {Et:.2f}, with threshold of {Et_threshold:.2f}")
        if Et > Et_threshold : update_batch_size = False
        if mu not in saved_energy_tests : saved_energy_tests [mu] = []
        saved_energy_tests[mu].append(Et)
    plt.show()
    
    fig = plt.figure(figsize=(2.5*num_mu_scan_points, 2))
    for ax_idx, mu in enumerate(mu_scan_points) :
        ax = fig.add_subplot(1, num_mu_scan_points, 1+ax_idx)
        ax.plot(saved_epochs, saved_energy_tests[mu], "x-")
        for v in epoch_batch_size_transitions :
            ax.axvline(v[0], color="gray", linestyle="--", linewidth=1)
            ax.set_yscale("log")
        ax.set_xlabel("Epoch", fontsize=12, labelpad=15)
        if ax_idx == 0 : ax.set_ylabel("Energy test", fontsize=12, labelpad=15)
        ax.set_title(f"$\mu = {mu:.2f}$")
    
    if update_batch_size is False : continue
        
    print(f"Updating batch factor and throwing toys for Et thresholds")
    batch_size    = np.min([max_batch_size, int(batch_size * batch_update_factor)])
    epoch_batch_size_transitions.append([epoch_idx+1, batch_size])
    Et_thresholds = get_Et_thresholds(batch_size, threshold=energy_test_threshold, num_toys=energy_test_ntoys)
        
    print("Continuing")
    

In [None]:
'''critic.save_weights(".critic0_1.h5")
generator.save_weights(".generator0_1.h5")'''

In [None]:
data_batch , data_conditions  = get_train_data (n_gen_points_per_c_per_ds, mu_scan_points, white_datasets, max_axis=axis_idx+1)
fakes_batch, fakes_conditions = get_train_fakes(n_gen_points_per_c_per_ds, mu_scan_points, axis_idx+1, generator)
    
num_mu_scan_points = len(mu_scan_points)
fig = plt.figure(figsize=(2.5*num_mu_scan_points, 2))
update_batch_size = True if batch_size < max_batch_size else False
for ax_idx, mu in enumerate(mu_scan_points) :
    ax = fig.add_subplot(1, num_mu_scan_points, 1+ax_idx)
    plot_data  = np.array([x for x,c in zip(data_batch , data_conditions ) if c == mu])
    plot_data .reshape((plot_data.shape[0],))
    plot_fakes = np.array([x for x,c in zip(fakes_batch, fakes_conditions) if c == mu])
    plot_fakes.reshape((plot_fakes.shape[0],))
    ax.hist(plot_data , bins=np.linspace(axis_lims[0][0], axis_lims[0][1], 51), color="blue", alpha=0.5)
    ax.hist(plot_fakes, bins=np.linspace(axis_lims[0][0], axis_lims[0][1], 51), color="red" , alpha=0.5)
    ax.set_xlabel("Observable", fontsize=12, labelpad=15)
    ax.set_title(f"$\mu = {mu:.2f}$")
    ax.set_ylim(0, n_gen_points_per_c_per_ds*8000/n_gen_points_per_c_per_ds)
    ax.axvline(0, linestyle="--", linewidth=1, c="gray")
plt.show()

In [None]:
generator.trainable = False
critic.trainable = False
generator, critic, GAN = layer_generators[1], layer_critics[1], layer_GANs[1]

epoch_idx = -1
axis_idx  = 1
batch_size = initial_batch_size
print("Getting initial Et thresholds")
Et_thresholds = get_Et_thresholds(batch_size, threshold=energy_test_threshold, num_toys=energy_test_ntoys)
while epoch_idx < max_epochs :
    epoch_idx = epoch_idx + 1
    
    num_conditions = len(mu_scan_points)
    data_batch , data_conditions  = get_train_data (batch_size, mu_scan_points, white_datasets, max_axis=axis_idx+1)
    fakes_batch, fakes_conditions = get_train_fakes(batch_size, mu_scan_points, axis_idx+1, generator)
    
    data_labels = np.array([1.  for i in range(num_conditions*batch_size)])
    fake_labels = np.array([-1. for i in range(num_conditions*batch_size)])
    
    critic_train_X_datapoints = np.concatenate([data_batch     , fakes_batch     ])
    critic_train_X_conditions = np.concatenate([data_conditions, fakes_conditions])
    critic_train_Y_labels     = np.concatenate([data_labels    , fake_labels     ])
    critic_train_X_datapoints, critic_train_X_conditions, critic_train_Y_labels = joint_shuffle(critic_train_X_datapoints, critic_train_X_conditions, critic_train_Y_labels)
    critic_train_X = [critic_train_X_datapoints, critic_train_X_conditions]
    
    # train critic
    #print(f"Training critic at epoch {epoch_idx}")
    critic.trainable = True
    for i in range(max_critic_updates_per_epoch) :
        critic.train_on_batch(critic_train_X, -1.*critic_train_Y_labels)
    
    # train generator
    #print(f"Training generator at epoch {epoch_idx}")
    new_noise, new_conditions = get_noise (batch_size, mu_scan_points, axis_idx+1)
    generator_train_X = [new_noise, new_conditions]
    critic   .trainable = False
    generator.trainable = True
    for i in range(max_critic_updates_per_epoch) :
        GAN_loss = GAN.train_on_batch(generator_train_X, fake_labels)
    if epoch_idx % epoch_print_interval != 0 : continue
    print(f"Epoch {epoch_idx} GAN loss is {GAN_loss}")
    print(f"Batch size is {batch_size}")
    
    data_batch , data_conditions  = get_train_data (batch_size, mu_scan_points, white_datasets, max_axis=axis_idx+1)
    fakes_batch, fakes_conditions = get_train_fakes(batch_size, mu_scan_points, axis_idx+1, generator)
    
    num_mu_scan_points = len(mu_scan_points)
    fig = plt.figure(figsize=(2.5*num_mu_scan_points, 4))
    update_batch_size = True if batch_size < max_batch_size else False
    for ax_idx, mu in enumerate(mu_scan_points) :
        plot_data  = np.array([x for x,c in zip(data_batch , data_conditions ) if c == mu])
        plot_fakes = np.array([x for x,c in zip(fakes_batch, fakes_conditions) if c == mu])
        ax = fig.add_subplot(2, num_mu_scan_points, 1+ax_idx)
        ax.hist2d(plot_data [:,0], plot_data [:,1], bins=[np.linspace(axis_lims[0][0], axis_lims[0][1], 21), np.linspace(axis_lims[1][0], axis_lims[1][1], 21)])
        ax.set_title(f"$\mu = {mu:.2f}$")
        ax = fig.add_subplot(2, num_mu_scan_points, num_mu_scan_points+1+ax_idx)
        ax.hist2d(plot_fakes[:,0], plot_fakes[:,1], bins=[np.linspace(axis_lims[0][0], axis_lims[0][1], 21), np.linspace(axis_lims[1][0], axis_lims[1][1], 21)])
        ax.set_xlabel("Observable", fontsize=12, labelpad=15)
        if update_batch_size is False : continue
        Et           = energy_test_stat(plot_data, plot_fakes)
        Et_threshold = Et_thresholds [mu]
        print(f"Energy test [mu = {mu:.2f}] = {Et:.2f}, with threshold of {Et_threshold:.2f}")
        if Et > Et_threshold : update_batch_size = False
    plt.show()
    
    if update_batch_size is False : continue
        
    print(f"Updating batch factor and throwing toys for Et thresholds")
    batch_size    = np.min([max_batch_size, int(batch_size * batch_update_factor)])
    Et_thresholds = get_Et_thresholds(batch_size, threshold=energy_test_threshold, num_toys=energy_test_ntoys)
        
    print("Continuing")
    

In [None]:
#  train on observable 1

epoch_idx = -1
while True :
    epoch_idx = epoch_idx + 1
    
    batch_size = get_batch_size(original_batch_size, max_batch_size, epoch_idx)
    data_batch , data_conditions  = get_train_data (batch_size, train_points_c, datasets)
    fakes_batch, fakes_conditions = get_train_fakes(batch_size, train_points_c, GAN_noise_size, generator)
    
    data_labels = np.array([1.  for i in range(num_conditions*batch_size)])
    fake_labels = np.array([-1. for i in range(num_conditions*batch_size)])
    
    critic_train_X_datapoints = np.concatenate([data_batch     , fakes_batch     ])
    critic_train_X_conditions = np.concatenate([data_conditions, fakes_conditions])
    critic_train_Y_labels     = np.concatenate([data_labels    , fake_labels     ])
    critic_train_X = [critic_train_X_datapoints, critic_train_X_conditions]
    
    # train discriminator
    critic.trainable = True
    for critich_update_itr in range(critic_itrs_per_generator_itr) :
        critic_loss = critic.train_on_batch(critic_train_X, -1.*critic_train_Y_labels)
    
    # train generator
    new_noise, new_conditions = get_noise (batch_size, train_points_c, GAN_noise_size)
    generator_train_X = [new_noise, new_conditions]
    critic   .trainable = False
    generator.trainable = True
    GAN_loss = GAN.train_on_batch(generator_train_X, fake_labels)
    
    if epoch_idx % epoch_print_interval != 0 : continue
    print(f"Epoch {epoch_idx} GAN loss is {GAN_loss}")
    
    data_batch , data_conditions  = get_train_data (batch_size, train_points_c, datasets)
    fakes_batch, fakes_conditions = get_train_fakes(batch_size, train_points_c, GAN_noise_size, generator)
    plot_progress (train_points_c, data_batch, data_conditions, fakes_batch, fakes_conditions)

In [None]:
A_mean, A_sigma = 50, 100
B_mean, B_sigma = 125, 8
v_means, v_sigmas = np.array([A_mean, B_mean]), np.array([A_sigma, B_sigma])

Amin, Amax, A_npoints = 50 , 300, 251
Bmin, Bmax, B_npoints = 100, 150, 251

rotate_angle = np.pi/4
rotate = np.array([[np.cos(rotate_angle), -1.*np.sin(rotate_angle)], [np.sin(rotate_angle), np.cos(rotate_angle)]])
cov = np.zeros(shape=rotate.shape)
for i in range(2) :
    for j in range(2) :
        cov[i, j] = rotate[i, j] * v_sigmas[i] * v_sigmas[j]
eigval, eigvec = np.linalg.eig(cov)
eigval = np.sqrt(eigval)

n_data = 100000

rnd_variations = np.array([np.random.normal(0, 1, n_data), np.random.normal(0, 1, n_data)]).transpose()
fake_dataset = []
for idx, row in enumerate(rnd_variations) :
    dp = v_means + np.matmul(rotate, v_sigmas*row)
    if (dp[0] > Amax) or (dp[0] < Amin) : continue
    if (dp[1] > Bmax) or (dp[1] < Bmin) : continue
    fake_dataset.append(dp)
fake_dataset = np.array(fake_dataset)
print(len(fake_dataset))

plt.hist2d(fake_dataset[:,0], fake_dataset[:,1])
plt.show()


In [None]:
def get_special_encoding_constants_for_axis (dataset, axis, axmin, axmax, ax_npoints, frac_constant) :
    tmp_dataset = dataset[:,axis]
    ax_scan_points = np.linspace(axmin, axmax, 1+ax_npoints)
    tmp_dataset = np.array([x for x in tmp_dataset if (x>axmin and x<axmax)])
    
    data_cdf = []
    for A in ax_scan_points :
        data_cdf.append(len([x for x in tmp_dataset if x < A]) / len(tmp_dataset))
    data_cdf     = np.array(data_cdf)
    constant_cdf = (ax_scan_points - axmin) / (axmax - axmin)
    combined_cdf = frac_constant*constant_cdf + (1-frac_constant)*data_cdf
    
    Gauss_x   = np.linspace(-5, 5, 201)
    Gauss_cdf = stats.norm.cdf(Gauss_x)
    Gauss_cdf[0], Gauss_cdf[-1] = 0., 1.
    
    A_to_z = lambda A : np.interp(A, ax_scan_points, combined_cdf  )
    z_to_A = lambda z : np.interp(z, combined_cdf  , ax_scan_points)

    z_to_g = lambda z : np.interp(z, Gauss_cdf, Gauss_x  )
    g_to_z = lambda g : np.interp(g, Gauss_x  , Gauss_cdf)

    A_to_g = lambda A : z_to_g(A_to_z(A))
    g_to_A = lambda g : z_to_A(g_to_z(g))
    
    return A_to_g, g_to_A
    

In [None]:
def special_whiten_dataset (dataset, *axis_configs) :
    num_axes = dataset.shape[1]
    if num_axes != len(axis_configs) : 
        raise ValueError(f"Dataset with shape {dataset.shape} requires {num_axes} axis configs but {len(axis_configs)} provided")
    whitening_funcs = []
    for axis_idx in range(num_axes) :
        axis_config = axis_configs[axis_idx]
        whitening_funcs.append(get_encoding_constants_for_axis (dataset, axis_idx, axis_config[0], axis_config[1], axis_config[2], axis_config[3]))
    white_dataset = np.array([[whitening_funcs[idx][0](x[idx]) for idx in range(num_axes)] for x in fake_dataset])
    white_dataset, whitening_params = whiten_data(white_dataset)
    
    return white_dataset, whitening_funcs, whitening_params

def special_unwhiten_dataset (white_dataset, whitening_funcs, whitening_params) :
    num_axes = white_dataset.shape[1]
    unwhite_dataset = unwhiten_data(white_dataset, whitening_params)
    unwhite_dataset = np.array([[whitening_funcs[idx][1](x[idx]) for idx in range(num_axes)] for x in unwhite_dataset])
    return unwhite_dataset


In [None]:
white_dataset, whitening_funcs, whitening_params = special_whiten_dataset(fake_dataset, 
                                                                          [Amin, Amax, A_npoints, 0.2], 
                                                                          [Bmin, Bmax, B_npoints, 0.2])

unwhite_dataset = special_unwhiten_dataset(white_dataset, whitening_funcs, whitening_params)

In [None]:
fig = plt.figure(figsize=(10, 2))

ax = fig.add_subplot(1, 3, 1)
ax.set_title("Original A")
ax.hist(fake_dataset[:,0])

ax = fig.add_subplot(1, 3, 2)
ax.set_title("Original B")
ax.hist(fake_dataset[:,1])

ax = fig.add_subplot(1, 3, 3)
ax.set_title("Original A vs B")
ax.hist2d(fake_dataset[:,0], fake_dataset[:,1])
plt.show()


fig = plt.figure(figsize=(10, 2))

ax = fig.add_subplot(1, 3, 1)
ax.set_title("Whitened A'")
ax.hist(white_dataset[:,0])

ax = fig.add_subplot(1, 3, 2)
ax.set_title("Whitened B'")
ax.hist(white_dataset[:,1])

ax = fig.add_subplot(1, 3, 3)
ax.set_title("Whitened A' vs B'")
ax.hist2d(white_dataset[:,0], white_dataset[:,1])
plt.show()


fig = plt.figure(figsize=(10, 2))

ax = fig.add_subplot(1, 3, 1)
ax.set_title("Reconstructed A")
ax.hist(unwhite_dataset[:,0])

ax = fig.add_subplot(1, 3, 2)
ax.set_title("Reconstructed B")
ax.hist(unwhite_dataset[:,1])

ax = fig.add_subplot(1, 3, 3)
ax.set_title("Reconstructed A vs B")
ax.hist2d(unwhite_dataset[:,0], unwhite_dataset[:,1])
plt.show()

In [None]:
tmp_dataset = fake_dataset[:,0]

In [None]:
A_scan_points = np.linspace(Amin, Amax, 1+A_npoints)

tmp_dataset = np.array([x for x in tmp_dataset if (x>Amin and x<Amax)])

data_cdf = []
for A in A_scan_points :
    data_cdf.append(len([x for x in tmp_dataset if x < A]) / len(tmp_dataset))
data_cdf = np.array(data_cdf)

In [None]:
constant_cdf = (A_scan_points - Amin) / (Amax - Amin)
frac_constant = 0.2
combined_cdf = frac_constant*constant_cdf + (1-frac_constant)*data_cdf

plt.plot(A_scan_points, data_cdf)
plt.plot(A_scan_points, constant_cdf)
plt.plot(A_scan_points, combined_cdf)
plt.show()

In [None]:
Gauss_x   = np.linspace(-5, 5, 501)
Gauss_cdf = stats.norm.cdf(Gauss_x)
print(Gauss_cdf[0], Gauss_cdf[-1])
Gauss_cdf[0], Gauss_cdf[-1] = 0., 1.
print("-------->")
print(Gauss_cdf[0], Gauss_cdf[-1])

In [None]:
A_to_z = lambda A : np.interp(A, A_scan_points, combined_cdf )
z_to_A = lambda z : np.interp(z, combined_cdf , A_scan_points)

z_to_g = lambda z : np.interp(z, Gauss_cdf, Gauss_x  )
g_to_z = lambda g : np.interp(g, Gauss_x  , Gauss_cdf)

A_to_g = lambda A : z_to_g(A_to_z(A))
g_to_A = lambda g : z_to_A(g_to_z(g))

In [None]:
plt.hist(tmp_dataset)
plt.show()

plt.hist(A_to_g(tmp_dataset))
plt.show()

plt.hist(g_to_A(A_to_g(tmp_dataset)))
plt.show()