In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import tensorflow as tf
import tensorflow_addons as tfa

from matplotlib.colors import LogNorm
from scipy import stats
from scipy.optimize import minimize
from copy import deepcopy
from sklearn.model_selection import train_test_split

from freedom.toy_model import advanced_toy_model, NNs

In [None]:
params = {'legend.fontsize': 17,
          'figure.figsize': (15, 9.3),
          'axes.labelsize': 24,
          'axes.titlesize': 24,
          'xtick.labelsize': 22,
          'ytick.labelsize': 22}
plt.rcParams.update(params)

par_names = ['x', 'y', 't', 'E', 'azi']
data_path = '../../freedom/resources/toy_data/'
plot_path = '../../plots/toy_model/'

def correct_azi(azi):
    azi = np.where(azi<-np.pi, azi+2*np.pi, azi)
    return np.where(azi>np.pi, azi-2*np.pi, azi)

In [None]:
detectors = np.vstack([np.repeat(np.linspace(-10, 10, 5), 5), np.tile(np.linspace(-10, 10, 5), 5)]).T
toy_experiment = advanced_toy_model.advanced_toy_experiment(detectors=detectors, isotrop=False) #, time_dist=advanced_toy_model.pandel

In [None]:
bounds = np.array([[-12,12], [-12,12], [-5,5], [3,40], [0, 2*np.pi]])

def LLH(X, event, only_c=False, only_h=False, fix=[None], bounds=bounds):
    #X: hypo_x, hypo_y, hypo_t, hypo_N_src, hypo_ang
    assert only_c + only_h < 2
    
    if fix[0] != None:
        X = np.insert(X, fix[0], fix[1])
        
    if ~np.alltrue(np.logical_and(bounds[:,0] <= X, X <= bounds[:,1]), axis=-1):
        return 1e9
    
    pos = np.array([X[0], X[1]])
    c_term = -toy_experiment.charge_term(event[0], pos, X[3], X[4])
    h_term = -toy_experiment.hit_term(event[1], pos, X[2], X[3], X[4])
    
    if only_c: return c_term
    if only_h: return h_term
    return c_term + h_term

def LLH_NN(X, event, chargeNet=None, hitNet=None, fix=[None], bounds=bounds):
    #X: hypo_x, hypo_y, hypo_t, hypo_N_src, hypo_ang
    if fix[0] != None:
        X = np.insert(X, fix[0], fix[1])
        
    if ~np.alltrue(np.logical_and(bounds[:,0] <= X, X <= bounds[:,1]), axis=-1):
        return 1e9
    
    c_term = 0
    if chargeNet != None:
        x, t = event[0].reshape((1,2)), np.array([X])
        c_term = -chargeNet.predict([x, t])[0, 0]
    h_term = 0
    if hitNet != None and len(event[1]) > 0:
        x, t = event[1][:,:3], np.repeat([X], len(event[1]), axis=0)
        h_term = -np.sum(hitNet.predict([x, t]))
    
    return c_term + h_term

In [None]:
#loc = '../../freedom/resources/models/toy/'
#cmodel = tf.keras.models.load_model(loc+'chargeNet_new.hdf5', custom_objects={'charge_trafo':NNs.charge_trafo})
#hmodel = tf.keras.models.load_model(loc+'hitNet.hdf5', custom_objects={'hit_trafo':NNs.hit_trafo})

## Create events

In [None]:
N = 200000
events, Truth = toy_experiment.generate_events(N, xlims=(-12,12), blims=(-12,12), N_lims=(3,40))
Truth = np.insert(Truth, 2, 0, axis=1)

np.save('../../freedom/resources/toy_data/toy_events_test', events)
np.save('../../freedom/resources/toy_data/toy_truth_test', Truth)

## Train NNs

In [None]:
events = np.load('../../freedom/resources/toy_data/toy_events.npy', allow_pickle=True)#[:100000]
Truth = np.load('../../freedom/resources/toy_data/toy_truth.npy', allow_pickle=True)#[:100000]

