In [29]:
import os,pickle
import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from process_data import simulate_conditional_flow_data_ptscale
from Model.ConditionalRealNVP import ConditionalRealNVP
from Utils.ObjDict import ObjDict
from Utils.mkdir_p import mkdir_p

In [30]:
# __________________________________________________________________ ||
# Basic configurables
# __________________________________________________________________ ||

input_csv_path = "data/train.npy"
saved_model_path = "/Users/lucien/Downloads/train_condrealnvp_210112_v1/saved_model.h5"
event_size = 1000
ndim = 3
ncond = 2

In [31]:
# __________________________________________________________________ ||
# Load models
# __________________________________________________________________ ||

nf_model = ConditionalRealNVP(num_coupling_layers=5,dim=ndim,ncond=ncond)
samples = nf_model.distribution.sample(event_size)
condition = 1.0 * np.ones((event_size,2))
_,_ = nf_model.predict([samples,condition,])
nf_model.load_weights(saved_model_path)

In [32]:
x,_ = nf_model.predict([samples,condition,])

In [33]:
arr = np.load(input_csv_path)
condition = (arr[:,-1] - 90.)
arr = arr[np.squeeze(np.abs(condition) < 1)]

idx_select = np.random.randint(0,arr.shape[0],event_size)
arr = arr[idx_select]

pt1_mean = np.mean(arr[:,0])
pt2_mean = np.mean(arr[:,3])

In [131]:
# __________________________________________________________________ ||
# Make plots for different conditions
# __________________________________________________________________ ||

#sf_grid = [-0.5,-0.1,0.0,0.1,0.5,]
sf_grid = [-0.75,-0.5,-0.1,0.0,0.1,0.5,0.75,]
figsize = (50,50)

cond_cfgs = [ObjDict(sf1=sf1,sf2=sf2,x=ix,y=iy) for ix,sf1 in enumerate(sf_grid) for iy,sf2 in enumerate(sf_grid)]

samples = nf_model.distribution.sample(event_size)
fig_pt1,ax_pt1 = plt.subplots(len(sf_grid),len(sf_grid),figsize=figsize)
fig_pt2,ax_pt2 = plt.subplots(len(sf_grid),len(sf_grid),figsize=figsize)
fig_mll,ax_mll = plt.subplots(len(sf_grid),len(sf_grid),figsize=figsize)

for cfg in cond_cfgs:
    
    condition_str = str(cfg.sf1)+" "+str(cfg.sf2)
    
    x_true,condition = simulate_conditional_flow_data_ptscale(
        arr,
        pt1_mean=pt1_mean,
        pt2_mean=pt2_mean,
        batch_size=1,
        event_size=event_size,
        sf1=cfg.sf1,
        sf2=cfg.sf2,
    )

    x_gen,_ = nf_model.predict([samples,condition,])

    ax_pt1[cfg.x,cfg.y].hist(x_true[:,0],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_pt1[cfg.x,cfg.y].hist(x_gen[:,0],bins=100,density=1.,histtype='step',range=[-10.,10.],label='Flow '+condition_str)
    ax_pt1[cfg.x,cfg.y].legend(loc='best')
    
    ax_pt2[cfg.x,cfg.y].hist(x_true[:,1],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_pt2[cfg.x,cfg.y].hist(x_gen[:,1],bins=100,density=1.,histtype='step',range=[-10.,10.],label='Flow '+condition_str)
    ax_pt2[cfg.x,cfg.y].legend(loc='best')
    
    ax_mll[cfg.x,cfg.y].hist(x_true[:,2],bins=100,density=1.,histtype='step',range=[0.,20.],label='True '+condition_str)
    ax_mll[cfg.x,cfg.y].hist(x_gen[:,2],bins=100,density=1.,histtype='step',range=[0.,20.],label='Flow '+condition_str)
    ax_mll[cfg.x,cfg.y].legend(loc='best')
    
fig_pt1.savefig('pt1.png')
fig_pt2.savefig('pt2.png')
fig_mll.savefig('mll.png')

In [38]:
# __________________________________________________________________ ||
# Make plots for likelihood
# __________________________________________________________________ ||

from matplotlib import ticker
import time

ngrid = 10
plot_low = -0.6
plot_high = 0.6
levels = 100 
X = np.arange(plot_low, plot_high, (plot_high-plot_low)/ngrid)
Y = np.arange(plot_low, plot_high, (plot_high-plot_low)/ngrid)

Z = np.zeros((ngrid,ngrid))
X_grid, Y_grid = np.meshgrid(X, Y)

for i in range(10):
    
    print("-"*100)
    print("Drawing plot ",i)
    
    start_time = time.time()
    
    data_sf1 = 0.5 * (2.*np.random.random_sample()-1.)
    data_sf2 = 0.5 * (2.*np.random.random_sample()-1.)
    
    sf_str = str(data_sf1)+'_'+str(data_sf2)
    
    x_data,condition = simulate_conditional_flow_data_ptscale(
        arr,
        pt1_mean=pt1_mean,
        pt2_mean=pt2_mean,
        batch_size=1,
        event_size=event_size,
        sf1=data_sf1,
        sf2=data_sf2,
        )

    nf_model.direction = -1
    
    condition_concat = np.concatenate([
        np.concatenate([np.ones((x_data.shape[0],1)) * x,np.ones((x_data.shape[0],1)) * y],axis=1)
        for ix,x in enumerate(X) for iy,y in enumerate(Y) 
        ]
    )
    x_data_concat = np.concatenate([x_data for ix,x in enumerate(X) for iy,y in enumerate(Y)])
    
    z_concat = nf_model.batch_log_loss([x_data_concat,condition_concat])
    
    for ix,x in enumerate(X):
        for iy,y in enumerate(Y):
            Z[ix,iy] = tf.reduce_mean(z_concat[(ix*ngrid+iy)*x_data.shape[0]:(ix*ngrid+iy+1)*x_data.shape[0]])

    plt.clf()
    fig, ax = plt.subplots()
    c = ax.contourf(X_grid, Y_grid, Z, levels)
    fig.colorbar(c)
    ax.plot([data_sf2,],[data_sf1,],marker='*',color='red')
    ax.set_title(sf_str)
    fig.savefig('train_condrealnvp_210111_v4/log_loss_'+sf_str+'.png')
    
    elapsed_time = time.time() - start_time
    print("Time used: "+str(elapsed_time)+"s")

----------------------------------------------------------------------------------------------------
Drawing plot  0
Time used: 7.118891000747681s
----------------------------------------------------------------------------------------------------
Drawing plot  1
Time used: 6.2698140144348145s
----------------------------------------------------------------------------------------------------
Drawing plot  2
Time used: 6.3250932693481445s
----------------------------------------------------------------------------------------------------
Drawing plot  3
Time used: 6.399147033691406s
----------------------------------------------------------------------------------------------------
Drawing plot  4
Time used: 5.807127952575684s
----------------------------------------------------------------------------------------------------
Drawing plot  5
Time used: 5.991120100021362s
----------------------------------------------------------------------------------------------------
Drawing plot  6