# Quantum Lie-Equivariant GNN (Q-LieEGNN)

#### in this task, we were asked to implement and draw the architecture of a possible Quantum Graph Neural Network

I will start from the LorentzNet paper and official implementation, load the quark-gluon tagging data, and perform equivariance tests just for sanity checks, to confirm that the code works.

Once this is done, here I put a simple modification to incorporate parameterized circuits using Pennylane. Once this is done, we test again for equivariance. Once equivariance test is passed, I show equivariance to an arbitrary metric tensor $J$, where the Lorentz boosts are now symmetry-breaking, but rotations about a fixed plane are preserving.

In [2]:
# For Colab
!pip install torch_geometric
!pip install torch_sparse
!pip install torch_scatter


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting torch_sparse
  Using cached torch_sparse-0.6.18.tar.gz (209 kB)
  Preparing metadata (setup.py) ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[6 lines of output][0m
  [31m   [0m Traceback (most recent call last):
  [31m   [0m   File "<string>", line 2, in <module>
  [31m   [0m   File "<pip-setuptools-caller>", line 34, in <module>
  [31m   [0m   File "/tmp/pip-install-lcy59mga/torch-sparse_ff3c8ac0008f4b1d8e47a3a9bea01e6f/setup.py", line 8, in <module>
  [31m   [0m     import torch
  [31m   [0m ModuleNotFoundError: No module named 'torch'
  [31m   [0m [31m

## Variational circuits

In [96]:
n_qubits = 10

In [97]:
import torch
import pennylane as qml
import torch.nn.functional as F
from torch import nn
from torch_geometric.utils import to_dense_adj


dev = qml.device('default.qubit', wires=n_qubits)


def H_layer(nqubits):
    """Layer of single-qubit Hadamard gates.
    """
    for idx in range(nqubits):
        qml.Hadamard(wires=idx)


def RY_layer(w):
    """Layer of parametrized qubit rotations around the y axis.
    """
    for idx, element in enumerate(w):
        qml.RY(element, wires=idx)


def entangling_layer(nqubits):
    """Layer of CNOTs followed by another shifted layer of CNOT.
    """
    # In other words it should apply something like :
    # CNOT  CNOT  CNOT  CNOT...  CNOT
    #   CNOT  CNOT  CNOT...  CNOT
    for i in range(0, nqubits - 1, 2):  # Loop over even indices: i=0,2,...N-2
        qml.CNOT(wires=[i, i + 1])
    for i in range(1, nqubits - 1, 2):  # Loop over odd indices:  i=1,3,...N-3
        qml.CNOT(wires=[i, i + 1])


@qml.qnode(dev, interface="torch")
def quantum_net(q_input_features, q_weights_flat, q_depth, n_qubits):
    """
    The variational quantum circuit.
    """

    # Reshape weights
    q_weights = q_weights_flat.reshape(q_depth, n_qubits)

    # Start from state |+> , unbiased w.r.t. |0> and |1>
    H_layer(n_qubits)

    # Embed features in the quantum node
    RY_layer(q_input_features)

    # Sequence of trainable variational layers
    for k in range(q_depth):
        entangling_layer(n_qubits)
        RY_layer(q_weights[k])

    # Expectation values in the Z basis
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    return tuple(exp_vals)


class DressedQuantumNet(nn.Module):
    """
    Torch module implementing the *dressed* quantum net.
    """

    def __init__(self, n_qubits, q_depth = 1, q_delta=0.001):
        """
        Definition of the *dressed* layout.
        """

        super().__init__()
        self.n_qubits = n_qubits
        self.q_depth = q_depth
        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))

    def forward(self, input_features):
        """
        Defining how tensors are supposed to move through the *dressed* quantum
        net.
        """

        # Quantum Embedding (U(X))
        q_in = torch.tanh(input_features) * np.pi / 2.0

        # Apply the quantum circuit to each element of the batch and append to q_out
        q_out = torch.Tensor(0, self.n_qubits)
        q_out = q_out.to(device)
        # for batch in q_in:
        for elem in q_in:
            q_out_elem = quantum_net(elem, self.q_params, self.q_depth, self.n_qubits).float().unsqueeze(0)
            q_out = torch.cat((q_out, q_out_elem))

        # return the batch measurement of the PQC
        return q_out.unsqueeze(0)

In [173]:
import torch
import numpy as np
import energyflow
from scipy.sparse import coo_matrix
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data.distributed import DistributedSampler

def get_adj_matrix(n_nodes, batch_size, edge_mask):
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx*n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)

    edges = [torch.LongTensor(rows), torch.LongTensor(cols)]
    return edges

def collate_fn(data):
    data = list(zip(*data)) # label p4s nodes atom_mask
    data = [torch.stack(item) for item in data]
    batch_size, n_nodes, _ = data[1].size()
    atom_mask = data[-1]
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask *= diag_mask
    edges = get_adj_matrix(n_nodes, batch_size, edge_mask)
    return data + [edge_mask, edges]

def retrieve_dataloaders(batch_size, num_data = -1, use_one_hot = True, cache_dir = './data', num_workers=4):
    raw = energyflow.qg_jets.load(num_data=num_data, pad=True, ncol=4, generator='pythia',
                            with_bc=False, cache_dir=cache_dir)
    splits = ['train', 'val', 'test']
    data = {type:{'raw':None,'label':None} for type in splits}
    (data['train']['raw'],  data['val']['raw'],   data['test']['raw'],
    data['train']['label'], data['val']['label'], data['test']['label']) = \
        energyflow.utils.data_split(*raw, train=0.8, val=0.1, test=0.1, shuffle = False)

    enc = OneHotEncoder(handle_unknown='ignore').fit([[11],[13],[22],[130],[211],[321],[2112],[2212]])
    
    for split, value in data.items():
        pid = torch.from_numpy(np.abs(np.asarray(value['raw'][...,3], dtype=int))).unsqueeze(-1)
        p4s = torch.from_numpy(energyflow.p4s_from_ptyphipids(value['raw'],error_on_unknown=True))
        one_hot = enc.transform(pid.reshape(-1,1)).toarray().reshape(pid.shape[:2]+(-1,))
        one_hot = torch.from_numpy(one_hot)
        mass = torch.from_numpy(energyflow.ms_from_p4s(p4s)).unsqueeze(-1)
        charge = torch.from_numpy(energyflow.pids2chrgs(pid))
        if use_one_hot:
            nodes = one_hot
        else:
            nodes = torch.cat((mass,charge),dim=-1)
            nodes = torch.sign(nodes) * torch.log(torch.abs(nodes) + 1)
        atom_mask = (pid[...,0] != 0)
        value['p4s'] = p4s
        value['nodes'] = nodes
        value['label'] = torch.from_numpy(value['label'])
        value['atom_mask'] = atom_mask.to(torch.bool)

    datasets = {split: TensorDataset(value['label'], value['p4s'],
                                     value['nodes'], value['atom_mask'])
                for split, value in data.items()}

    # distributed training
    # train_sampler = DistributedSampler(datasets['train'], shuffle=True)
    # Construct PyTorch dataloaders from datasets
    dataloaders = {split: DataLoader(dataset,
                                     batch_size=batch_size,
                                     # sampler=train_sampler if (split == 'train') else DistributedSampler(dataset, shuffle=False),
                                     pin_memory=False,
                                     # persistent_workers=True,
                                     drop_last=True if (split == 'train') else False,
                                     num_workers=num_workers,
                                     collate_fn=collate_fn)
                        for split, dataset in datasets.items()}

    return dataloaders #train_sampler, dataloaders

if __name__ == '__main__':
    # train_sampler, dataloaders = retrieve_dataloaders(32, 100)
    dataloaders = retrieve_dataloaders(1, 20)
    for (label, p4s, nodes, atom_mask, edge_mask, edges) in dataloaders['train']:
        print(label.shape, p4s.shape, nodes.shape, atom_mask.shape,
              edge_mask.shape, edges[0].shape, edges[1].shape)
        break

N nodes: N nodes:   N nodes: N nodes: 139139 

 139
139
N nodes:  N nodes: 139 N nodes: 139
 139
N nodes: 
 139
N nodes:  139
torch.Size([1]) torch.Size([1, 139, 4]) torch.Size([1, 139, 8]) torch.Size([1, 139]) torch.Size([1, 139, 139]) torch.Size([306]) torch.Size([306])


In [174]:
batch_size = 1
n_nodes = 139
device = 'cpu'
dtype = torch.float32

atom_positions = p4s[:, :, :].view(batch_size * n_nodes, -1).to(device, dtype)

atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device, dtype)
edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)

edges = [a.to(device) for a in edges]
nodes = nodes.view(batch_size * n_nodes, -1).to(device,dtype)

In [270]:
import torch
from torch import nn
import numpy as np

class LGEB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LGEB, self).__init__()
        self.c_weight = c_weight
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product

        self.include_x = include_x
        self.phi_e = nn.Sequential(
            nn.Linear(n_input * 2 + n_edge_attr, n_hidden, bias=False),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU())

        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())
        
        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4
        

    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        out = self.phi_e(out)
        # print("m_model output: ", out.shape)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.phi_e(out)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m): # norms
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # print("m: ", m.shape)
        # print("trans: ", trans.shape)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        # print("trans: ", trans.shape)
        # print("x.size: ", x.size(0))
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight # * norms[i, j], smth like that, or norms
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            # print("X: ", x)
            x = self.x_model(x, edges, x_diff, m)
            # print("phi_x(X) = ", x, '\n---\n')
            
        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LorentzNet(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LGEB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LorentzNet, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LGEBs = nn.ModuleList([LGEB(self.n_hidden, self.n_hidden, self.n_hidden, 
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LGEBs[i](h, x, edges, node_attr=scalars)
        print("h after (just the first particle): \n", h[0].cpu().detach().numpy())
            
        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)

        print("Final preds: \n", pred.cpu().detach().numpy())
        return pred.squeeze(1)

### Now that we have the official code for the classical, just for sanity checking, let's test for equivariance

The cell below is just an auxiliary function to give us the boosts

In [None]:
from math import sqrt
import numpy as np

# Speed of light (m/s)
c = 299792458

"""Lorentz transformations describe the transition between two inertial reference
frames F and F', each of which is moving in some direction with respect to the
other. This code only calculates Lorentz transformations for movement in the x
direction with no spatial rotation (i.e., a Lorentz boost in the x direction).
The Lorentz transformations are calculated here as linear transformations of
four-vectors [ct, x, y, z] described by Minkowski space. Note that t (time) is
multiplied by c (the speed of light) in the first entry of each four-vector.

Thus, if X = [ct; x; y; z] and X' = [ct'; x'; y'; z'] are the four-vectors for
two inertial reference frames and X' moves in the x direction with velocity v
with respect to X, then the Lorentz transformation from X to X' is X' = BX,
where

    | γ  -γβ  0  0|
B = |-γβ  γ   0  0|
    | 0   0   1  0|
    | 0   0   0  1|

is the matrix describing the Lorentz boost between X and X',
γ = 1 / √(1 - v²/c²) is the Lorentz factor, and β = v/c is the velocity as
a fraction of c.
"""


def beta(velocity: float) -> float:
    """
    Calculates β = v/c, the given velocity as a fraction of c
    >>> beta(c)
    1.0
    >>> beta(199792458)
    0.666435904801848
    """
    if velocity > c:
        raise ValueError("Speed must not exceed light speed 299,792,458 [m/s]!")
    elif velocity < 1:
        # Usually the speed should be much higher than 1 (c order of magnitude)
        raise ValueError("Speed must be greater than or equal to 1!")

    return velocity / c


def gamma(velocity: float) -> float:
    """
    Calculate the Lorentz factor γ = 1 / √(1 - v²/c²) for a given velocity
    >>> gamma(4)
    1.0000000000000002
    >>> gamma(1e5)
    1.0000000556325075
    >>> gamma(3e7)
    1.005044845777813
    >>> gamma(2.8e8)
    2.7985595722318277
    """
    return 1 / sqrt(1 - beta(velocity) ** 2)


def transformation_matrix(velocity: float) -> np.ndarray:
    """
    Calculate the Lorentz transformation matrix for movement in the x direction:

    | γ  -γβ  0  0|
    |-γβ  γ   0  0|
    | 0   0   1  0|
    | 0   0   0  1|

    where γ is the Lorentz factor and β is the velocity as a fraction of c
    >>> transformation_matrix(29979245)
    array([[ 1.00503781, -0.10050378,  0.        ,  0.        ],
           [-0.10050378,  1.00503781,  0.        ,  0.        ],
           [ 0.        ,  0.        ,  1.        ,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  1.        ]])
    """
    return np.array(
        [
            [gamma(velocity), -gamma(velocity) * beta(velocity), 0, 0],
            [-gamma(velocity) * beta(velocity), gamma(velocity), 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
        ]
    )


### Now, the model

In [291]:
model = LorentzNet(n_scalar = 8, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 6,\
                       c_weight = 1e-3)

### Let's start with a default prediction

In [292]:
pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [ 0.22275785 -0.00887579 -0.45730796  0.4752541 ]
h after (just the first particle): 
 [ 0.6250086  -0.4157089   0.19434586  3.7227166 ]
Final preds: 
 [[0.25635535 0.08354717]]


### ... taking any random nonsense transformation in the four-momentum vectors
i.e.: multiplying by 0.1. Does the hidden rep stay the same?

In [293]:
pred = model(scalars=nodes, x= 0.1 * atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [ 0.22275785 -0.00887579 -0.45730796  0.4752541 ]
h after (just the first particle): 
 [ 1.5326474  -0.09580445 -0.10811514  4.131195  ]
Final preds: 
 [[0.25635535 0.08354717]]


### Even though the final logits in this case wasn't different, if we look the last output of h (which contains both scalar and 4-momenta information), it changed! Now, what about Lorentz transformations?

In [295]:
pred = model(scalars=nodes, x= (torch.tensor(transformation_matrix(220000000)) @ atom_positions.to(dtype=torch.float64).T).to(dtype=torch.float32).T, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [ 0.22275785 -0.00887579 -0.45730796  0.4752541 ]
h after (just the first particle): 
 [ 0.6251303  -0.41550016  0.1935861   3.721067  ]
Final preds: 
 [[0.25635535 0.08354717]]


## Equivariance works. Now, let's move to our Q-LieEGNN

In [269]:
import torch
from torch import nn
import numpy as np
import pennylane as qml

"""
    Quantum Lie-Equivariant Block (QLieGEB).
    
        - Given the Lie generators found (i.e.: through LieGAN, oracle-preserving latent flow, or some other approach
          that we develop further), once the metric tensor J is found via the equation:

                          L.J + J.(L^T) = 0,
                          
          we just have to specify the metric to make the model symmetry-preserving to the corresponding Lie group. 
          In the cells below, I will show first how the model preserves symmetries (starting with the default Lorentz group),
          and when we change J to some other metric (Euclidean, for example), Lorentz boosts break equivariance, while other
          transformations preserve it (rotations, for the example shown in the cells below)
"""
class QLieGEB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(QLieGEB, self).__init__()
        self.c_weight = c_weight
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product

        self.include_x = include_x

        """
            phi_e: input size: n_qubits -> output size: n_qubits
            n_hidden has to be equal to n_input (n_input * 2 + n_edge_attr),
            but this is just considering that this is a simple working example.
        """
        self.phi_e = DressedQuantumNet(n_input * 2 + n_edge_attr)

        n_hidden = n_input * 2 + n_edge_attr
        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())        
        # self.phi_e = nn.Sequential(
        #     nn.Linear(n_input * 2 + n_edge_attr, n_hidden, bias=False),
        #     nn.BatchNorm1d(n_hidden),
        #     nn.ReLU(),
        #     nn.Linear(n_hidden, n_hidden),
        #     nn.ReLU())

        # self.phi_h = nn.Sequential(
        #     nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
        #     nn.BatchNorm1d(n_hidden),
        #     nn.ReLU(),
        #     nn.Linear(n_hidden, n_output))

        # layer = nn.Linear(n_hidden, 1, bias=False)
        # torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        # self.phi_x = nn.Sequential(
        #     nn.Linear(n_hidden, n_hidden),
        #     nn.ReLU(),
        #     layer)

        # self.phi_m = nn.Sequential(
        #     nn.Linear(n_hidden, 1),
        #     nn.Sigmoid())
        
        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4

    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        # print("Before embedding to |psi> : ", out)
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m):
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            x = self.x_model(x, edges, x_diff, m)
        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class QLieEGNN(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LGEB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(QLieEGNN, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.QLieGEBs = nn.ModuleList([QLieGEB(self.n_hidden, self.n_hidden, self.n_hidden, 
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)
        
        print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.QLieGEBs[i](h, x, edges, node_attr=scalars)
        
        print("h after (just the first particle): \n", h[0].cpu().detach().numpy())
        
        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)
        return pred.squeeze(1)


def unsorted_segment_sum(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    return result

def unsorted_segment_mean(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_mean`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    count = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    count.index_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)

def normsq4(p):
    r''' Minkowski square norm
         `\|p\|^2 = p[0]^2-p[1]^2-p[2]^2-p[3]^2`
    ''' 
    psq = torch.pow(p, 2)
    return 2 * psq[..., 0] - psq.sum(dim=-1)
    
def dotsq4(p,q):
    r''' Minkowski inner product
         `<p,q> = p[0]q[0]-p[1]q[1]-p[2]q[2]-p[3]q[3]`
    '''
    psq = p*q
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def normA_fn(A):
    return lambda p: torch.einsum('...i, ij, ...j->...', p, A, p)

def dotA_fn(A):
    return lambda p, q: torch.einsum('...i, ij, ...j->...', p, A, q)
    
def psi(p):
    ''' `\psi(p) = Sgn(p) \cdot \log(|p| + 1)`
    '''
    return torch.sign(p) * torch.log(torch.abs(p) + 1)

## Quantum model

#### Let's start with a default prediction

In [271]:
model = QLieEGNN(n_scalar = 8, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 6,\
                       c_weight = 1e-3)

In [272]:
pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [-0.02530114 -0.01923932 -0.2870935   0.00164617]
h after (just the first particle): 
 [ 0.9169079   1.3130671   0.57629734 -0.47118652]


### ... taking any random nonsense transformation in the four-momentum vectors
i.e.: multiplying by 0.1. Does the hidden rep stay the same?

In [273]:
pred = model(scalars=nodes, x=0.1 * atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [-0.02530114 -0.01923932 -0.2870935   0.00164617]
h after (just the first particle): 
 [ 2.43436    1.324191  -3.4988296 -1.3810511]


### Not at all! What about Lorentz transformations?

In [274]:
pred = model(scalars=nodes, x=(torch.tensor(transformation_matrix(180000000)) @ atom_positions.to(dtype=torch.float64).T).to(dtype=torch.float32).T, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [-0.02530114 -0.01923932 -0.2870935   0.00164617]
h after (just the first particle): 
 [ 0.91627324  1.3119451   0.5798764  -0.47567284]


### Equivariance holds!

#### Now, let's do the predictions again for some other metric tensor J. 
#### This will illustrate the situation where we found an infinitesimal generator for some experimental data
(i.e.: following Robin Walter's approach in LieGAN; the oracle-preserving latents from Roy Forestano et. al, or some other approach that we develop further - would be interesting). Once we have the generators, suppose that we solved for the metric tensor by solving the following eq. (as proposed in Robin's paper):

\begin{equation}
L\cdot J + J\cdot L^{T} = 0
\end{equation}

Here, the Lorentz transformations should not anymore preserve equivariance. To illustrate this, let's consider $J = diag(1,1,1,1)$, that is, we recover the Euclidean norm and dot-product. So, if our model is working, then **boosts** should **break equivariance**, but **rotations** should **preserve** it:

In [275]:
J = torch.eye(4)
print("J: \n", J)

J: 
 tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])


In [276]:
"""
    I will define a rotation matrix about the xy plane. Given that our QLie-EGNN has a new metric,
    the Lorentz boosts now should break equivariance, but rotations in this case, should preserve
    it.
"""
rot = torch.tensor([[np.cos(np.pi), -np.sin(np.pi), 0, 0],
                    [np.sin(np.pi), np.cos(np.pi),  0, 0],
                    [     0       ,       0      ,  1, 0],
                    [     0       ,       0      ,  0, 1]])

In [281]:
model = QLieEGNN(n_scalar = 8, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 6,\
                       c_weight = 1e-3, A=J)

#### Again, the default forward pass using the Euclidean metric J.

In [282]:
pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [-0.20670539  0.26581004 -0.09239267 -0.22357208]
h after (just the first particle): 
 [-0.46950197 -3.965734   -3.0378942  -1.9866586 ]


#### Now, the Lorentz boosted jets:

In [283]:
pred = model(scalars=nodes, x=(torch.tensor(transformation_matrix(240000000)) @ atom_positions.to(dtype=torch.float64).T).to(dtype=torch.float32).T, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [-0.20670539  0.26581004 -0.09239267 -0.22357208]
h after (just the first particle): 
 [-0.60887957 -3.9931903  -3.0501935  -2.342436  ]


#### Equivariance is broken. What about a rotation about the xy plane?

In [284]:
pred = model(scalars=nodes, x=(rot @ atom_positions.to(dtype=torch.float64).T).to(dtype=torch.float32).T, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

h before (just the first particle): 
 [-0.20670539  0.26581004 -0.09239267 -0.22357208]
h after (just the first particle): 
 [-0.46950197 -3.965734   -3.0378942  -1.9866586 ]


#### Equivariant again.
I propose to work on this project, exploring how to improve the symmetry discovery. Also, besides incorporating arbitrary Lie invariances, Infrared Collinear (IRC) safety would be very interesting, and study how our model performs on tagging semi-visible jets for Beyond the Standard Model (BSM) discoveries, like was done in [6] for the Hidden Valley models.