In [None]:
strategy = tf.distribute.MirroredStrategy()
nGPUs = strategy.num_replicas_in_sync

hitNet

In [None]:
x, t = NNs.get_hit_data(events, Truth)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)

d_train = NNs.DataGenerator(x_train, t_train, batch_size=2048*nGPUs)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=2048*nGPUs)

In [None]:
with strategy.scope():
    hmodel = NNs.get_model(x_shape=3, t_shape=5, trafo=NNs.hit_trafo)
    optimizer = tf.keras.optimizers.Adam(1e-3)
    hmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = hmodel.fit(d_train, epochs=100, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
hmodel.layers[-1].activation = tf.keras.activations.linear
hmodel.compile()
#hmodel.save('../../../freedom/resources/models/toy/hitNet.hdf5')

chargeNet

In [None]:
x, t = NNs.get_charge_data(events, Truth)
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.1, random_state=42)

d_train = NNs.DataGenerator(x_train, t_train, batch_size=2048*nGPUs)
d_valid = NNs.DataGenerator(x_test, t_test, batch_size=2048*nGPUs)

In [None]:
# define function here again (easier to modify)
def get_model(x_shape, t_shape, trafo, activation=tfa.activations.mish, dets=None): #'elu'
    x_input = tf.keras.Input(shape=(x_shape,))
    t_input = tf.keras.Input(shape=(t_shape,))

    if np.all(dets) == None:
        inp = trafo()(x_input, t_input)
    else:
        inp = trafo()(x_input, t_input, dets=dets)
        
    c, nch, ts = tf.split(inp, [1, 1, 5], 1)

    ls = [ts]
    ls.append(tf.keras.layers.Dense(5, activation=activation)(ts))
    for i in range(50):
        stacked = tf.concat(ls, axis=-1)
        if i == 49:
            ls.append(tf.keras.layers.Dense(50, activation='exponential')(stacked))
        else:
            ls.append(tf.keras.layers.Dense(5, activation=activation)(stacked))
        
    h = tf.keras.layers.Dropout(0.01)(tf.concat(ls, axis=-1))
    h = tf.keras.layers.Dense(100, activation=activation)(h)
    h = tf.keras.layers.Dropout(0.01)(h)
    h = tf.keras.layers.Dense(50, activation=activation)(h)
    h = tf.keras.layers.Dropout(0.01)(h)
    h = tf.keras.layers.Dense(25, activation=activation)(h)
    h = tf.keras.layers.Dense(5, activation=activation)(h)
    
    h = tf.concat([h, c], axis=-1)
    h = tf.keras.layers.Dense(30, activation=activation)(h)
    h = tf.keras.layers.Dense(30, activation=activation)(h)
    #h = tf.concat([h, nch], axis=-1)
    h = tf.keras.layers.Dense(30, activation=activation)(h)
    h = tf.keras.layers.Dense(30, activation=activation)(h)
    h = tf.keras.layers.Dense(30, activation='exponential')(h)

    outputs = tf.keras.layers.Dense(1, activation='sigmoid')(h)

    model = tf.keras.Model(inputs=[x_input, t_input], outputs=outputs)
    
    return model

In [None]:
optimizer = tf.keras.optimizers.Adam(2e-3)
#radam = tfa.optimizers.RectifiedAdam(lr=2e-3)
#optimizer = tfa.optimizers.Lookahead(radam)
with strategy.scope():
    cmodel = get_model(x_shape=2, t_shape=5, trafo=NNs.charge_trafo, dets=toy_experiment.detectors.astype(np.float32))
    cmodel.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:
hist = cmodel.fit(d_train, epochs=25, verbose=1, validation_data=d_valid)

In [None]:
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.yscale('log')

In [None]:
cmodel.layers[-1].activation = tf.keras.activations.linear
cmodel.compile()
#cmodel.save('../../../freedom/resources/models/toy/chargeNet_new.hdf5')

simple checks

In [None]:
# hitnet
x, t = np.zeros(300).reshape((100,3)), np.stack([np.ones(100), np.ones(100), np.linspace(-5,5,100), 10*np.ones(100), np.zeros(100)]).T
pred = -hmodel.predict([x, t])

