# Interactive basis plaything

In [1]:
import numpy as np
import bqplot
from bqplot import pyplot as plt
from IPython.display import display
from ipywidgets import Layout
import numpy as np
import pandas as pd
%matplotlib notebook

from madminer.madminer import MadMiner

## Preparation

In [2]:
n_resolution = 50

In [3]:
miner = MadMiner()
miner.add_parameter(
    lha_block='dim6',
    lha_id=1,
    parameter_name='fW',
    morphing_max_power=4,
    morphing_parameter_range=(-1.,1.)
)
miner.add_parameter(
    lha_block='dim6',
    lha_id=2,
    parameter_name='fWW',
    morphing_max_power=4,
    morphing_parameter_range=(-1.,1.)
)

xi = np.linspace(-1.,1.,n_resolution)
yi = np.linspace(-1.,1.,n_resolution)
xx, yy = np.meshgrid(xi, yi)
xx = xx.reshape((-1,1))
yy = yy.reshape((-1,1))
theta_evaluation = np.hstack([xx, yy])

x_updated = False
y_updated = False

## Basis evaluation

In [11]:
def evaluate_basis(basis=None):
    
    # Set basis
    miner.set_benchmarks()
    if basis is not None:
        for theta in basis:
            miner.add_benchmark(
                {'fW':theta[0], 'fWW':theta[1]}
            )
            
    # Set up morphing
    miner.set_benchmarks_from_morphing(keep_existing_benchmarks=True,
                           n_trials=1)

    # Read out basis (which might have been appended)
    actual_basis = []
    for _, parameters in miner.benchmarks.items():
        actual_basis.append([value for _, value in parameters.items()])
    actual_basis = np.array(actual_basis)

    # Evaluate basis
    squared_weights = []

    for theta in theta_evaluation:
        wi = miner.current_morpher._calculate_morphing_weights(theta, None)
        squared_weights.append(np.sum(wi*wi)**0.5)

    squared_weights = np.array(squared_weights).reshape((n_resolution,n_resolution))
    
    return actual_basis, squared_weights

## Initial set up

In [15]:
basis = np.array([[ 0.        ,  0.        ],
       [-0.8318245 ,  0.85645093],
       [-0.82002127, -0.85191237],
       [ 0.76870769, -0.81272456],
       [ 0.7819962 ,  0.86242685],
       [-0.57243257,  0.37755934],
       [-0.29730939,  0.74563426],
       [ 0.13777926,  0.35254704],
       [ 0.46330191,  0.51783982],
       [ 0.64649576, -0.01232633],
       [ 0.16629182, -0.29365045],
       [ 0.39752054, -0.64235507],
       [-0.19238158, -0.59962178],
       [-0.30730345, -0.09697784],
       [-0.70631846, -0.18913046]])

basis, squared_weights = evaluate_basis(basis)

## Interactive tool

In [16]:
def update(change):
    global basis, squared_weights, x_updated, y_updated
    
    variable = change['name']
    values = change['new']
    
    if variable == 'x':
        basis[:,0] = values
        x_updated = True
    elif variable == 'y':
        basis[:,1] = values
        y_updated = True
    
    if x_updated and y_updated:
        basis, squared_weights = evaluate_basis(basis)
        heat.color = np.log(squared_weights) / np.log(10)
        
        x_updated = False
        y_updated = False

In [17]:
x_sc = bqplot.scales.LinearScale(min=-1., max=1.)
y_sc = bqplot.scales.LinearScale(min=-1., max=1.)
c_sc = bqplot.scales.ColorScale(min=0., mid=1., max=2., scheme='YlOrRd')

heat = bqplot.GridHeatMap(color=np.log(squared_weights) / np.log(10),
                          scales={'row': x_sc, 'column': y_sc, 'color': c_sc},
                          #stroke='white',
                          row=yi,
                          column=xi,
                          stroke=None
                         )

scatter = bqplot.Scatter(colors=['white'],
                         x=basis[:,0],
                         y=basis[:,1],
                         scales={'x': x_sc, 'y': y_sc})

ax_x = bqplot.Axis(scale=x_sc, label='fW')
ax_y = bqplot.Axis(scale=y_sc, orientation='vertical', label='fWW')
ax_c = bqplot.ColorAxis(scale=c_sc,
                        orientation='vertical', 
                        side='right')

fig = bqplot.Figure(marks=[heat, scatter],
                    axes=[ax_x, ax_y, ax_c],
                    layout=Layout(width='600px', height='600px'))

display(fig)

scatter.observe(update, ['y','x'])
scatter.enable_move = True
