In [4]:
import os,pickle
import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from process_data import preprocess_conditional_flow_data_cww
from Model.ConditionalRealNVP import ConditionalRealNVP
from Utils.ObjDict import ObjDict
from Utils.mkdir_p import mkdir_p

In [12]:
# __________________________________________________________________ ||
# Basic configurables
# __________________________________________________________________ ||

input_csv_path = "data/train_cww.npy"
saved_model_path = "output/train_condrealnvp_cww_210118_v1/saved_model_1700.h5"
output_dir = os.path.dirname(saved_model_path)
event_size = 4000
ndim = 3
ncond = 1

In [7]:
# __________________________________________________________________ ||
# Load models
# __________________________________________________________________ ||

nf_model = ConditionalRealNVP(num_coupling_layers=5,dim=ndim,ncond=ncond)
samples = nf_model.distribution.sample(event_size)
condition = 1.0 * np.ones((event_size,1))
_,_ = nf_model.predict([samples,condition,])
nf_model.load_weights(saved_model_path)

In [8]:
arr = np.load(input_csv_path)
arr_list = preprocess_conditional_flow_data_cww(arr)

In [13]:
# __________________________________________________________________ ||
# Make plots for different conditions
# __________________________________________________________________ ||

n_dim = 5
param_grid = [arr_list[idx_param] for idx_param in np.random.randint(0,len(arr_list),n_dim*n_dim)]
param_grid.sort(key=lambda x: x.condition[0])
figsize = (50,50)

samples = nf_model.distribution.sample(event_size)
fig_m4l,ax_m4l = plt.subplots(n_dim,n_dim,figsize=figsize)
fig_mz1,ax_mz1 = plt.subplots(n_dim,n_dim,figsize=figsize)
fig_mz2,ax_mz2 = plt.subplots(n_dim,n_dim,figsize=figsize)