true = []
for T in np.linspace(-5,5,100): #hit_times, pos_src, t_src, N_src, ang_src
    true.append(-toy_experiment.hit_term(np.array([np.append(x[0], 12)]), [1,1], T, 10, 0))

In [None]:
plt.plot(np.linspace(-5,5,100), pred-np.min(pred))
plt.plot(np.linspace(-5,5,100), np.array(true)-np.min(true))

In [None]:
# charge net
x, t = 10*np.ones(200).reshape((100,2)), np.stack([np.ones(100), np.ones(100), np.zeros(100), np.linspace(3,40,100), np.zeros(100)]).T
pred = -cmodel.predict([x, t])

true = []
for E in np.linspace(3,40,100):
    true.append(-toy_experiment.charge_term(x[0], [1,1], E, 0))

In [None]:
plt.plot(np.linspace(3,40,100), pred-np.min(pred))
plt.plot(np.linspace(3,40,100), np.array(true)-np.min(true))

## Test event

In [None]:
# generate one test event

example_pos_src = np.array([1, 1])
example_N_src = 10
example_ang_src = np.pi
test_event = toy_experiment.generate_event(example_pos_src, N_src=example_N_src, ang_src=example_ang_src)
truth = np.array([example_pos_src[0], example_pos_src[1], 0, example_N_src, example_ang_src])
'''
test_event = events[4]
example_pos_src = Truth[4][:2]
example_N_src = Truth[4][3]
example_ang_src = Truth[4][4]
'''

In [None]:
u, idx, c = np.unique(test_event[1][:,3], return_counts=True, return_index=True)

plt.scatter(toy_experiment.detectors[0], toy_experiment.detectors[1], color='grey')
plt.scatter(test_event[1][idx, 1], test_event[1][idx, 2], s=30*c, marker='+', linewidth=3, color='r')
plt.scatter(example_pos_src[0], example_pos_src[1], color='black', marker='$T$', s=70)
#plt.savefig('../../../plots/toy_model/test_event', bbox_inches='tight')

#### reco test event

In [None]:
seed = np.random.normal(truth) #
mini = minimize(LLH, seed, method='Nelder-Mead', args=(np.array(test_event)))
args = (np.array(test_event), cmodel, hmodel, [None])
mini2 = minimize(LLH_NN, seed, method='Nelder-Mead', args=args)

truth, mini.x, mini2.x

#### LLH scans

In [None]:
# 1d LLH space
point, point2 = truth, truth #mini.x, mini2.x #Reco[3046]

X = np.linspace(point[0]-3, point[0]+3, 100)
Y = np.linspace(point[1]-3, point[1]+3, 100)
T, E = np.linspace(point[2]-2, point[2]+2, 100), np.linspace(max(point[3]-7,3), point[3]+7, 100)
ranges = [X, Y, T, E]

llhs, llhs_nn = [], []
for i in range(len(ranges)):
    llh, llh_nn = [], []
    p, p2 = deepcopy(point), deepcopy(point2)
    for idx in np.ndindex(ranges[i].shape):
        p[i], p2[i] = ranges[i][idx], ranges[i][idx]
        llh.append(LLH(p, test_event, only_c=True)) #
        llhs_nn = np.append(llhs_nn, p2)
    llhs.append(llh-np.min(llh))
llhs = np.array(llhs)

#NN
c_ts = llhs_nn.reshape((400,5))
c_xs = np.tile(test_event[0], len(c_ts)).reshape(len(c_ts), 2)
h_ts = np.repeat(c_ts, test_event[1].shape[0], axis=0)
h_xs = np.tile(test_event[1][:, :3], (len(c_ts),1))

nn_c = -cmodel.predict([c_xs, c_ts], batch_size=4096).reshape(llhs.shape)
nn_h = 0 #-hmodel.predict([h_xs, h_ts], batch_size=4096).reshape((len(c_ts), test_event[1].shape[0]))
nn_h = 0 #np.sum(nn_h, axis=1).reshape(llhs.shape)
llhs_nn = nn_c + nn_h
for l in llhs_nn:
    l -= np.min(l)

