In [1]:
import os,pickle
import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from process_data import preprocess_conditional_flow_data_mass,simuate_conditional_flow_data_mass
from Model.ConditionalRealNVP import ConditionalRealNVP
from Utils.ObjDict import ObjDict
from Utils.mkdir_p import mkdir_p

In [2]:
# __________________________________________________________________ ||
# Basic configurables
# __________________________________________________________________ ||

input_csv_path = "data/train_mass.npy"
saved_model_path = "output/train_condrealnvp_mass_210113_v1/saved_model_300.h5"
output_dir = os.path.dirname(saved_model_path)
event_size = 1000
ndim = 3
ncond = 1

In [4]:
# __________________________________________________________________ ||
# Load models
# __________________________________________________________________ ||

nf_model = ConditionalRealNVP(num_coupling_layers=5,ndim=ndim,ncond=ncond)
samples = nf_model.distribution.sample(event_size)
condition = 1.0 * np.ones((event_size,1))
_,_ = nf_model.predict([samples,condition,])
nf_model.load_weights(saved_model_path)

ValueError: Cannot assign to variable dense_7/kernel:0 due to variable shape (4, 16) and value shape (4, 64) are incompatible

In [4]:
arr = np.load(input_csv_path)
arr_list = preprocess_conditional_flow_data_mass(arr)

In [9]:
# __________________________________________________________________ ||
# Make plots for different conditions
# __________________________________________________________________ ||

n_dim = 5
mass_grid = [arr_list[idx_mass] for idx_mass in np.random.randint(0,len(arr_list),n_dim*n_dim)]
mass_grid.sort(key=lambda x: x.condition[0])
figsize = (50,50)

samples = nf_model.distribution.sample(event_size)
fig_pt1,ax_pt1 = plt.subplots(n_dim,n_dim,figsize=figsize)
fig_pt2,ax_pt2 = plt.subplots(n_dim,n_dim,figsize=figsize)
fig_mll,ax_mll = plt.subplots(n_dim,n_dim,figsize=figsize)