for i,m in enumerate(param_grid):
    
    ix = int(i / n_dim)
    iy = i % n_dim
    
    condition_str = str(m.condition[0])
    condition = np.ones((event_size,1)) * m.condition[0]

    idx_batch = np.random.randint(0,m.x.shape[0],event_size)
    
    x_true = m.x[idx_batch]
    x_gen,_ = nf_model.predict([samples,condition,])

    ax_m4l[ix,iy].hist(x_true[:,0],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_m4l[ix,iy].hist(x_gen[:,0],bins=100,density=1.,histtype='step',range=[-10.,10.],label='Flow '+condition_str)
    ax_m4l[ix,iy].legend(loc='best')
    ax_m4l[ix,iy].set_title(condition_str)
    
    ax_mz1[ix,iy].hist(x_true[:,1],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_mz1[ix,iy].hist(x_gen[:,1],bins=100,density=1.,histtype='step',range=[-10.,10.],label='Flow '+condition_str)
    ax_mz1[ix,iy].legend(loc='best')
    ax_mz1[ix,iy].set_title(condition_str)
    
    ax_mz2[ix,iy].hist(x_true[:,2],bins=100,density=1.,histtype='step',range=[-10.,10.],label='True '+condition_str)
    ax_mz2[ix,iy].hist(x_gen[:,2],bins=100,density=1.,histtype='step',range=[-5.,5.],label='Flow '+condition_str)
    ax_mz2[ix,iy].legend(loc='best')
    ax_mz2[ix,iy].set_title(condition_str)
    
fig_m4l.savefig(os.path.join(output_dir,'m4l.png'))
fig_mz1.savefig(os.path.join(output_dir,'mZ1.png'))
fig_mz2.savefig(os.path.join(output_dir,'mZ2.png'))

  return n/db/n.sum(), bin_edges


In [16]:
# __________________________________________________________________ ||
# Make plots for likelihood
# __________________________________________________________________ ||

import time

n_dim = 5
param_grid = [arr_list[idx_param] for idx_param in np.random.randint(0,len(arr_list),n_dim*n_dim)]
param_grid.sort(key=lambda x: x.condition[0])

plot_low = 0.0
plot_high = 0.2
n_grid = 10
x_grid = [plot_low+(plot_high-plot_low)/n_grid*i for i in range(n_grid+1)]

event_size = 5000

z = np.zeros(n_grid+1)
fig, ax = plt.subplots(n_dim,n_dim,figsize=figsize)

for i,p in enumerate(param_grid):
    
    print("-"*100)
    print("Drawing plot ",i," with param ",p.condition[0])
    
    ix = int(i / n_dim)
    iy = i % n_dim
    
    start_time = time.time()
    
    condition_str = str(p)

    nf_model.direction = -1
    
    idx_batch = np.random.randint(0,p.x.shape[0],event_size)
    
    condition_concat = np.concatenate([np.ones((event_size,1)) * x for ix,x in enumerate(x_grid)])
    x_data_concat = np.concatenate([p.x[idx_batch] for ix,x in enumerate(x_grid)])
    
    z_concat = nf_model.batch_log_loss([x_data_concat,condition_concat])
    
    for ig,x in enumerate(x_grid):
        z[ig] = tf.reduce_mean(z_concat[ig*idx_batch.shape[0]:(ig+1)*idx_batch.shape[0]])

    ax[ix,iy].plot(x_grid,z,)
    ylims = ax[ix,iy].get_ylim()
    ax[ix,iy].arrow(p.condition[0], ylims[1], 0., ylims[0]-ylims[1],)
    
    elapsed_time = time.time() - start_time
    print("Time used: "+str(elapsed_time)+"s")
    
fig.savefig(os.path.join(output_dir,'log_loss.png'))

----------------------------------------------------------------------------------------------------
Drawing plot  0  with param  [0.00231224]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.5815858840942383s
----------------------------------------------------------------------------------------------------
Drawing plot  1  with param  [0.00720713]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.6718010902404785s
----------------------------------------------------------------------------------------------------
Drawing plot  2  with param  [0.02237665]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.281798839569092s
----------------------------------------------------------------------------------------------------
Drawing plot  3  with param  [0.0329956]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 4.299259901046753s
----------------------------------------------------------------------------------------------------
Drawing plot  4  with param  [0.03596633]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.7942662239074707s
----------------------------------------------------------------------------------------------------
Drawing plot  5  with param  [0.04389669]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.9933221340179443s
----------------------------------------------------------------------------------------------------
Drawing plot  6  with param  [0.04472527]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.976522922515869s
----------------------------------------------------------------------------------------------------
Drawing plot  7  with param  [0.0542747]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 3.096670150756836s
----------------------------------------------------------------------------------------------------
Drawing plot  8  with param  [0.05492049]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.9330949783325195s
----------------------------------------------------------------------------------------------------
Drawing plot  9  with param  [0.05652886]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.9986660480499268s
----------------------------------------------------------------------------------------------------
Drawing plot  10  with param  [0.0570449]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.6774237155914307s
----------------------------------------------------------------------------------------------------
Drawing plot  11  with param  [0.06052305]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.757598876953125s
----------------------------------------------------------------------------------------------------
Drawing plot  12  with param  [0.0606764]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.850069999694824s
----------------------------------------------------------------------------------------------------
Drawing plot  13  with param  [0.0610229]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.8595669269561768s
----------------------------------------------------------------------------------------------------
Drawing plot  14  with param  [0.06862886]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.653409957885742s
----------------------------------------------------------------------------------------------------
Drawing plot  15  with param  [0.08153868]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.9275760650634766s
----------------------------------------------------------------------------------------------------
Drawing plot  16  with param  [0.11241495]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.654127836227417s
----------------------------------------------------------------------------------------------------
Drawing plot  17  with param  [0.14428673]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.6973860263824463s
----------------------------------------------------------------------------------------------------
Drawing plot  18  with param  [0.15202926]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.7087562084198s
----------------------------------------------------------------------------------------------------
Drawing plot  19  with param  [0.15671654]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.643663167953491s
----------------------------------------------------------------------------------------------------
Drawing plot  20  with param  [0.17319047]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.617650032043457s
----------------------------------------------------------------------------------------------------
Drawing plot  21  with param  [0.17319047]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.5568950176239014s
----------------------------------------------------------------------------------------------------
Drawing plot  22  with param  [0.17562874]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.6027889251708984s
----------------------------------------------------------------------------------------------------
Drawing plot  23  with param  [0.18344812]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.747225761413574s
----------------------------------------------------------------------------------------------------
Drawing plot  24  with param  [0.18935812]


  verts = np.dot(coords, M) + (x + dx, y + dy)


Time used: 2.6483850479125977s
