# Medic Network Notebook

In this notebook we will use the medic network declared in the utility file medica.py, for training and testing on our generated dataset. The dataset has been created through the Data_create.ipynb notebook.

In [None]:
# collider_doctor_v1
# Neural network for collider data: track, tower, missinget â†’ 4-class probability distribution

import awkward as ak
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt

from medica import *

## Loading the data for training/testing

In [None]:
# Params
json_path = "Data/training_data.json"
batch_size = 64
lr = 1e-3
epochs = 300
patience = 20

# Load dataset & split
data = read_json_to_awkward(json_path)

# converting awkward array to torch Dataset
dataset = ColliderDataset(data)

# Printing the dataset information
print("Total events in dataset:", len(dataset))
# Peek at one example
track, tower, met, y = dataset[0]
print("Track features:", track.shape[2])
track_features = track.shape[2]
print("Tower fetures:", tower.shape[2])
tower_features = tower.shape[2]
print("Missing ET features:", met.shape[2])
met_features = met.shape[2]


In [None]:
# Splitting the dataset into train, validation, and test sets
n_total = len(dataset)
n_train = int(0.7 * n_total)
n_val = int(0.1 * n_total)
n_test = n_total - n_train - n_val
train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

In [None]:

# Setup model/optimizer/loss/device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embed_dim = 64
model = MEDIC(d_track=track_features, d_tower=tower_features, d_met=met_features, embed_dim=embed_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.KLDivLoss(reduction="batchmean")

# Train
model = train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs, patience)

# Test
test_model(model, test_loader, criterion, device)
