In [None]:
!nvidia-smi

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import numpy as np
import keras

import sys
sys.path.append("scripts")
import utils
from PET import PET

import matplotlib.pyplot as plt
import sklearn
import sklearn.metrics
import json

In [None]:
EPOCHS = 500
BATCH_SIZE = 512
LR = 1e-4

NTRAIN = 200000
NVAL = 100000

OUTFILE = "out/test.json"

In [None]:
import awkward
import vector

def to_p4(p4_obj):
    return vector.awk(
        awkward.zip(
            {
                "mass": p4_obj.tau,
                "x": p4_obj.x,
                "y": p4_obj.y,
                "z": p4_obj.z,
            }
        )
    )

def deltaphi(phi1, phi2):
    diff = phi1 - phi2
    return np.arctan2(np.sin(diff), np.cos(diff))

def deltar(eta1, phi1, eta2, phi2):
    deta = eta1 - eta2
    dphi = deltaphi(phi1, phi2)
    return np.sqrt(deta**2 + dphi**2)

#initialize weights of model1 from weights of model2.
#ignore layers that don't match in shape.
def set_matching_weights(model1, model2):
    il1 = len(model1.layers)
    il2 = len(model2.layers)
    maxl = max(il1, il1)
    for il in range(maxl):
        if il<il1 and il<il2:
            if len(model1.layers[il].weights)>0:
                if len(model1.layers[il].weights) == len(model2.layers[il].weights):
                    weights_match = True
                    for w1, w2 in zip(model1.layers[il].weights, model2.layers[il].weights):
                        if w1.shape != w2.shape:
                            weights_match = False
                    if weights_match:
                        # print(model1.layers[il].name, model2.layers[il].name, [w.shape.as_list() for w in model1.layers[il].weights])
                        model1.layers[il].set_weights(model2.layers[il].weights)

Define the backbone model

In [None]:
#Define original model to get the weights
model = PET(
    num_feat=13,
    num_jet=4,
    num_classes=10, #this is the number of target classes the original model was trained with. not relevant for us.
    local=True,
    num_layers=8,
    drop_probability=0,
    simple=False,
    layer_scale=True,
    talking_head=False,
    mode="all"
)

Run the backbone model on dummy data to initialize weights.

In [None]:
x = {}
x["input_features"] = tf.zeros((32, 100, 13))
x["input_points"] = tf.zeros((32, 100, 2))
x["input_mask"] = tf.zeros((32, 100))
x["input_jet"] = tf.zeros((32, 4)) #how many features per jet
x["input_time"] = tf.zeros((32, 1)) #this is not used

model(x)
out = model.body(x)
print(len(out), out.shape)

Load the weights of the backbone model.

In [None]:
model.summary()
model.load_weights("checkpoints/PET_jetclass_8_local_layer_scale_token_baseline_all.weights.h5", by_name=True)

In [None]:
data = awkward.from_parquet("zh.parquet")

In [None]:
reco_cand_p4s = to_p4(data["reco_cand_p4s"])
reco_jet_p4s = to_p4(data["reco_jet_p4s"])

delta_eta = reco_cand_p4s.eta - reco_jet_p4s.eta
delta_phi = deltaphi(reco_cand_p4s.phi, reco_jet_p4s.phi)
log_pt = np.log(reco_cand_p4s.pt)
log_e = np.log(reco_cand_p4s.energy)
log_ptjet = np.log(1 - reco_cand_p4s.pt/reco_jet_p4s.pt)
log_ejet = np.log(1 - reco_cand_p4s.energy/reco_jet_p4s.energy)
delta_r = deltar(reco_cand_p4s.eta, reco_cand_p4s.phi, reco_jet_p4s.eta, reco_jet_p4s.phi)
charge = data["reco_cand_charge"]
is_ele = np.abs(data["reco_cand_pdg"])==11
is_mu = np.abs(data["reco_cand_pdg"])==13
is_photon = np.abs(data["reco_cand_pdg"])==22
is_chhad = np.abs(data["reco_cand_pdg"])==211
is_nhad = np.abs(data["reco_cand_pdg"])==130