for i,m in enumerate(mass_grid):
    
    ix = int(i / n_dim)
    iy = i % n_dim
    
    condition_str = str(m.condition[0])
    condition = np.ones((event_size,1)) * m.condition[0]

    idx_batch = np.random.randint(0,m.x.shape[0],event_size)
    
    x_true = m.x[idx_batch]
    x_gen,_ = nf_model.predict([samples,condition,])

    ax_pt1[ix,iy].hist(x_true[:,0],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_pt1[ix,iy].hist(x_gen[:,0],bins=100,density=1.,histtype='step',range=[-10.,10.],label='Flow '+condition_str)
    ax_pt1[ix,iy].legend(loc='best')
    ax_pt1[ix,iy].set_title(condition_str)
    
    ax_pt2[ix,iy].hist(x_true[:,1],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_pt2[ix,iy].hist(x_gen[:,1],bins=100,density=1.,histtype='step',range=[-10.,10.],label='Flow '+condition_str)
    ax_pt2[ix,iy].legend(loc='best')
    ax_pt2[ix,iy].set_title(condition_str)
    
    ax_mll[ix,iy].hist(x_true[:,2],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_mll[ix,iy].hist(x_gen[:,2],bins=100,density=1.,histtype='step',range=[-5.,5.],label='Flow '+condition_str)
    ax_mll[ix,iy].legend(loc='best')
    ax_mll[ix,iy].set_title(condition_str)
    
fig_pt1.savefig(os.path.join(output_dir,'pt1.png'))
fig_pt2.savefig(os.path.join(output_dir,'pt2.png'))
fig_mll.savefig(os.path.join(output_dir,'mll.png'))

In [32]:
# __________________________________________________________________ ||
# Make plots for likelihood
# __________________________________________________________________ ||

import time

n_dim = 5
mass_grid = [arr_list[idx_mass] for idx_mass in np.random.randint(0,len(arr_list),n_dim*n_dim)]
mass_grid.sort(key=lambda x: x.condition[0])

plot_low = -1.5
plot_high = 1.5
n_grid = 10
x_grid = [plot_low+(plot_high-plot_low)/n_grid*i for i in range(n_grid+1)]

event_size = 5000

z = np.zeros(n_grid+1)
fig, ax = plt.subplots(n_dim,n_dim,figsize=figsize)

for i,m in enumerate(mass_grid):
    
    print("-"*100)
    print("Drawing plot ",i," with mass ",m.condition[0])
    
    ix = int(i / n_dim)
    iy = i % n_dim
    
    start_time = time.time()
    
    condition_str = str(m)

    nf_model.direction = -1
    
    idx_batch = np.random.randint(0,m.x.shape[0],event_size)
    
    condition_concat = np.concatenate([np.ones((event_size,1)) * x for ix,x in enumerate(x_grid)])
    x_data_concat = np.concatenate([m.x[idx_batch] for ix,x in enumerate(x_grid)])
    
    z_concat = nf_model.batch_log_loss([x_data_concat,condition_concat])
    
    for ig,x in enumerate(x_grid):
        z[ig] = tf.reduce_mean(z_concat[ig*idx_batch.shape[0]:(ig+1)*idx_batch.shape[0]])

    ax[ix,iy].plot(x_grid,z,)
    ylims = ax[ix,iy].get_ylim()
    ax[ix,iy].arrow(m.condition[0], ylims[1], 0., ylims[0]-ylims[1],)
    
    elapsed_time = time.time() - start_time
    print("Time used: "+str(elapsed_time)+"s")
    
fig.savefig(os.path.join(output_dir,'log_loss.png'))

----------------------------------------------------------------------------------------------------
Drawing plot  0  with mass  [-0.97735305]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.053668975830078s
----------------------------------------------------------------------------------------------------
Drawing plot  1  with mass  [-0.96899617]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.5634102821350098s
----------------------------------------------------------------------------------------------------
Drawing plot  2  with mass  [-0.91823386]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.7363088130950928s
----------------------------------------------------------------------------------------------------
Drawing plot  3  with mass  [-0.91823386]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.798591136932373s
----------------------------------------------------------------------------------------------------
Drawing plot  4  with mass  [-0.89797003]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.207273006439209s
----------------------------------------------------------------------------------------------------
Drawing plot  5  with mass  [-0.88521594]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.152392864227295s
----------------------------------------------------------------------------------------------------
Drawing plot  6  with mass  [-0.84047532]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.071143865585327s
----------------------------------------------------------------------------------------------------
Drawing plot  7  with mass  [-0.72078026]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.7842519283294678s
----------------------------------------------------------------------------------------------------
Drawing plot  8  with mass  [-0.6850653]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.801119089126587s
----------------------------------------------------------------------------------------------------
Drawing plot  9  with mass  [-0.54612159]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 4.1568520069122314s
----------------------------------------------------------------------------------------------------
Drawing plot  10  with mass  [-0.39936156]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.2875802516937256s
----------------------------------------------------------------------------------------------------
Drawing plot  11  with mass  [-0.38566189]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.3052361011505127s
----------------------------------------------------------------------------------------------------
Drawing plot  12  with mass  [-0.37736623]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.2544050216674805s
----------------------------------------------------------------------------------------------------
Drawing plot  13  with mass  [-0.32993482]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.4447720050811768s
----------------------------------------------------------------------------------------------------
Drawing plot  14  with mass  [-0.26489593]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.7945621013641357s
----------------------------------------------------------------------------------------------------
Drawing plot  15  with mass  [-0.13855339]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.834876775741577s
----------------------------------------------------------------------------------------------------
Drawing plot  16  with mass  [-0.1306833]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.8006341457366943s
----------------------------------------------------------------------------------------------------
Drawing plot  17  with mass  [-0.09870492]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.7808282375335693s
----------------------------------------------------------------------------------------------------
Drawing plot  18  with mass  [-0.02121313]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.883500099182129s
----------------------------------------------------------------------------------------------------
Drawing plot  19  with mass  [0.0814019]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.200653076171875s
----------------------------------------------------------------------------------------------------
Drawing plot  20  with mass  [0.0814019]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.0035669803619385s
----------------------------------------------------------------------------------------------------
Drawing plot  21  with mass  [0.31316736]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.9579110145568848s
----------------------------------------------------------------------------------------------------
Drawing plot  22  with mass  [0.63667242]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.8417069911956787s
----------------------------------------------------------------------------------------------------
Drawing plot  23  with mass  [0.90375983]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.9515180587768555s
----------------------------------------------------------------------------------------------------
Drawing plot  24  with mass  [0.91052179]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.9235470294952393s
