In [None]:
import sys
import os.path as osp
sys.path
sys.path.append('../../L1DeepMETv2/')
from graphmetnetwork import GraphMetNetwork

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.utils import to_undirected
from torch_cluster import radius_graph, knn_graph
from torch_geometric.datasets import MNISTSuperpixels
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from tqdm import tqdm
import model.net as net
import model.data_loader as data_loader
import utils

### Load Parameters

In [None]:
data_dir = '../../L1DeepMETv2/data_ttbar'

In [None]:
dataloaders = data_loader.fetch_dataloader(data_dir = data_dir, batch_size=6, validation_split=.2)
train_dl = dataloaders['train']
test_dl = dataloaders['test']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Training dataloader: {}, Test dataloader: {}'.format(len(train_dl), len(test_dl)))
print(device)

In [None]:
test_data = None
for cnt, test_data in enumerate(test_dl):
    if cnt == 5:
        break

#### Load Tensor Parameters

In [None]:
n_features_cont = 6
x_cont_test = test_data.x[:,:n_features_cont] .to(device)  # include puppi
x_cat_test = test_data.x[:,n_features_cont:].long().to(device)
etaphi_test = torch.cat([test_data.x[:, 3][:, None], test_data.x[:, 4][:, None]], dim=1).to(device=device)
batch_test = test_data.batch.to(device)
edge_index_test = radius_graph(etaphi_test, r=0.4, batch=batch_test, loop=False, max_num_neighbors=255).to(device=device)
print(f'x_cont_test: {x_cont_test.shape}')
print(f'x_cat_test: {x_cat_test.shape}')
print(f'etaphi: {etaphi_test.shape}')
print(f'batch: {batch_test.shape}')
print(f'edge_index: {edge_index_test.shape}')

#### Convert Tensor parameters to Numpy arrays

In [None]:
x_cont = np.ascontiguousarray(x_cont_test.squeeze(0).cpu().numpy())
x_cat = np.ascontiguousarray(x_cat_test.squeeze(0).cpu().numpy())
batch = np.ascontiguousarray(batch_test.squeeze(0).cpu().numpy())
etaphi = etaphi_test.squeeze(0).cpu().numpy()
edge_index = edge_index_test.squeeze(0).cpu().numpy().transpose()
num_nodes = x_cont.shape[0]
batch_size = batch.shape[0]
print(f'Number of nodes: {num_nodes}')
assert(num_nodes == batch_size)

### Load the Torch Model

In [None]:
prefix = '../../L1DeepMETv2/ckpts_April30_scale_sigmoid'
# Restore ckpts
restore_ckpt = osp.join(prefix, 'last.pth.tar')
norm = torch.tensor([1., 1., 1., 1., 1., 1.]).to(device=device)
torch_model = net.Net(continuous_dim=6, categorical_dim=2 , norm=norm).to(device)
print(torch_model)

In [None]:
param_restored_new = utils.load_checkpoint(restore_ckpt, torch_model)
param_restored_new

In [None]:
weights = param_restored_new['state_dict']
weights

#### Get the weights

In [None]:
from collections import OrderedDict
output_dir = "weights_files/"

# Function to save the weights as binary files
def save_weights_as_binary(weights_dict, output_dir):
    for key, tensor in weights_dict.items():
        # Convert the tensor to a NumPy array
        np_array = tensor.cpu().numpy()

        # Create a binary file name based on the tensor name
        file_name = output_dir + key.replace('.', '_') + '.bin'

        # Save the NumPy array as a binary file
        np_array.tofile(file_name)

In [None]:
# Save all weights in the OrderedDict to binary files
save_weights_as_binary(weights, output_dir)

### Load the C++ Model

In [None]:
weights_dir = './weights_files'

In [None]:
# Create an instance of the C++ GraphMetNetwork model
cmodel = GraphMetNetwork()

# Load the weights
cmodel.load_weights(weights)

### Test the weights

In [None]:
torch_emb_cont_weights = param_restored_new['state_dict']['graphnet.embed_continuous.0.weight'].cpu().numpy()
torch_emb_cont_weights.shape

In [None]:
cmodel_emb_cont_weights = cmodel.get_graphmet_embed_continuous_0_weight()
cmodel_emb_cont_weights.shape

In [None]:
print(torch_emb_cont_weights)

In [None]:
print(cmodel_emb_cont_weights)

In [None]:
assert(np.allclose(torch_emb_cont_weights, cmodel_emb_cont_weights, atol=1e-5))

### Run the Torch Model

In [None]:
results = torch_model(x_cont_test, x_cat_test, edge_index_test, batch_test)

### Run the C++ Model

In [None]:
cmodel.GraphMetNetworkLayers(x_cont, x_cat, batch, num_nodes)

#### Test Inputs