In [None]:
plt.hist(awkward.num(reco_cand_p4s.pt), bins=np.linspace(0,32,33));

In [None]:
pad_size = 32 #max number of particles per jet
fill_val = 0 #fill value of padded data

#create particle array in the shape [njets, pad_size, 13]
vals = [
    awkward.to_numpy(
        awkward.fill_none(
            awkward.pad_none(
                x, pad_size, clip=True), fill_val
        )
    ) for x in [delta_eta, delta_phi, log_ptjet, log_pt, log_ejet, log_e , delta_r, charge, is_chhad, is_nhad, is_photon, is_ele, is_mu]
]
particles = np.stack(vals, axis=-1)
particles[np.isnan(particles)] = 0
particles[np.isinf(particles)] = 0
particles_mask = (~awkward.to_numpy(awkward.pad_none(delta_eta, pad_size, clip=True)).mask).astype(np.float32)

#normalize particles
means_particle = particles[np.squeeze(particles_mask==1)].mean(axis=0)
means_particle[7:] = 0
stds_particle = particles[np.squeeze(particles_mask==1)].std(axis=0)
stds_particle[7:] = 1
stds_particle[stds_particle==0] = 1
particles = (particles - means_particle)/stds_particle

#create jet array in the shape [njets, 4]
jets = awkward.to_numpy(np.stack([
    reco_jet_p4s.pt,
    reco_jet_p4s.eta,
    reco_jet_p4s.mass,
    awkward.num(reco_cand_p4s)], axis=-1)
)
jets[np.isnan(jets)] = 0
jets[np.isinf(jets)] = 0

#normalize jets
means_jet = jets.mean(axis=0)
stds_jet = jets.std(axis=0)
stds_jet[stds_jet==0] = 1
jets = (jets - means_jet)/stds_jet

In [None]:
particles[np.squeeze(particles_mask==1)].shape

In [None]:
means_particle, stds_particle

In [None]:
means_jet, stds_jet

In [None]:
particles.shape, particles_mask.shape, jets.shape

In [None]:
targets_dm = awkward.to_numpy(data["gen_jet_tau_decaymode"])
targets_pt = np.log(to_p4(data["gen_jet_tau_p4s"]).pt/to_p4(data["reco_jet_p4s"]).pt)

In [None]:
def prepare_data(start, stop):
    x = {}
    x["input_features"] = particles[start:stop]
    x["input_points"] = particles[start:stop, :, :2]
    x["input_mask"] = np.expand_dims(particles_mask[start:stop], axis=-1)
    x["input_jet"] = jets[start:stop]
    x["input_time"] = np.zeros((stop-start, 1))
    y_dm = awkward.to_numpy(targets_dm[start:stop])
    y_pt = awkward.to_numpy(targets_pt[start:stop])
    return x, y_dm, y_pt

In [None]:
from tensorflow.keras import layers
from layers import StochasticDepth, TalkingHeadAttention, LayerScale, RandomDrop

def get_encoding(x, projection_dim, use_bias=True):
    x = layers.Dense(2*projection_dim, use_bias=use_bias, activation='gelu')(x)
    x = layers.Dense(projection_dim, use_bias=use_bias, activation='gelu')(x)
    return x

def FourierProjection(x,projection_dim,num_embed=64):    
    half_dim = num_embed // 2
    emb = tf.math.log(10000.0) / (half_dim - 1)
    emb = tf.cast(emb,tf.float32)
    freq = tf.exp(-emb* tf.range(start=0, limit=half_dim, dtype=tf.float32))


    angle = x*freq*1000.0
    embedding = tf.concat([tf.math.sin(angle),tf.math.cos(angle)],-1)*x
    embedding = layers.Dense(2*projection_dim,activation="swish",use_bias=False)(embedding)
    embedding = layers.Dense(projection_dim,activation="swish",use_bias=False)(embedding)
    
    return embedding

