In [None]:
import torch
import torch_sparse
import torch_scatter
import torch_cluster
import torch_geometric

import numpy as np

In [None]:
class GraphBuildingLSH(torch.nn.Module):
    def __init__(self, feature_dim, bin_size, max_num_bins, k, **kwargs):
        super(GraphBuildingLSH, self).__init__(**kwargs)

        self.k = k
        self.bin_size = bin_size
        self.max_num_bins = max_num_bins
        self.codebook = torch.randn((feature_dim, max_num_bins // 2))

        self.reset_parameters()

    def reset_parameters(self):
        pass

    def forward(self, x):
        shp = x.shape  # (batches, nodes, features)
        n_bins = shp[1] // self.bin_size

        assert n_bins <= self.max_num_bins
        mul = torch.matmul(x, self.codebook[:, : n_bins // 2])
        cmul = torch.cat([mul, -mul], axis=-1)

        bin_idx = torch.argmax(cmul, axis=-1)
        bins_split = torch.reshape(torch.argsort(bin_idx), (shp[0], n_bins, shp[1] // n_bins))

        points_binned = torch.stack(
            [x[ibatch][bins_split[ibatch]] for ibatch in range(x.shape[0])]
        )  # (batches, bins, nodes, features)

        # multiply binned feature dimension
        dm_binned = torch.einsum("...ij,...kj->...ik", points_binned, points_binned)
        dm = torch.sigmoid(dm_binned)  # (batches, bins, nodes, nodes)

        # (batches, bins, nodes, neighbors)
        topk = torch.topk(dm, self.k, axis=-1)

        sps = []
        for ibatch in range(dm.shape[0]):
            src = []
            dst = []
            val = []
            for ibin in range(dm.shape[1]):
                inds_src = torch.arange(0, dm.shape[2])
                inds_dst = topk.indices[ibatch, ibin]
                global_indices_src = bins_split[ibatch, ibin][inds_src]
                global_indices_dst = bins_split[ibatch, ibin][inds_dst]
                vals = topk.values[ibatch, ibin]

                for ineigh in range(inds_dst.shape[-1]):
                    src.append(global_indices_src)
                    dst.append(global_indices_dst[:, ineigh])
                    val.append(vals[:, ineigh])

            src = torch.cat(src)
            dst = torch.cat(dst)
            val = torch.cat(val)

            sp = torch.sparse_coo_tensor(torch.stack([src, dst]), val, requires_grad=True, size=(shp[1], shp[1]))
            sps.append(sp)

        # Sparse (batches, nodes, nodes)
        sp = torch.stack(sps).coalesce()

        return sp

In [None]:
# take a 3d sparse matrix, and output a 2d sparse matrix,
# where the batch dimension has been stacked in a block-diagonal way
def stacked_sparse(dm):
    # dm.shape: (num_batch, nodes, nodes)

    vals = []
    inds = []
    for ibatch in range(dm.shape[0]):
        ind = dm[ibatch].coalesce().indices()

        ind += ibatch * dm.shape[1]
        inds.append(ind)

    edge_index = torch.cat(inds, axis=-1)  # (2, num_batch*nodes)
    edge_values = dm.values()  # (num_batch*nodes)
    return edge_index, edge_values


class Net(torch.nn.Module):
    def __init__(self, num_node_features):
        super(Net, self).__init__()

        feature_dim = 16
        self.lin1 = torch.nn.Linear(num_node_features, feature_dim)
        self.dm = GraphBuildingLSH(feature_dim=feature_dim, bin_size=100, max_num_bins=200, k=16)
        self.gcn = torch_geometric.nn.GCNConv(num_node_features, 32)
        self.lin2 = torch.nn.Linear(32, 1)

    def forward(self, x):

        n_batches = x.shape[0]
        n_points = x.shape[1]

        i1 = self.lin1(x)  # (n_batches, nodes, feature_dim)
        dm = self.dm(i1)  # (n_batches, nodes, nodes)

        edge_index, edge_vals = stacked_sparse(dm)

        xflat = torch.reshape(x, (n_batches * n_points, x.shape[-1]))
        i2 = self.gcn(xflat, edge_index, edge_vals)  # (n_batches, nodes, 32)
        i2 = torch.reshape(i2, (n_batches, n_points, i2.shape[-1]))

        i3 = self.lin2(i2)  # (n_batches, nodes, 1)

        return i3, dm

In [None]:
# generate an event that contains particles with a uniform energy and spatial distribution
# each particle generates deposits with a random smearing around itself until the energy is expended
def generate_event(
    mean_num_particles_per_event=1000,
    max_particle_energy=10.0,
    deposit_fraction=0.1,
    lowest_energy_threshold=0.5,
    deposit_pos_spread=0.02,
):

    particles = []
    all_deposits = []
    for ipart in range(np.random.poisson(mean_num_particles_per_event)):
        energy = np.random.uniform(0, max_particle_energy)
        pos_x = np.random.uniform(-1.0, 1.0)
        pos_y = np.random.uniform(-1.0, 1.0)
        orig_energy = energy
        particles.append([orig_energy, pos_x, pos_y])
        deposits = []
        while energy > lowest_energy_threshold:
            deposit_energy = np.random.normal(energy * deposit_fraction)
            if deposit_energy > lowest_energy_threshold:
                energy -= deposit_energy
                deposit_x = np.random.uniform(pos_x - deposit_pos_spread, pos_x + deposit_pos_spread)
                deposit_y = np.random.uniform(pos_y - deposit_pos_spread, pos_y + deposit_pos_spread)
                deposits.append([deposit_energy, deposit_x, deposit_y, -1, ipart])
        if len(deposits) > 0:
            top_deposit_index = np.argsort(np.array([d[0] for d in deposits]))[-1]
            deposits[top_deposit_index][3] = ipart
            all_deposits.append(deposits)

    particles_array = np.stack(particles)
    deposits_array = np.concatenate(all_deposits)
    particles_array_resized = np.zeros((deposits_array.shape[0], 3))

    for ideposit in range(deposits_array.shape[0]):
        particle_index = int(deposits_array[ideposit, 3])
        if particle_index >= 0:
            particles_array_resized[ideposit] = particles_array[particle_index]

    deposits_array = deposits_array[:, :3]

    return deposits_array, particles_array_resized


# pad all events to the same size
def generate_events(padded_size=5000, num_events=10):
    evs = [generate_event() for i in range(num_events)]

    Xs = []
    ys = []
    for X, y in evs:
        X = X[:padded_size]
        y = y[:padded_size]
        X = np.pad(X, ((0, padded_size - X.shape[0]), (0, 0)))
        y = np.pad(y, ((0, padded_size - y.shape[0]), (0, 0)))
        Xs.append(X)
        ys.append(y)
    X = np.stack(Xs)
    y = np.stack(ys)

    return X, y

In [None]:
X, y = generate_events(num_events=10)

In [None]:
X.shape  # (num_events, num_signals_per_event, num_features)

In [None]:
# (energy, pos_x, pos_y)
X[0, 0, :]

In [None]:
# (energy, pos_x, pos_y)
y[0, 0, :]

In [None]:
iev = 5
ymsk = y[iev, :, 0] != 0

plt.figure(figsize=(10, 10))
plt.scatter(X[iev, :, 1], X[iev, :, 2], marker="o", color="red", s=2.0)
plt.scatter(y[iev][ymsk][:, 1], y[iev][ymsk][:, 2], marker="s", color="blue", s=10 * y[iev][ymsk][:, 0], alpha=0.2)
plt.title("Input set (no edges)")

In [None]:
net = Net(3).float()

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
Xt = torch.from_numpy(X).float()
yt = torch.from_numpy(y).float()

for epoch in range(100):
    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs, dms = net(Xt)
    loss = criterion(outputs, yt[:, :, 0:1])
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss = loss.item()
    print(running_loss)

In [None]:
dm = dms[0].coalesce().to_dense().detach().numpy()

In [None]:
msk = X[0, :, 0] == 0

In [None]:
dm[msk, :] = 0
dm[:, msk] = 0

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(dm[:1000, :1000], cmap="Blues")

In [None]:
rows, cols = np.where(dm > 0)
edges = np.stack([rows, cols])

In [None]:
random_edges = np.random.permutation(edges.shape[1])[:10000]

In [None]:
plt.figure(figsize=(10, 10))
plt.scatter(X[0, ~msk, 1], X[0, ~msk, 2], marker="o", color="red", s=2.0)

plt.plot(
    X[0, edges[:, random_edges], 1],
    X[0, edges[:, random_edges], 2],
    linestyle="-",
    marker="o",
    color="black",
    markerfacecolor="red",
    markeredgecolor="red",
    markersize=2.0,
    lw=0.1,
)
plt.xlim(-1, 1)
plt.ylim(-1, 1)