# Notebook to evaluate trained model
### Dashboards to be produced:
* Radial distributions
* Phi distributions
* z distributions

## Imports and global definitions

In [None]:
from bokeh.io import output_notebook
import pandas as pd

from RootInteractive.InteractiveDrawing.bokeh.bokehDrawSA import *
from tpcwithdnn.tree_df_utils import tree_to_pandas, tree_to_pandas_ri

output_notebook()

In [None]:
base_dir = "/lustre/alice/users/hellbaer/NOTESData/JIRA/ATO-593/HitoshiData/with_vs_noNorm/20220916_withNorm"  #
tree_dir = "%s/trees/phi180_r33_z33_nest100_depth3_lr1.000_tm-hist_g0.00_weight1.0_d0.0_sub0.80_colTree0.8_colLvl1.0_colNode1.0_a0.0_l0.00005_scale1.0_base0.50_pred_doR1_dophi0_doz0_input_z0.0-251.0_rndaugment_dpoints10_ftrain40_fapply40/" % base_dir 
model_file = "%s/validation_mean1.00_nEv200000.root" % tree_dir             #
display(model_file)

## Load nd validation data

In [None]:
branches = ["randomId", "meanId", "r", "phi", "z", "derRefMeanCorrR", "deltaSC", "flucCorrR", "flucCorrRPred"]
df = tree_to_pandas(model_file, 'validation', columns=branches)
df['diffCorrR'] = df['flucCorrRPred'] - df['flucCorrR']
display(df)

## Radial distributions
#### First row 
* true correction vs r
* predicted correction vs r
* difference between prediction and true vs r
##### To Do: Convert these into scatter plots.

#### Second row (histogrammed version)
* 2D histogram: difference vs r
    * distribution, mean, std
    
#### Second Dashboard for Z
* make cut on the radius (83.5) to draw full z distribution

In [None]:
output_file("%s/figures/model_eval.html" % base_dir)

nBinsR = 33
histoArray = [
    #hists as function of R
    {"name": "RCorrR", "variables": ["r", "flucCorrR"], "nbins": [nBinsR, 50], "axis": [1]},           #true correction
    {"name": "RCorrPredR", "variables": ["r", "flucCorrRPred"], "nbins": [nBinsR, 50], "axis": [1]},   #predicted correction
    {"name": "hisdiffR", "variables": ["r", "diffCorrR"],"nbins": [nBinsR, 50], "axis": [1]},          #difference between true and predicted corrections
    #hists as function of Phi
    {"name": "RCorrPhi", "variables": ["phi", "flucCorrR"], "nbins": [nBinsR, 50], "axis": [1]},          
    {"name": "RCorrPredPhi", "variables": ["phi", "flucCorrRPred"], "nbins": [nBinsR, 50], "axis": [1]},   
    {"name": "hisdiffPhi", "variables": ["phi", "diffCorrR"],"nbins": [nBinsR, 50], "axis": [1]}, 
]

figureArray = [
    #R Histograms
    [['r'], ['flucCorrR']],      #corresponds to RCorrR
    [['r'], ['flucCorrRPred']],  #corresponda to RCorrPredR
    [['r'], ['diffCorrR']],      #corresponds to hisdiffR
    [['r'], ['hisdiffR']],       #2D hist
    [['hisdiffR_1.bin_center_0'], ['hisdiffR_1.mean']],  #_1 defines axis, bin_center_0 -> radial axis
    [['hisdiffR_1.bin_center_0'], ['hisdiffR_1.std']],
    #Phi Histograms
    [['phi'], ['flucCorrR']],
    [['phi'], ['flucCorrRPred']],
    [['phi'], ['diffCorrR']],
    [['phi'], ['hisdiffPhi']],
    [['hisdiffPhi_1.bin_center_0'], ['hisdiffPhi_1.mean']],  
    [['hisdiffPhi_1.bin_center_0'], ['hisdiffPhi_1.std']],
    #table
    ["tableHisto", {"rowwise": False}]                 #holds statistical properties of histograms
]

