In [None]:
import tensorflow as tf
import numpy as np
import sys
sys.path.append("/home/joosep/particleflow/mlpf")
import tfmodel.model
import tfmodel.data
import tfmodel.model_setup

import yaml
import matplotlib.pyplot as plt
import matplotlib 
import os
os.chdir("/home/joosep/particleflow")

import pandas
import networkx
import glob

from matplotlib import cm

In [None]:
with open("/home/joosep/particleflow/parameters/cms.yaml") as f:
    config = yaml.load(f)
config["setup"]["multi_output"] = True
config["parameters"]["debug"] = True

In [None]:
model = tfmodel.model_setup.make_gnn_dense(config, tf.float32)

In [None]:
cds = config["dataset"]

dataset_def = tfmodel.data.Dataset(
    num_input_features=int(cds["num_input_features"]),
    num_output_features=int(cds["num_output_features"]),
    padded_num_elem_size=6400,
    raw_path=cds.get("raw_path", None),
    raw_files=cds.get("raw_files", None),
    processed_path=cds["processed_path"],
    validation_file_path="data/TTbar_14TeV_TuneCUETP8M1_cfi/val/pfntuple_*.pkl.bz2",
    schema=cds["schema"]
)

dataset_transform = tfmodel.model_setup.targets_multi_output(config['dataset']['num_output_classes'])

In [None]:
Xs = []
ygens = []
ycands = []

for fi in dataset_def.val_filelist[:2]:
    print(fi)
    X, ygen, ycand = dataset_def.prepare_data(fi)

    Xs.append(np.concatenate(X))
    ygens.append(np.concatenate(ygen))
    ycands.append(np.concatenate(ycand))

X_val = np.concatenate(Xs)
ygen_val = np.concatenate(ygens)
ycand_val = np.concatenate(ycands)

X_val, ycand_val, _ = dataset_transform(X_val, ycand_val, None)
X_val, ygen_val, _ = dataset_transform(X_val, ygen_val, None)


In [None]:
np.std(ycand_val["energy"][np.argmax(ycand_val["cls"], axis=-1)==2].numpy().flatten())

In [None]:
plt.hist((ycand_val["energy"][np.argmax(ycand_val["cls"], axis=-1)==2].numpy().flatten()-1/59)/1.3, bins=100);

In [None]:
ret = model(X_val[:1])
#model.set_trainable_classification()
model.load_weights("/home/joosep/particleflow/experiments/cms_20210828_144012_433706.joosep-desktop//weights/weights-03-28.697701.hdf5")
ret = model.predict(X_val, batch_size=1, verbose=1)

In [None]:
def get_bin_index(bs):
    bin_index = []

    for ielem in range(6400):
        if X_val[0, ielem, 0] != 0:
            for ibin in range(bs.shape[0]):
                if ielem in bs[ibin]:
                    bin_index.append(ibin)
                    break
        else:
            break
    return bin_index

In [None]:
def plot_binning_in_layer(layer_name):
    msk = X_val[0][:, 0] != 0
    eta = X_val[0][msk, 2]
    phi = X_val[0][msk, 3]
    typ = X_val[0][msk, 0]
    energy = X_val[0][msk, 4]

    evenly_spaced_interval = np.linspace(0, 1, ret[layer_name]["bins"].shape[1])
    colorlist = [cm.Dark2(x) for x in evenly_spaced_interval]
    bin_idx = get_bin_index(ret[layer_name]["bins"][0])

    plt.figure(figsize=(4,4))
    plt.scatter(eta, phi, c=[colorlist[bi] for bi in bin_idx], marker=".", s=energy)
    plt.xlabel("eta")
    plt.ylabel("phi")
    plt.title("Binning in {}".format(layer_name))
    plt.savefig("bins_{}.pdf".format(layer_name))

In [None]:
plot_binning_in_layer("cg_0")

In [None]:
plot_binning_in_layer("cg_1")

In [None]:
plot_binning_in_layer("cg_2")

In [None]:
plot_binning_in_layer("cg_energy_0")

In [None]:
plot_binning_in_layer("cg_energy_1")

In [None]:
plot_binning_in_layer("cg_energy_2")

In [None]:
def plot_dms(dms):
    fig = plt.figure(figsize=(4*4, 3*4))
    for i in range(len(dms)):
        ax = plt.subplot(4,4,i+1)
        plt.axes(ax)
        plt.imshow(dms[i], interpolation="none", norm=matplotlib.colors.Normalize(vmin=0, vmax=1), cmap="Blues")
        plt.colorbar()
        plt.title("bin {}".format(i))
        #plt.xlabel("elem index $i$")
        #plt.ylabel("elem index $j$")
    plt.tight_layout()

In [None]:
for layer in ['cg_0', 'cg_1', 'cg_2']:
    dm_vals = ret[layer]['dm'].flatten()
    plt.hist(dm_vals[dm_vals!=0], bins=np.linspace(0,1,100), density=True, alpha=0.8, lw=2)

In [None]:
for layer in ['cg_energy_0', 'cg_energy_1', 'cg_energy_2']:
    dm_vals = ret[layer]['dm'].flatten()
    plt.hist(dm_vals[dm_vals!=0], bins=np.linspace(0,1,100), density=True, alpha=0.8, lw=2)