def knn(num_points, k, topk_indices, features):
    # topk_indices: (N, P, K)
    # features: (N, P, C)    
    batch_size = tf.shape(features)[0]

    batch_indices = tf.reshape(tf.range(batch_size), (-1, 1, 1))
    batch_indices = tf.tile(batch_indices, (1, num_points, k))
    indices = tf.stack([batch_indices, topk_indices], axis=-1)
    return tf.gather_nd(features, indices)

def get_neighbors(points,features,projection_dim,K):
    drij = pairwise_distance(points)  # (N, P, P)
    _, indices = tf.nn.top_k(-drij, k=K + 1)  # (N, P, K+1)
    indices = indices[:, :, 1:]  # (N, P, K)
    knn_fts = knn(tf.shape(points)[1], K, indices, features)  # (N, P, K, C)
    knn_fts_center = tf.broadcast_to(tf.expand_dims(features, 2), tf.shape(knn_fts))
    local = tf.concat([knn_fts-knn_fts_center,knn_fts_center],-1)
    local = layers.Dense(2*projection_dim,activation='gelu')(local)
    local = layers.Dense(projection_dim,activation='gelu')(local)
    local = tf.reduce_mean(local,-2)
    
    return local

def pairwise_distance(point_cloud):
    r = tf.reduce_sum(point_cloud * point_cloud, axis=2, keepdims=True)
    m = tf.matmul(point_cloud, point_cloud, transpose_b = True)
    D = r - 2 * m + tf.transpose(r, perm=(0, 2, 1)) + 1e-5
    return D


