# Initial Exploration

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

sys.path.append("../")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Roadmap

1. Load data
2. Train
3. Debug train loop
4. Validate

## Load data

In [2]:
from architectures.EquivariantGNN.utils import *

In [4]:
%%time
train_file = '/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/raw_input/train.h5'
with pd.HDFStore(train_file, mode = 'r') as store:
    train_df = store['table']

val_file = '/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/raw_input/val.h5'
with pd.HDFStore(val_file, mode = 'r') as store:
    val_df = store['table']
    
test_file = '/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/raw_input/test.h5'
with pd.HDFStore(test_file, mode = 'r') as store:
    test_df = store['table']

# all_p, all_y = build_dataset(train_df, 1000)
# train_dataset = JetDataset(all_p, all_y)
# train_loader = DataLoader(train_dataset, batch_size = 100, shuffle = True)

# val_all_p, val_all_y = build_dataset(val_df, 100)
# val_dataset = JetDataset(val_all_p, val_all_y)
# val_loader = DataLoader(val_dataset)

CPU times: user 26.1 s, sys: 8.68 s, total: 34.8 s
Wall time: 35.4 s


In [4]:
%%time
sample_size = 100000

for i in range(int(len(val_df)/sample_size)):
    
    if not os.path.exists(f"/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/processed_input/val/val_{i}.pt"):
        
        subsample = val_df.iloc[i*sample_size:(i+1)*sample_size]
        dataset = []
    
        for j, jet in enumerate(subsample.itertuples()):

            try:
                p = get_four_momenta(jet)
                y = torch.tensor(jet.is_signal_new)

                pt, jet_pt, delta_eta, delta_phi, jet_E = get_higher_features(p)
                delta_pt = torch.log(pt / jet_pt)
                delta_E = torch.log(p[:, 0] / jet_E)
                delta_R = torch.sqrt( delta_eta**2 + delta_phi**2 )
                
                dataset.append(Data(x=p, 
                                    y=y, 
                                    log_pt = torch.log(pt), 
                                    log_E = torch.log(p[:, 0]),
                                    delta_eta = delta_eta,
                                    delta_phi = delta_phi,
                                    delta_pt = delta_pt,
                                    delta_E = delta_E,
                                    delta_R = delta_R
                                   ))
            except:
                pass

            if (j%10000) == 0:
                print("Built event:", i, j)

        save_file = f"/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/processed_input/val/val_{i}.pt"
        torch.save(dataset, save_file)
        
    print("Processed file ", i)

Built event: 0 0
Built event: 0 10000
Built event: 0 20000
Built event: 0 30000
Built event: 0 40000
Built event: 0 50000
Built event: 0 60000
Built event: 0 70000
Built event: 0 80000
Built event: 0 90000
Processed file  0
Built event: 1 0
Built event: 1 10000
Built event: 1 20000
Built event: 1 30000
Built event: 1 40000
Built event: 1 50000
Built event: 1 60000
Built event: 1 70000
Built event: 1 80000
Built event: 1 90000
Processed file  1
Built event: 2 0
Built event: 2 10000
Built event: 2 20000
Built event: 2 30000
Built event: 2 40000
Built event: 2 50000
Built event: 2 60000
Built event: 2 70000
Built event: 2 80000
Built event: 2 90000
Processed file  2
Built event: 3 0
Built event: 3 10000
Built event: 3 20000
Built event: 3 30000
Built event: 3 40000
Built event: 3 50000
Built event: 3 60000
Built event: 3 70000
Built event: 3 80000
Built event: 3 90000
Processed file  3
CPU times: user 10min, sys: 23.1 s, total: 10min 23s
Wall time: 10min 25s


In [None]:
%%time


In [2]:
%%time
input_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/processed_input/test"
jet_files = os.listdir(input_dir)
jet_paths = [os.path.join(input_dir, file) for file in jet_files][:1]
opened_files = [np.load(file, allow_pickle=True) for file in jet_paths]

CPU times: user 1min 1s, sys: 2.34 s, total: 1min 4s
Wall time: 1min 4s


In [15]:
opened_files[0].shape

(100000, 9, 2)

In [18]:
torch.save(opened_files, "testfile.pt")

In [5]:
%%time
loaded_file = torch.load("/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/processed_input/test/test_0.pt")

CPU times: user 35.2 s, sys: 2.54 s, total: 37.7 s
Wall time: 42 s


In [10]:
%%time
processed_events = build_processed_dataset(loaded_file, "static", 1, 16, False, 100000)

Built event: 0
CPU times: user 19.9 s, sys: 1.26 s, total: 21.2 s
Wall time: 21.2 s