In [None]:
dmn = ret['cg_0']['dm'][0, :, :, :, 0]
plot_dms(dmn)
plt.suptitle("Learned adjacency, cg_0", y=1.01)
plt.savefig("dm_cg_0.pdf")

In [None]:
dmn = ret['cg_1']['dm'][0, :, :, :, 0]
plot_dms(dmn)
plt.suptitle("Learned adjacency, cg_1", y=1.01)
plt.savefig("dm_cg_1.pdf")

In [None]:
dmn = ret['cg_2']['dm'][0, :, :, :, 0]
plot_dms(dmn)
plt.suptitle("Learned adjacency, cg_2", y=1.01)
plt.savefig("dm_cg_2.pdf")

In [None]:
dmn = ret['cg_energy_0']['dm'][0, :, :, :, 0]
plot_dms(dmn)
plt.suptitle("Learned adjacency, cg_energy_0", y=1.01)
plt.savefig("dm_cg_energy_0.pdf")

In [None]:
dmn = ret['cg_energy_1']['dm'][0, :, :, :, 0]
plot_dms(dmn)
plt.suptitle("Learned adjacency, cg_energy_1", y=1.01)
plt.savefig("dm_cg_energy_1.pdf")

In [None]:
dmn = ret['cg_energy_2']['dm'][0, :, :, :, 0]
plot_dms(dmn)
plt.suptitle("Learned adjacency, cg_energy_2", y=1.01)
plt.savefig("dm_cg_energy_2.pdf")

In [None]:
msk = X_val[0][:, 0]!=0
sel = ret['dec_output'][0][msk]

In [None]:
plt.scatter(sel[:, 40], sel[:, 60], marker=".")

In [None]:
np.array(X_val[:1, :, 0]!=0, np.float32)

In [None]:
ret['dec_output_energy'].shape

In [None]:
pred_debug1 = model.output_dec([
    X_val,
    ret['dec_output'],
    ret['dec_output_energy'],
    np.array(X_val[:, :, 0:1]!=0, np.float32)],
    training=False
)

In [None]:
true_id = np.argmax(ycand_val["cls"], axis=-1)
pred_id1 = np.argmax(pred_debug1["cls"], axis=-1)

In [None]:
plt.figure(figsize=(4,4))
msk1 = (X_val[:, :, 0]!=0) & (true_id==2)
plt.scatter(
    pred_debug1["energy"][msk1][:, 0].numpy(),
    ycand_val["energy"][msk1][:, 0].numpy(),
    marker=".", alpha=0.4
)

#plt.plot([-1,1], [-1,1], color="black")

plt.plot([0,6], [0,6], color="black")

In [None]:
model.cg[0].trainable = False
model.cg[1].trainable = False
model.cg[2].trainable = False

# model.cg_energy[0].trainable = False
# model.cg_energy[1].trainable = False
# model.cg_energy[2].trainable = False

model.output_dec.ffn_id.trainable = False
model.output_dec.ffn_charge.trainable = False
model.output_dec.ffn_phi.trainable = False
model.output_dec.ffn_eta.trainable = False
model.output_dec.ffn_pt.trainable = False
model.output_dec.ffn_energy.trainable = True

model.output_dec.layernorm.trainable = False

In [None]:
[w.name for w in model.trainable_weights]

In [None]:
class_weights = tf.constant([0.0, 0.01, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0])

In [None]:
loss = tf.keras.losses.Huber()
optimizer = tf.keras.optimizers.Adam(lr=1e-4)
for epoch in range(100):
    with tf.GradientTape() as tape:
        y_pred = model(X_val[:2], training=True)
        pred_cls = tf.argmax(y_pred["cls"], axis=-1)
        true_cls = tf.argmax(ycand_val["cls"][:2], axis=-1)
        msk_loss = tf.expand_dims(tf.cast((pred_cls==true_cls) & (true_cls!=0), tf.float32), axis=-1)
        sample_weights = tf.keras.activations.softmax(ycand_val["cls"][:2]*100)*class_weights
        sample_weights = tf.reduce_sum(class_weights, axis=-1, keepdims=True)
        loss_val = loss(ycand_val["energy"][:2]*msk_loss, y_pred["energy"][:2]*msk_loss, sample_weight=sample_weights)
        print(loss_val)
    trainable_vars = model.trainable_variables
    gradients = tape.gradient(loss_val, trainable_vars)
    optimizer.apply_gradients(zip(gradients, trainable_vars))


In [None]:
y_pred = model(X_val[2:6], training=False)

true_id = tf.argmax(ycand_val["cls"][2:6], axis=-1)
pred_id = tf.argmax(y_pred["cls"], axis=-1)

In [None]:
sklearn.metrics.confusion_matrix(true_id.numpy().flatten(), pred_id.numpy().flatten())

In [None]:
plt.figure(figsize=(4,4))
cls = 3
print(np.sum((true_id==cls) & (pred_id==cls)))
plt.scatter(
    y_pred["energy"][(true_id==cls) & (pred_id==cls)],
    ycand_val["energy"][2:6][(true_id==cls) & (pred_id==cls)],
    marker="."
)
plt.plot([0,6], [0,6], color="black")
plt.xlim(0,6)
plt.ylim(0,6)

In [None]:
vals = y_pred["energy"][(true_id!=0)] - ycand_val["energy"][2:6][(true_id!=0)]

In [None]:
plt.hist(vals.numpy().flatten(), bins=np.linspace(-2,2,100));
plt.yscale("log")