class TransformerModel(keras.Model):
    def __init__(self,
                 use_backbone,
                 num_feat,
                 num_jet,      
                 num_classes=2):
        
        super(TransformerModel, self).__init__()
        
        self.projection_dim = model.projection_dim
        self.num_heads = model.num_heads
        self.num_classes = num_classes
        self.class_activation = "softmax"
        
        self.feature_drop = model.feature_drop
        self.num_keep = model.num_keep
        self.mode = model.mode
        self.num_layers = model.num_layers
        self.layer_scale = model.layer_scale
        self.layer_scale_init = model.layer_scale_init
        self.drop_probability = model.drop_probability
        self.dropout = model.dropout
        self.num_class_layers = 2
        
        self._input_features = layers.Input(shape=(None, num_feat), name='input_features')
        self._input_points = layers.Input(shape=(None, 2), name='input_points')
        self._input_mask = layers.Input(shape=(None, 1), name='input_mask')
        self._input_jet = layers.Input((num_jet, ),name='input_jet')
        self._input_time = layers.Input((None, ),name='input_time')

        if use_backbone:
            self.backbone_body = self.PET_body(
                self._input_features,
                self._input_points,
                self._input_mask,
                self._input_time,
                True,
                self.num_classes,
                2,
                False
            )
            self.backbone = keras.Model(
                inputs=[self._input_features, self._input_points, self._input_mask, self._input_time],
                outputs=[self.backbone_body], name="backbone"
            )
            particles_encoded = self.backbone_body
        else:
            particles_encoded = get_encoding(self._input_features, self.projection_dim)

        classifier_out, regression_out = self.PET_classifier(
            particles_encoded, self._input_jet, self.num_class_layers, 1
        )
        self.classifier_and_regression = keras.Model(
            inputs=[self._input_features, self._input_points, self._input_mask, self._input_jet, self._input_time],
            outputs=[classifier_out, regression_out], name="classifier_and_regression"
        )
            
    def PET_classifier(
            self,
            encoded,
            input_jet,
            num_class_layers,
            num_jet,
            simple = False
    ):

        #Include event information as a representative particle
        if simple:
            encoded = layers.GroupNormalization(groups=1)(encoded)
            representation = layers.GlobalAveragePooling1D()(encoded)
            jet_encoded = get_encoding(input_jet,self.projection_dim)
            representation = layers.Dense(self.projection_dim,activation='gelu')(representation+jet_encoded)
            outputs_pred = layers.Dense(self.num_classes,activation=self.class_activation)(representation)
            outputs_mse = layers.Dense(num_jet)(representation)
        else:
            conditional = layers.Dense(2*self.projection_dim,activation='gelu')(input_jet)
            conditional = tf.tile(conditional[:,None, :], [1,tf.shape(encoded)[1], 1])
            scale,shift = tf.split(conditional,2,-1)
            encoded = encoded*(1.0 + scale) + shift

            class_tokens = tf.Variable(tf.zeros(shape=(1, self.projection_dim)),trainable = True)    
            class_tokens = tf.tile(class_tokens[None, :, :], [tf.shape(encoded)[0], 1, 1])
                        
            for _ in range(num_class_layers):
                concatenated = tf.concat([class_tokens, encoded],1)

                x1 = layers.GroupNormalization(groups=1)(concatenated)            
                updates = layers.MultiHeadAttention(num_heads=self.num_heads,
                                                    key_dim=self.projection_dim//self.num_heads)(
                                                        query=x1[:,:1], value=x1, key=x1)
                updates = layers.GroupNormalization(groups=1)(updates)
                if self.layer_scale:
                    updates = LayerScale(self.layer_scale_init, self.projection_dim)(updates)

                x2 = layers.Add()([updates,class_tokens])
                x3 = layers.GroupNormalization(groups=1)(x2)
                x3 = layers.Dense(2*self.projection_dim,activation="gelu")(x3)
                x3 = layers.Dropout(self.dropout)(x3)
                x3 = layers.Dense(self.projection_dim)(x3)
                if self.layer_scale:
                    x3 = LayerScale(self.layer_scale_init, self.projection_dim)(x3)
                class_tokens = layers.Add()([x3,x2])


            class_tokens = layers.GroupNormalization(groups=1)(class_tokens)
            outputs_pred = layers.Dense(self.num_classes,activation=self.class_activation)(class_tokens[:,0])
            outputs_mse = layers.Dense(num_jet)(class_tokens[:,0])

        return outputs_pred, outputs_mse
    
    def PET_body(self,
                 input_features,
                 input_points,
                 input_mask,
                 input_time,
                 local, K,num_local,
                 talking_head,
                 ):
            
        #Randomly drop features not present in other datasets
        encoded = RandomDrop(self.feature_drop if  'all' in self.mode else 0.0,num_skip=self.num_keep)(input_features)                        
        encoded = get_encoding(encoded,self.projection_dim)

        time = FourierProjection(input_time,self.projection_dim)
        time = tf.tile(time[:,None, :], [1,tf.shape(encoded)[1], 1])*input_mask
        time = layers.Dense(2*self.projection_dim,activation='gelu',use_bias=False)(time)
        scale,shift = tf.split(time,2,-1)
        
        encoded = encoded*(1.0+scale) + shift
        
        if local:
            coord_shift = tf.multiply(999., tf.cast(tf.equal(input_mask, 0), dtype='float32'))        
            points = input_points[:,:,:2]
            local_features = input_features
            for _ in range(num_local):
                local_features = get_neighbors(coord_shift+points,local_features,self.projection_dim,K)
                points = local_features
                
            encoded = layers.Add()([local_features,encoded])

        skip_connection = encoded
        for i in range(self.num_layers):
            x1 = layers.GroupNormalization(groups=1)(encoded)
            if talking_head:
                updates, _ = TalkingHeadAttention(self.projection_dim, self.num_heads, 0.0)(x1)
            else:
                updates = layers.MultiHeadAttention(num_heads=self.num_heads,
                                                    key_dim=self.projection_dim//self.num_heads)(x1,x1)

            if self.layer_scale:
                updates = LayerScale(self.layer_scale_init, self.projection_dim)(updates,input_mask)
            updates = StochasticDepth(self.drop_probability)(updates)
            x2 = layers.Add()([updates,encoded])
            x3 = layers.GroupNormalization(groups=1)(x2)
            x3 = layers.Dense(2*self.projection_dim,activation="gelu")(x3)
            x3 = layers.Dropout(self.dropout)(x3)
            x3 = layers.Dense(self.projection_dim)(x3)
            if self.layer_scale:
                x3 = LayerScale(self.layer_scale_init, self.projection_dim)(x3,input_mask)
            x3 = StochasticDepth(self.drop_probability)(x3)
            encoded = layers.Add()([x3,x2])*input_mask
        return encoded + skip_connection
    
    def call(self, x):
        ret = self.classifier_and_regression([
            x["input_features"], x["input_points"], x["input_mask"], x["input_jet"], x["input_time"]
        ])
        return ret

In [None]:
X_val, y_dm_val, y_pt_val = prepare_data(NTRAIN, NTRAIN+NVAL)

#one small test batch
x, y_dm, y_pt = prepare_data(0, 256)

In [None]:
x["input_features"].shape

In [None]:
histories_direct = []
histories_bb = []
histories_bb_cp = []

X_train, y_dm_train, y_pt_train = prepare_data(0, NTRAIN)

#model without backbone
model_dm_direct = TransformerModel(False, 13, 4, 16)
model_dm_direct(x)
model_dm_direct.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LR),
    loss=[keras.losses.SparseCategoricalCrossentropy(from_logits=False), keras.losses.MeanAbsoluteError()]
)

