In [1]:
import os,pickle
import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from process_data import preprocess_conditional_flow_data_mass,simuate_conditional_flow_data_mass
from Model.ConditionalRealNVP import ConditionalRealNVP
from Utils.ObjDict import ObjDict
from Utils.mkdir_p import mkdir_p

In [27]:
# __________________________________________________________________ ||
# Basic configurables
# __________________________________________________________________ ||

input_csv_path = "data/train_mass.npy"
saved_model_path = "output/train_condrealnvp_mass_210209_v1/saved_model_700.h5"
output_dir = os.path.dirname(saved_model_path)
ndim = 3
ncond = 1

In [28]:
# __________________________________________________________________ ||
# Load models
# __________________________________________________________________ ||

nf_model = ConditionalRealNVP(num_coupling_layers=3,ndim=ndim,ncond=ncond)
samples = nf_model.distribution.sample(event_size)
condition = 1.0 * np.ones((event_size,1))
_,_ = nf_model.predict([samples,condition,])
nf_model.load_weights(saved_model_path)

In [13]:
arr = np.load(input_csv_path)
arr_list = preprocess_conditional_flow_data_mass(arr)

In [5]:
# __________________________________________________________________ ||
# Make plots for different conditions
# __________________________________________________________________ ||

n_dim = 5
mass_grid = [arr_list[idx_mass] for idx_mass in np.random.randint(0,len(arr_list),n_dim*n_dim)]
mass_grid.sort(key=lambda x: x.condition[0])
figsize = (50,50)

samples = nf_model.distribution.sample(event_size)
fig_pt1,ax_pt1 = plt.subplots(n_dim,n_dim,figsize=figsize)
fig_pt2,ax_pt2 = plt.subplots(n_dim,n_dim,figsize=figsize)
fig_mll,ax_mll = plt.subplots(n_dim,n_dim,figsize=figsize)