In [3]:
%%time
processed_datasets = load_processed_datasets("/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/processed_input/", [90000,90000,90000], "static", 0.1, 16, False)

Loading torch files
Building events
Built event: 0
Built event: 0
Built event: 0
CPU times: user 2min 45s, sys: 9.57 s, total: 2min 54s
Wall time: 3min


In [4]:
%%time
processed_datasets = load_processed_datasets("/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/processed_input/", [190000,190000,190000], "static", 0.1, 16, False)

Loading torch files
Building events
Built event: 0
Built event: 100000
Built event: 0
Built event: 100000
Built event: 0
Built event: 100000
CPU times: user 5min 58s, sys: 24.2 s, total: 6min 22s
Wall time: 6min 41s


In [5]:
%%time
processed_datasets = load_processed_datasets("/global/cscratch1/sd/danieltm/ExaTrkX/top_tagging/processed_input/", [290000,290000,290000], "static", 0.1, 16, False)

Loading torch files
Building events
Built event: 0
Built event: 100000
Built event: 200000
Built event: 0
Built event: 100000
Built event: 200000
Built event: 0
Built event: 100000
Built event: 200000
CPU times: user 9min 51s, sys: 53.6 s, total: 10min 44s
Wall time: 11min 17s


## Model

### Model Defs

In [4]:
from lorentz_equivariant_gnn.scripts.legnn_model import L_GCL, LEGNN, unsorted_segment_mean
from lorentz_equivariant_gnn.scripts.jet_tagging_network import train

In [5]:
train_config = {"n_epochs": 20,
               "lr": 1e-3,
               "factor": 0.3,
               "patience": 10}

In [6]:
model_config = {"input_feature_dim": 1,
               "message_dim": 32,
               "output_feature_dim": 2,
               "edge_feature_dim": 0,
               "n_layers": 6}

In [7]:
model = LEGNN(device = device, **model_config)

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=train_config["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=train_config["factor"],
                                                       step_size=train_config["patience"])
loss_fn = torch.nn.BCELoss()

### Debug

In [9]:
for i, batch in enumerate(train_loader):
    print(f"Batch: {i}")

    optimizer.zero_grad()

    p, y = torch.squeeze(batch["p"].to(device)), batch["y"].to(device)
    print("p:", p, "y:", y)
    
    n_nodes = p.size()[0]
    print("n_nodes:", n_nodes)
    
    edges = get_edges(n_nodes)
    row, column = edges
    
    print("edges:", edges.shape)

    h, _ = L_GCL.compute_radials(edges, p)  # torch.zeros(n_nodes, 1)

    print("h:", h)
    
    output, x = model(h, p, edges)
    
    print("output:", output, "x:", x)
    
    break

Batch: 0
p: tensor([[ 2.9199e+02,  9.3843e+01,  1.0585e+02, -2.5544e+02],
        [ 1.6158e+02,  5.1625e+01,  5.9809e+01, -1.4094e+02],
        [ 7.2156e+01,  9.4136e+00,  3.3340e+01, -6.3296e+01],
        [ 4.2710e+01,  5.5029e+00,  2.0023e+01, -3.7322e+01],
        [ 3.3787e+01,  3.2938e+00,  1.9467e+01, -2.7418e+01],
        [ 3.7821e+01,  1.1010e+01,  1.5437e+01, -3.2724e+01],
        [ 2.2993e+01,  4.1118e+00,  1.4969e+01, -1.6962e+01],
        [ 2.4341e+01,  3.0827e+00,  1.4853e+01, -1.9035e+01],
        [ 3.1476e+01,  1.0028e+01,  1.1356e+01, -2.7590e+01],
        [ 1.8812e+01,  5.7134e+00,  1.3936e+01, -1.1270e+01],
        [ 2.8604e+01,  1.0388e+01,  9.7623e+00, -2.4799e+01],
        [ 2.2571e+01,  2.7284e+00,  1.3933e+01, -1.7546e+01],
        [ 2.1750e+01,  2.5852e+00,  1.3901e+01, -1.6527e+01],
        [ 2.7623e+01,  2.7445e+00,  1.2974e+01, -2.4231e+01],
        [ 2.2332e+01,  3.9915e+00,  1.2350e+01, -1.8173e+01],
        [ 1.6093e+01,  4.4986e+00,  1.2076e+01, -9.6401e+0

## Train Loop

In [9]:
train(train_loader, val_loader, model, optimizer, scheduler, loss_fn, train_config)

Epoch: 0


RuntimeError: Sizes of tensors must match except in dimension 2. Got 4 and 32 (The offending index is 0)

## Model Deconstruction

TODO:
- [ ] Fix edge builder
- [ ] Understand the `compute_radials` concept

## Validate