#model with backbone, initialization from scratch
model_dm_bb = TransformerModel(True, 13, 4, 16)
model_dm_bb(x)
model_dm_bb.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LR),
    loss=[keras.losses.SparseCategoricalCrossentropy(from_logits=False), keras.losses.MeanAbsoluteError()]
)

#model with backbone, initialize from checkpoint
model_dm_bb_cp = TransformerModel(True, 13, 4, 16)
model_dm_bb_cp(x)
model_dm_bb_cp.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LR),
    loss=[keras.losses.SparseCategoricalCrossentropy(from_logits=False), keras.losses.MeanAbsoluteError()]
)

#set backbone weights
set_matching_weights(model_dm_bb_cp.backbone, model.body)
set_matching_weights(model_dm_bb_cp.classifier_and_regression, model.classifier)

callbacks = [tf.keras.callbacks.EarlyStopping(patience=20)]

In [None]:
print("training direct")
history_direct = model_dm_direct.fit(
    X_train, (y_dm_train, y_pt_train),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(X_val, (y_dm_val, y_pt_val)),
    verbose=2,
    callbacks=callbacks
)
print("direct val_loss={:.2f}".format(history_direct.history["val_loss"][-1]))

In [None]:
plt.figure()
plt.plot(history_direct.history["loss"], label="train")
plt.plot(history_direct.history["val_loss"], label="val")
plt.legend()

In [None]:
print("training backbone")
history_bb = model_dm_bb.fit(
    X_train, (y_dm_train, y_pt_train),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(X_val, (y_dm_val, y_pt_val)),
    verbose=2,
    callbacks=callbacks
)
print("backbone val_loss={:.2f}".format(history_bb.history["val_loss"][-1]))

In [None]:
plt.figure()
plt.plot(history_bb.history["loss"], label="train")
plt.plot(history_bb.history["val_loss"], label="val")
plt.legend()

In [None]:
print("training backbone checkpoint")
history_bb_cp = model_dm_bb_cp.fit(
    X_train, (y_dm_train, y_pt_train),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(X_val, (y_dm_val, y_pt_val)),
    verbose=2,
    callbacks=callbacks
)
print("backbone checkpoint val_loss={:.2f}".format(history_bb_cp.history["val_loss"][-1]))