In [None]:
plt.figure(figsize=(15, 11))
#plt.suptitle('At bf', y=0.91, size=23)
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.plot(ranges[i], llhs[i], label='True LLH')
    #plt.axvline(mini.x[i], label='Best-fit true llh', color='blue')
    plt.plot(ranges[i], llhs_nn[i], label='NN LLH')
    #plt.axvline(mini2.x[i], label='Best-fit nn llh', color='r')
    plt.axvline(truth[i], color='black', linestyle='--', label='Truth')
    
    plt.legend(fontsize=15)
    plt.xlabel(par_names[i])
    #plt.ylim(0,10)
#plt.savefig('../../../plots/toy_model/llh_scans', bbox_inches='tight')

In [None]:
# Grid scan
X = np.linspace(-10, 10, 100)
Y = np.linspace(-10, 10, 100)
x, y = np.meshgrid(X, Y)

g = {}
g['hit_terms'] = np.empty(x.shape)
g['charge_terms'] = np.empty(x.shape)

for idx in np.ndindex(x.shape):
    hypo_pos =  np.array([x[idx], y[idx]])
    hypo_t = 0 #mini.x[2]
    hypo_N_src = example_N_src #mini.x[3]
    hypo_ang_src = example_ang_src #mini.x[4]
    g['hit_terms'][idx] = 0 #-toy_experiment.hit_term(test_event[1], hypo_pos, hypo_t, hypo_N_src, hypo_ang_src)
    g['charge_terms'][idx] = -toy_experiment.charge_term(test_event[0], hypo_pos, hypo_N_src, hypo_ang_src)
    
g['total_llh'] = g['hit_terms'] + g['charge_terms']
g['total_llh'] -= np.min(g['total_llh'])

#NN
ones = np.ones(np.prod(x.shape))
c_ts = np.vstack([x.flatten(), y.flatten(), ones*0, ones*example_N_src, ones*example_ang_src]).T
#c_ts = np.vstack([x.flatten(), y.flatten(), ones*mini2.x[2], ones*mini2.x[3], ones*mini2.x[4]]).T
c_xs = np.tile(test_event[0], np.prod(x.shape)).reshape(np.prod(x.shape), 2)
h_ts = np.repeat(c_ts, test_event[1].shape[0], axis=0)
h_xs = np.tile(test_event[1][:, :3], (np.prod(x.shape),1))

g_nn_c = -cmodel.predict([c_xs, c_ts], batch_size=4096).reshape(g['total_llh'].shape)
g_nn_h = -hmodel.predict([h_xs, h_ts], batch_size=4096).reshape((np.prod(x.shape), test_event[1].shape[0]))
g_nn_h = np.sum(g_nn_h, axis=1).reshape(g['total_llh'].shape)
g_nn = g_nn_c #+ g_nn_h
g_nn -= np.min(g_nn)

In [None]:
#plot 2d LLH space
plt.figure(figsize=(20,7))
#plt.suptitle('At bf', y=0.98, size=23)

plt.subplot(121)
plt.pcolormesh(X, Y, g['total_llh']) #, vmax=10
plt.colorbar()
plt.title('true LLH')
plt.scatter(example_pos_src[0], example_pos_src[1], color='white', marker='$T$', s=70)
#plt.scatter(mini.x[0], mini.x[1], color='r')
#plt.scatter(toy_experiment.detectors[0], toy_experiment.detectors[1], color='black')

plt.subplot(122)
plt.pcolormesh(X, Y, g_nn) #, vmax=10
plt.colorbar()
plt.title('NN LLH')
plt.scatter(example_pos_src[0], example_pos_src[1], color='white', marker='$T$', s=70)
#plt.scatter(mini2.x[0], mini2.x[1], color='r')
#plt.scatter(toy_experiment.detectors[0], toy_experiment.detectors[1], color='black')

#plt.savefig('../../../plots/toy_model/LLH_scans/xy_llh_scan', bbox_inches='tight')