for i,m in enumerate(mass_grid):
    
    ix = int(i / n_dim)
    iy = i % n_dim
    
    condition_str = str(m.condition[0])
    condition = np.ones((event_size,1)) * m.condition[0]

    idx_batch = np.random.randint(0,m.x.shape[0],event_size)
    
    x_true = m.x[idx_batch]
    x_gen,_ = nf_model.predict([samples,condition,])

    ax_pt1[ix,iy].hist(x_true[:,0],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_pt1[ix,iy].hist(x_gen[:,0],bins=100,density=1.,histtype='step',range=[-10.,10.],label='Flow '+condition_str)
    ax_pt1[ix,iy].legend(loc='best')
    ax_pt1[ix,iy].set_title(condition_str)
    
    ax_pt2[ix,iy].hist(x_true[:,1],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_pt2[ix,iy].hist(x_gen[:,1],bins=100,density=1.,histtype='step',range=[-10.,10.],label='Flow '+condition_str)
    ax_pt2[ix,iy].legend(loc='best')
    ax_pt2[ix,iy].set_title(condition_str)
    
    ax_mll[ix,iy].hist(x_true[:,2],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_mll[ix,iy].hist(x_gen[:,2],bins=100,density=1.,histtype='step',range=[-5.,5.],label='Flow '+condition_str)
    ax_mll[ix,iy].legend(loc='best')
    ax_mll[ix,iy].set_title(condition_str)
    
fig_pt1.savefig(os.path.join(output_dir,'pt1.png'))
fig_pt2.savefig(os.path.join(output_dir,'pt2.png'))
fig_mll.savefig(os.path.join(output_dir,'mll.png'))

NameError: name 'arr_list' is not defined

In [114]:
# __________________________________________________________________ ||
# Make plots for likelihood
# __________________________________________________________________ ||

import time

n_dim = 5
mass_grid = [arr_list[idx_mass] for idx_mass in np.random.randint(0,len(arr_list),n_dim*n_dim)]
mass_grid.sort(key=lambda x: x.condition[0])

plot_low = -1.5
plot_high = 1.5
n_grid = 10
x_grid = [plot_low+(plot_high-plot_low)/n_grid*i for i in range(n_grid+1)]
figsize = (50,50)

event_size = 5000

z = np.zeros(n_grid+1)
fig, ax = plt.subplots(n_dim,n_dim,figsize=figsize)

for i,m in enumerate(mass_grid):
    
    print("-"*100)
    print("Drawing plot ",i," with mass ",m.condition[0])
    
    ix = int(i / n_dim)
    iy = i % n_dim
    
    start_time = time.time()
    
    condition_str = str(m)

    nf_model.direction = -1
    
    idx_batch = np.random.randint(0,m.x.shape[0],event_size)
    
    condition_concat = np.concatenate([np.ones((event_size,1)) * x for ix,x in enumerate(x_grid)])
    x_data_concat = np.concatenate([m.x[idx_batch] for ix,x in enumerate(x_grid)])
    
    z_concat = nf_model.batch_log_loss([x_data_concat,condition_concat])
    
    for ig,x in enumerate(x_grid):
        z[ig] = tf.reduce_mean(z_concat[ig*idx_batch.shape[0]:(ig+1)*idx_batch.shape[0]])

    ax[ix,iy].plot(x_grid,z,)
    ylims = ax[ix,iy].get_ylim()
    ax[ix,iy].arrow(m.condition[0], ylims[1], 0., ylims[0]-ylims[1],)
    
    elapsed_time = time.time() - start_time
    print("Time used: "+str(elapsed_time)+"s")
    
fig.savefig(os.path.join(output_dir,'log_loss.png'))

----------------------------------------------------------------------------------------------------
Drawing plot  0  with mass  [-0.89797003]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7671208381652832s
----------------------------------------------------------------------------------------------------
Drawing plot  1  with mass  [-0.85885536]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7628359794616699s
----------------------------------------------------------------------------------------------------
Drawing plot  2  with mass  [-0.73132628]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7342588901519775s
----------------------------------------------------------------------------------------------------
Drawing plot  3  with mass  [-0.70267433]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7076148986816406s
----------------------------------------------------------------------------------------------------
Drawing plot  4  with mass  [-0.5804934]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7123360633850098s
----------------------------------------------------------------------------------------------------
Drawing plot  5  with mass  [-0.50172869]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7333872318267822s
----------------------------------------------------------------------------------------------------
Drawing plot  6  with mass  [-0.46160543]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.72086501121521s
----------------------------------------------------------------------------------------------------
Drawing plot  7  with mass  [-0.45847269]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.6950399875640869s
----------------------------------------------------------------------------------------------------
Drawing plot  8  with mass  [-0.32993482]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.742542028427124s
----------------------------------------------------------------------------------------------------
Drawing plot  9  with mass  [-0.26489593]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7245728969573975s
----------------------------------------------------------------------------------------------------
Drawing plot  10  with mass  [-0.02515446]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7205100059509277s
----------------------------------------------------------------------------------------------------
Drawing plot  11  with mass  [-0.02121313]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.8390500545501709s
----------------------------------------------------------------------------------------------------
Drawing plot  12  with mass  [0.01815648]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.9336886405944824s
----------------------------------------------------------------------------------------------------
Drawing plot  13  with mass  [0.01815648]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.8155441284179688s
----------------------------------------------------------------------------------------------------
Drawing plot  14  with mass  [0.24547908]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7188680171966553s
----------------------------------------------------------------------------------------------------
Drawing plot  15  with mass  [0.2754474]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.8371257781982422s
----------------------------------------------------------------------------------------------------
Drawing plot  16  with mass  [0.30650036]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.815392017364502s
----------------------------------------------------------------------------------------------------
Drawing plot  17  with mass  [0.55027851]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.8939840793609619s
----------------------------------------------------------------------------------------------------
Drawing plot  18  with mass  [0.56619628]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7988801002502441s
----------------------------------------------------------------------------------------------------
Drawing plot  19  with mass  [0.60776793]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.9447932243347168s
----------------------------------------------------------------------------------------------------
Drawing plot  20  with mass  [0.60776793]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.6944010257720947s
----------------------------------------------------------------------------------------------------
Drawing plot  21  with mass  [0.63667242]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7165231704711914s
----------------------------------------------------------------------------------------------------
Drawing plot  22  with mass  [0.78392868]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.7386188507080078s
----------------------------------------------------------------------------------------------------
Drawing plot  23  with mass  [0.78392868]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.6903238296508789s
----------------------------------------------------------------------------------------------------
Drawing plot  24  with mass  [0.83711099]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 0.8433568477630615s


In [29]:
# __________________________________________________________________ ||
# Make plots for posterior
# __________________________________________________________________ ||

import time

n_dim = 5
select_param = [40,42,45,47,50,]
param_grid = [arr_list[idx_param] for idx_param in select_param]
param_grid.sort(key=lambda x: x.condition[0])
figsize = (25,10)

plot_low = -1.0
plot_high = 1.0
n_grid = 100
x_grid = [plot_low+(plot_high-plot_low)/n_grid*i for i in range(n_grid+1)]

energy_norm = 5.

fig, ax = plt.subplots(len(select_param),n_dim,figsize=figsize,constrained_layout=True)

def make_array(event_size,x_grid,p,nf_model):
    log_prob = np.zeros(n_grid+1)
    log_prob_sm = np.zeros(n_grid+1)
    
    idx_batch = np.random.randint(0,p.x.shape[0],event_size)
    condition_concat = np.concatenate([np.ones((event_size,1)) * x for ix,x in enumerate(x_grid)])
    x_data_concat = np.concatenate([p.x[idx_batch] for ix,x in enumerate(x_grid)])
    
    z_concat = nf_model.batch_log_loss([x_data_concat,condition_concat])

    for ig,x in enumerate(x_grid):
        arg = tf.expand_dims(z_concat[ig*idx_batch.shape[0]:(ig+1)*idx_batch.shape[0]],axis=1)
        log_prob[ig] = tf.reduce_sum(nf_model.distribution.log_prob(arg))
    return log_prob.astype(np.float64),idx_batch

for i,p in enumerate(param_grid):
    
    print("-"*100)
    print("Drawing plot ",i," with param ",p.condition[0])
    
    ix = i
    
    start_time = time.time()
    
    condition_str = str(p)
    
    for event_size in [5,50,100]:
        nf_model.direction = -1
        log_prob,idx_batch = make_array(event_size,x_grid,p,nf_model)
        
        y_grid = tf.nn.softmax(log_prob)
        ax[ix,0].plot([x*energy_norm for x in x_grid],y_grid,label=str(event_size)+' events',)
        #ax[ix,0].set_title("mass: "+condition_str)
        ax[ix,0].legend(loc='best')
        ax[ix,0].grid(True)
        ax[ix,0].set_ylim(0.,1.)
        ylims = ax[ix,0].get_ylim()
        ax[ix,0].arrow(p.condition[0]*energy_norm, ylims[1], 0., ylims[0]-ylims[1],)
        
        ax[ix,1].plot([x*energy_norm for x in x_grid],-log_prob-np.min(-log_prob),label=str(event_size)+' events',)
        #ax[ix,1].set_title("mass: "+condition_str)
        ax[ix,1].legend(loc='best')
        ax[ix,1].grid(True)
        ax[ix,1].set_ylim(0.,40.)
        ylims = ax[ix,1].get_ylim()
        ax[ix,1].arrow(p.condition[0]*energy_norm, ylims[1], 0., ylims[0]-ylims[1],)
        
        if event_size == 100:
            nf_model.direction = 1
            samples = nf_model.distribution.sample(idx_batch.shape[0])
            x_pred,_ = nf_model.predict([samples,np.ones((idx_batch.shape[0],1))*p.condition[0]])
            ax[ix,2].hist(p.x[idx_batch,0],bins=50,histtype='step',range=[-5.,5.],label='True',)
            ax[ix,2].hist(x_pred[:,0],bins=50,histtype='step',range=[-5.,5.],label='Flow',)
            ax[ix,2].legend(loc='best')
            
            ax[ix,3].hist(p.x[idx_batch,1],bins=20,histtype='step',range=[-5.,5.],label='True',)
            ax[ix,3].hist (x_pred[:,1],bins=20,histtype='step',range=[-5.,5.],label='Flow',)
            ax[ix,3].legend(loc='best')
            
            ax[ix,4].hist(p.x[idx_batch,2],bins=20,histtype='step',range=[-5.,5.],label='True',)
            ax[ix,4].hist(x_pred[:,2],bins=20,histtype='step',range=[-5.,5.],label='Flow',)
            ax[ix,4].legend(loc='best')
    
    elapsed_time = time.time() - start_time
    print("Time used: "+str(elapsed_time)+"s")
    
fig.savefig(os.path.join(output_dir,'lratio.png'))

----------------------------------------------------------------------------------------------------
Drawing plot  0  with param  [-0.29155951]


  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 7.5941550731658936s
----------------------------------------------------------------------------------------------------
Drawing plot  1  with param  [-0.23136288]


  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 10.903495788574219s
----------------------------------------------------------------------------------------------------
Drawing plot  2  with param  [-0.21552924]


  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 8.235575914382935s
----------------------------------------------------------------------------------------------------
Drawing plot  3  with param  [-0.15319038]


  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 9.340811014175415s
----------------------------------------------------------------------------------------------------
Drawing plot  4  with param  [-0.1027987]


  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)
  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 7.14969277381897s


In [93]:
# __________________________________________________________________ ||
# Correlation
# __________________________________________________________________ ||

import time

n_dim = 3
select_param = [40,42,45,47,50,]
param_grid = [arr_list[idx_param] for idx_param in select_param]
param_grid.sort(key=lambda x: x.condition[0])
figsize = (6,10)
event_size = 10000

energy_norm = 10.

fig, ax = plt.subplots(len(select_param),n_dim,figsize=figsize,constrained_layout=True)

for i,p in enumerate(param_grid):
    
    print("-"*100)
    print("Drawing plot ",i," with param ",p.condition[0])
    
    ix = i
    
    start_time = time.time()
    
    condition_str = str(p)
    
    idx_batch = np.random.randint(0,p.x.shape[0],event_size)
    
    nf_model.direction = 1
    samples = nf_model.distribution.sample(idx_batch.shape[0])
    
    x_pred,_ = nf_model.predict([samples,np.ones((idx_batch.shape[0],1))*p.condition[0]])
    
    corr_pred = np.corrcoef(x_pred.T)
    corr_orig = np.corrcoef(p.x[idx_batch].T)

    im1 = ax[ix,0].imshow(corr_orig)
    cbar1 = plt.colorbar(im1, ax=ax[ix,0])
    cbar1.set_ticks(np.arange(-1., 1., 0.1))

    im2 = ax[ix,1].matshow(corr_pred)
    cbar2 = plt.colorbar(im2, ax=ax[ix,1])
    cbar2.set_ticks(np.arange(-1., 1., 0.1))
    
    im3 = ax[ix,2].matshow(corr_orig / corr_pred)
    cbar3 = plt.colorbar(im3, ax=ax[ix,2])
    cbar3.set_ticks(np.arange(0.0, 2., 0.1))
    
    elapsed_time = time.time() - start_time
    print("Time used: "+str(elapsed_time)+"s")
    
fig.savefig(os.path.join(output_dir,'correlation.png'))

----------------------------------------------------------------------------------------------------
Drawing plot  0  with param  [-0.29155951]
Time used: 0.4996671676635742s
----------------------------------------------------------------------------------------------------
Drawing plot  1  with param  [-0.23136288]
Time used: 0.5248293876647949s
----------------------------------------------------------------------------------------------------
Drawing plot  2  with param  [-0.21552924]
Time used: 0.5536932945251465s
----------------------------------------------------------------------------------------------------
Drawing plot  3  with param  [-0.15319038]
Time used: 0.4945089817047119s
----------------------------------------------------------------------------------------------------
Drawing plot  4  with param  [-0.1027987]
Time used: 0.4963679313659668s