In [None]:
plt.figure()
plt.plot(history_bb_cp.history["loss"], label="train")
plt.plot(history_bb_cp.history["val_loss"], label="val")
plt.legend()

In [None]:
plt.figure()
plt.title("Ntrain={}".format(NTRAIN))
plt.plot(history_direct.history["val_loss"], label="no backbone")
plt.plot(history_bb.history["val_loss"], label="OmniLearn naive")
plt.plot(history_bb_cp.history["val_loss"], label="OmniLearn checkpoint")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("validation loss")
plt.show()

In [None]:
with open(OUTFILE, "w") as fi:
    json.dump({
        "BATCH_SIZE": BATCH_SIZE,
        "NTRAIN": NTRAIN,
        "NVAL": NVAL,
        "history_direct": history_direct.history,
        "history_bb": history_bb.history,
        "history_bb_cp": history_bb_cp.history,
    }, fi, indent=2)

In [None]:
pred_dm_probas1, pred_pt1 = model_dm_direct.predict(X_val, batch_size=BATCH_SIZE)
pred_dm_probas2, pred_pt2 = model_dm_bb.predict(X_val, batch_size=BATCH_SIZE)
pred_dm_probas3, pred_pt3 = model_dm_bb_cp.predict(X_val, batch_size=BATCH_SIZE)

In [None]:
pred_dm1 = tf.argmax(pred_dm_probas1, axis=-1)
pred_dm2 = tf.argmax(pred_dm_probas2, axis=-1)
pred_dm3 = tf.argmax(pred_dm_probas3, axis=-1)

In [None]:
plt.imshow(sklearn.metrics.confusion_matrix(y_dm_val, pred_dm1, labels=range(16)))

In [None]:
plt.figure(figsize=(5,5))
b = np.linspace(-0.2,0.2,100)
plt.hist2d(
    y_pt_val,
    pred_pt1[:, 0],
    bins=(b, b), cmap="Blues"
);
plt.xlabel("target log(genpt/recopt)")
plt.ylabel("predicted log(genpt/recopt)")

In [None]:
plt.figure(figsize=(5,5))
b = np.linspace(0,200,100)
plt.hist2d(
    awkward.to_numpy(to_p4(data["gen_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt),
    awkward.to_numpy(np.exp(pred_pt1[:, 0]) * to_p4(data["reco_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt),
    bins=(b, b), cmap="Blues"
);
plt.plot([0,200],[0,200])
plt.xlabel("true pt")
plt.ylabel("predicted pt")

In [None]:
plt.figure(figsize=(5,5))
b = np.linspace(0.75, 1.25, 200)

plt.hist(
    awkward.to_numpy(to_p4(data["reco_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt)/awkward.to_numpy(to_p4(data["gen_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt),
    bins=b,
    histtype="step", lw=1, label="raw recojet", density=1
);

plt.hist(
    awkward.to_numpy(np.exp(pred_pt1[:, 0]) * to_p4(data["reco_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt)/awkward.to_numpy(to_p4(data["gen_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt),
    bins=b,
    histtype="step", lw=1, label="direct", density=1
);

plt.hist(
    awkward.to_numpy(np.exp(pred_pt2[:, 0]) * to_p4(data["reco_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt)/awkward.to_numpy(to_p4(data["gen_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt),
    bins=b,
    histtype="step", lw=1, label="OmniLearn naive", density=1
);

plt.hist(
    awkward.to_numpy(np.exp(pred_pt3[:, 0]) * to_p4(data["reco_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt)/awkward.to_numpy(to_p4(data["gen_jet_p4s"][NTRAIN:NTRAIN+NVAL]).pt),
    bins=b,
    histtype="step", lw=1, label="OmniLearn checkpoint", density=1
);

plt.axvline(1.0, color="black", ls="--", lw=1.0)
plt.legend(loc="best")
plt.xlabel("tau pred / true pt")
plt.ylim(top=30)