In [None]:
c_x_cont = cmodel.get_x_cont()
c_x_cat = cmodel.get_x_cat()
c_batch = cmodel.get_batch()
c_num_nodes = cmodel.get_num_nodes()
print(f'Shape of c_x_cont: {c_x_cont.shape}')
print(f'Shape of c_x_cat: {c_x_cat.shape}')
print(f'Shape of c_batch: {c_batch.shape}')
print(f'Value of c_num_nodes: {c_num_nodes}')

In [None]:
# print(x_cont)
assert(np.allclose(x_cont, c_x_cont, atol=1e-5))
assert(np.allclose(x_cat, c_x_cat, atol=1e-5))
assert(np.allclose(batch, c_batch, atol=1e-5))
assert(np.allclose(num_nodes, c_num_nodes, atol=1e-5))



#### Test Internal Variables

In [None]:
c_etaphi = cmodel.get_etaphi()
print(etaphi.shape)
print(c_etaphi.shape)
print(type(etaphi))
print(type(c_etaphi))
are_almost_equal = np.allclose(etaphi, c_etaphi, atol=1e-5)
assert(np.allclose(etaphi, c_etaphi, atol=1e-5))

In [None]:
c_num_edges = cmodel.get_num_edges()
c_edge_index = cmodel.get_edge_index()
edge_index_np = edge_index_test.squeeze(0).cpu().numpy().transpose()
print(edge_index_np.shape)
print(c_edge_index.shape)
print(f'Number of C edges: {c_num_edges}')

In [None]:
print(edge_index_np)
print(c_edge_index)

#### Test Intermediate Variables

In [None]:
torch_emb_cont = torch_model.graphnet.emb_cont_.cpu().numpy()
cmodel_emb_cont = cmodel.get_emb_cont()
# np.testing.assert_allclose(torch_model.graphnet.emb_cont_.cpu().numpy(), cmodel.get_emb_cont(), rtol=1e-5)
print(torch_emb_cont.shape)
print(cmodel_emb_cont.shape)

In [None]:
print(torch_emb_cont)

In [None]:
print(cmodel_emb_cont)

### Other Tests

In [None]:
import c_radius_graph

# Example points in 2D space and their corresponding batch indices
points = [[0.0, 0.0], [1.0, 1.0], [2.0, 2.0],  # Batch 0
          [3.0, 3.0], [4.0, 4.0],              # Batch 1
          [5.0, 5.0], [6.0, 6.0]]              # Batch 2
batch_indices = [0, 0, 0, 1, 1, 2, 2]  # Batch assignments
radius = 1.5

In [None]:
# Call the C++ function
neighbors = c_radius_graph.find_neighbors_by_batch(points, batch_indices, radius)

# Print neighbors
for pair in neighbors:
    print(f"Point {pair[0]} is within radius of point {pair[1]}")

neighbors

In [None]:
points_tensor = torch.tensor(points)
batch_tensor = torch.tensor(batch_indices)

edge_index_pts = radius_graph(points_tensor, r=radius, batch=batch_tensor, loop=False, max_num_neighbors=255)

In [None]:
print(edge_index_pts)

In [None]:
# Example usage in Python
import graphmetnetwork_bindings as gmn

# Create an instance of the model
model = gmn.GraphMetNetwork()

# Load the weights
model.load_weights("path_to_weights_file.txt")

# Now, you can run the model with input data
for i, (x_cont, x_cat, edge_index, batch) in enumerate(dataloader):
    num_nodes = x_cont.shape[0]

    # Run the PyTorch model
    with torch.no_grad():
        output = torch_model(x_cont.squeeze(0), x_cat.squeeze(0), edge_index.squeeze(0), batch.squeeze(0))

    # Run the C++ model
    model.GraphMetNetworkLayer(x_cont.squeeze(0).numpy(), x_cat.squeeze(0).numpy(), num_nodes, batch.squeeze(0).numpy())

    # Compare intermediate values as before
    np.testing.assert_allclose(torch_model._emb_cont.numpy(), model.get_emb_cont(), rtol=1e-5)
    np.testing.assert_allclose(torch_model._emb_chrg.numpy(), model.get_emb_chrg(), rtol=1e-5)
    np.testing.assert_allclose(torch_model._emb_pdg.numpy(), model.get_emb_pdg(), rtol=1e-5)
    np.testing.assert_allclose(torch_model._emb_cat.numpy(), model.get_emb_cat(), rtol=1e-5)
    np.testing.assert_allclose(torch_model._emb.numpy(), model.get_emb(), rtol=1e-5)
    np.testing.assert_allclose(torch_model._emb1.numpy(), model.get_emb1(), rtol=1e-5)
    np.testing.assert_allclose(torch_model._emb2.numpy(), model.get_emb2(), rtol=1e-5)
    np.testing.assert_allclose(output.numpy(), model.get_output(), rtol=1e-5)
