In [22]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from mmd import mix_rbf_mmd2
from torch import Tensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import gc

In [23]:
filenames = {
"herwig" : "../data/events_anomalydetection_DelphesHerwig_qcd_features.h5",
"pythiabg" : "../data/events_anomalydetection_DelphesPythia8_v2_qcd_features.h5",
"pythiasig" : "../data/events_anomalydetection_DelphesPythia8_v2_Wprime_features.h5"}

datatypes = ["herwig", "pythiabg", "pythiasig"]

In [24]:
datatypes = ["herwig", "pythiabg", "pythiasig"]

train_features = ["ptj1", "etaj1", "mj1", "ptj2", "etaj2", "phij2", "mj2"]
# condition_features = ["mjj"]

def load_data(datatype, stop = None, rotate = True, flip_eta = True):
    input_frame = pd.read_hdf(filenames[datatype], stop = stop)
    output_frame = pd.DataFrame(dtype = "float32")

    for jet in ["j1", "j2"]:
        output_frame["pt" + jet] = np.sqrt(input_frame["px" + jet]**2 + input_frame["py" + jet]**2)
        output_frame["eta" + jet] = np.arcsinh(input_frame["pz" + jet] / output_frame["pt" + jet])
        output_frame["phi" + jet] = np.arctan2(input_frame["py" + jet], input_frame["px" + jet])
        output_frame["m" + jet] = input_frame["m" + jet]
        output_frame["p" + jet] = np.sqrt(input_frame["pz" + jet]**2 + output_frame["pt" + jet]**2)
        output_frame["e" + jet] = np.sqrt(output_frame["m" + jet]**2 + output_frame["p" + jet]**2)
        output_frame["tau21" + jet] = input_frame["tau2" + jet] / input_frame["tau1" + jet]
        output_frame["tau32" + jet] = input_frame["tau3" + jet] / input_frame["tau2" + jet]
    
    del input_frame
    gc.collect()

    # Not exact rotation, since negative angles for phi2 are flipped across the x-axis. Should be OK due to symmetry.
    if rotate:
        output_frame["phij2"] = np.abs(output_frame["phij2"] - output_frame["phij1"])
        output_frame["phij1"] = 0
    
    if flip_eta:
        flipped_frame = output_frame.copy()
        flipped_frame["etaj1"] *= -1
        flipped_frame["etaj2"] *= -1
        output_frame = output_frame.append(flipped_frame)
        del flipped_frame
        gc.collect()
    
    for jet in ["j1", "j2"]:
        output_frame["px" + jet] = output_frame["pt" + jet] * np.cos(output_frame["phi" + jet])
        output_frame["py" + jet] = output_frame["pt" + jet] * np.sin(output_frame["phi" + jet])
        output_frame["pz" + jet] = output_frame["pt" + jet] * np.sinh(output_frame["eta" + jet])
    
    # Dijet properties
    output_frame["pxjj"] = output_frame["pxj1"] + output_frame["pxj2"]
    output_frame["pyjj"] = output_frame["pyj1"] + output_frame["pyj2"]
    output_frame["pzjj"] = output_frame["pzj1"] + output_frame["pzj2"]
    output_frame["ejj"] = output_frame["ej1"] + output_frame["ej2"]
    output_frame["pjj"] = np.sqrt(output_frame["pxjj"]**2 + output_frame["pyjj"]**2 + output_frame["pzjj"]**2)
    output_frame["mjj"] = np.sqrt(output_frame["ejj"]**2 - output_frame["pjj"]**2)

    # NaNs may arise from overly sparse jets with tau1 = 0, tau2 = 0, etc.
    output_frame.dropna(inplace = True)
    output_frame.reset_index(drop = True, inplace = True)
    
    return output_frame.astype('float32')



In [26]:
d = load_data("herwig")

In [36]:
d[np.abs(d.etaj1)>2.5]

Unnamed: 0,ptj1,etaj1,phij1,mj1,pj1,ej1,tau21j1,tau32j1,ptj2,etaj2,...,pzj1,pxj2,pyj2,pzj2,pxjj,pyjj,pzjj,ejj,pjj,mjj


In [30]:
d

Unnamed: 0,ptj1,etaj1,phij1,mj1,pj1,ej1,tau21j1,tau32j1,ptj2,etaj2,...,pzj1,pxj2,pyj2,pzj2,pxjj,pyjj,pzjj,ejj,pjj,mjj
0,1234.686035,0.000947,0.0,223.070007,1234.686523,1254.675781,0.579902,0.638047,868.532654,-0.128372,...,1.169330,-864.816162,80.261513,-111.802002,369.869843,80.261513,-110.632675,2132.947021,394.316132,2096.181641
1,1249.513794,-0.646499,0.0,242.136002,1519.860840,1539.027832,0.229670,0.612559,1092.055298,-0.293390,...,-865.270020,-1091.875000,19.845547,-325.015015,157.638855,19.845547,-1190.285034,2681.056885,1200.842285,2397.090820
2,1892.548584,-0.168968,0.0,95.748199,1919.629150,1922.015503,0.731174,0.742621,1332.402100,0.928036,...,-321.303009,-1231.342896,-509.008698,1421.810059,661.205688,-509.008698,1100.507080,3956.002441,1381.086060,3707.095459
3,1217.459106,1.295385,0.0,440.657013,2389.990967,2430.274658,0.603133,0.496088,995.841614,-1.301034,...,2056.659912,-985.330383,-144.307114,-1693.349976,232.128693,-144.307114,363.309937,4400.079102,454.645325,4376.527832
4,1332.903809,-0.044832,0.0,51.966702,1334.243530,1335.255127,0.204270,0.410674,1288.684082,-1.192335,...,-59.776699,-1272.782593,-201.819687,-1927.390015,60.121284,-201.819687,-1987.166748,3656.705566,1998.293579,3062.404053
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1999967,1248.614014,0.914459,0.0,796.106995,1808.091431,1975.596313,0.495325,0.571873,1197.468262,0.034504,...,1307.729980,-1197.383911,14.218256,41.325802,51.230145,14.218256,1349.055786,3174.208740,1350.103027,2872.772705
1999968,1334.905396,-0.925830,0.0,333.614014,1949.072510,1977.418091,0.713051,0.760998,1283.567627,0.982051,...,-1420.180054,-1279.354858,-103.908821,1473.140015,55.550514,-103.908821,52.959961,3935.302490,129.180725,3933.181641
1999969,1445.893555,0.677494,0.0,77.524902,1790.612549,1792.290039,0.600804,0.573244,1423.780884,0.111384,...,1056.260010,-1372.896118,377.237518,158.914001,72.997498,377.237518,1215.174072,3285.281494,1274.474243,3028.000732
1999970,1528.553223,0.149591,0.0,632.781006,1545.687744,1670.198242,0.209634,0.455536,1471.029907,0.543611,...,229.511993,-1467.661377,-99.493774,839.638977,60.891781,-99.493774,1069.151001,3374.546143,1075.495483,3198.573242