figureLayoutDesc={
    "R": [
         [0,1,2, {'plot_height':300}],             #first row
         [4,5, {'plot_height':200}],               #second row
         {'plot_height':240,'sizing_mode':'scale_width',"legend_visible":False}
         ],
    "Phi": [
        [6,7,8, {'plot_height':300}],
        [10,11, {'plot_height': 200}],
        {'plot_height':240,'sizing_mode':'scale_width',"legend_visible":False}],
    "Summary table": [
        [12, {'plot_height':100}]]
}


widgetParams=[
    ['range', ['r']],
    ['range', ['phi']],
    ['range', ['z']],
    ['range', ['deltaSC']],
    ['multiSelect', ["randomId"]],
    ['multiSelect', ["meanId"]]
]

widgetLayoutDesc=[ 
    [0, 1, 2, 3],
    [4, 5],
    {'sizing_mode':'scale_width'} 
]

tooltips = [("phi", "@phi"), ("r", "@r"), ("z", "@z")]


fig=bokehDrawSA.fromArray(df.sample(1000000), "r>0 & z<1", figureArray, widgetParams, layout=figureLayoutDesc, 
                          tooltips=tooltips, sizing_mode='scale_width', widgetLayout=widgetLayoutDesc, 
                          histogramArray=histoArray, rescaleColorMapper=True, nPointRender=6000)
#fig=bokehDrawSA.fromArray(df.sample(1000000, replace=True), "r>0 & z<1", figureArray, widgetParams, layout=figureLayoutDesc, 
#                          tooltips=tooltips, sizing_mode='scale_width', widgetLayout=widgetLayoutDesc, 
#                          histogramArray=histoArray, rescaleColorMapper=True, nPointRender=6000)

In [None]:
histoArray = [
     #hists as function of z
    {"name": "RCorrZ", "variables": ["z", "flucCorrR"], "nbins": [25, 10], "axis": [0, 1]},          
    {"name": "RCorrPredZ", "variables": ["z", "flucCorrRPred"], "nbins": [25, 10], "axis": [0, 1]},   
    {"name": "hisdiffZ", "variables": ["z", "diffCorrR"],"nbins": [25, 20], "axis": [0, 1]}, 
]

figureArray = [
    #Z Histograms
    [['z'], ['flucCorrR']],
    [['z'], ['flucCorrRPred']],
    [['z'], ['diffCorrR']],
    [['z'], ['hisdiffZ']],
    [['hisdiffZ_1.bin_center_0'], ['hisdiffZ_1.mean']], 
    [['hisdiffZ_1.bin_center_0'], ['hisdiffZ_1.std']],
    #R table
    ["tableHisto", {"rowwise": False}]                 #holds statistical properties of histograms
]


figureLayoutDesc={
    "Z": [
         [0,1,2, {'plot_height':400}],             #first row
         [3,4,5, {'plot_height':400}],               #second row
         [6, {'plot_height':100}],
         {'plot_height':240,'sizing_mode':'scale_width',"legend_visible":False}
         ]
}


widgetParams=[
    ['range', ['r']],
    ['range', ['phi']],
    ['range', ['z']],
    ['multiSelect', ["randomId"]],
    ['multiSelect', ["meanId"]],
    ['range', ['deltaSC']]
]

widgetLayoutDesc=[ 
    [0, 1, 2],
    [3, 4, 5],
    {'sizing_mode':'scale_width'} 
]

tooltips = [("phi", "@phi"), ("r", "@r"), ("z", "@z")]

#restrict r to view full z
fig=bokehDrawSA.fromArray(df.sample(1000000), "r<84 & z>0", figureArray, widgetParams, layout=figureLayoutDesc, 
                          tooltips=tooltips, sizing_mode='scale_width', widgetLayout=widgetLayoutDesc, 
                          histogramArray=histoArray, rescaleColorMapper=True, nPointRender=6000)
