# Graph Neural Network

## Graphs:

Graphs are tuples of a set of nodes, which can be connected with edges if there is a relationship between them. This connections can be directed or undirected. Each node has a feature vector which can be concatenated into a feature matrix. This work will take use of simple (without self-loops) and undirected graphs. Each graph represents a different halo, where the nodes are the subhalos within the corresponding hosting halo. The chosen subhalo features are: 3D comoving position, stellar mass, magnitude of velocity and the stellar half mass radius (radius containing half of the stellar mass in the galaxy). There are several ways to make connect the nodes. Gravity is a long-range force every node should be connected to all others. Since the effect is greater at smaller distances we choose a radius as a hyperparameter in which a node is connected to all others within the radius. The optimal radius is so large that 98% of all graphs are complete (all nodes are connected to each other):

## Message Passing Layer

GNNs make use of an mechanism called message passing. This means that each node aggregates information by all nodes in its neighbourhood in order to update its feature vector into the so called new "hidden feature vector". The author discovered that for this particular purpose the design of an Edge Convolution as describe in (https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html).

\begin{equation}
h_i = \max_{j\in N(i)} \psi ([x_i, x_i-x_j])
\end{equation}

where $\psi$ denotes a multi layer perceptron with an input layer with $2*n_{feat}$ channels and 3 hidden layers with 300, 300 and 100 hidden channels seperated by ReLu activation function works best.

Implementation:

In [2]:
import torch
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import MessagePassing 
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
from torch_cluster import radius_graph

# Edge convolution layer
class EdgeLayer(MessagePassing):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(EdgeLayer, self).__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Sequential(Linear(2 * in_channels, mid_channels),
                       ReLU(),
                       Linear(mid_channels, mid_channels),
                       ReLU(),
                       Linear(mid_channels, out_channels))
        self.messages = 0.
        self.input = 0.

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        input = torch.cat([x_i, x_j - x_i], dim=-1)  # tmp has shape [E, 2 * in_channels]

        self.input = input
        self.messages = self.mlp(input)

        return self.messages

## General GNN architecture:

The hidden feature vectors are put into a pooling layer, where the dimension is reduced from n_feat*n_nodes to 20 by concatinating the aggregated features via adding, max and mean with the global properties. This new global feature vector is then put into a MLP with 3 hidden layers with 300 channels seperated by ReLu functions and an output layer consisting of the target and its standard deviation.

\begin{equation}
y = \phi(\bigoplus_{i \in G_h} h_i)
\end{equation}

where $\psi$ denotes the MLP and $\bigoplus$ some aggregation. In this implementation max, mean as well as sum have been used to create global features.

Implementation:

In [4]:
class ModelGNN(torch.nn.Module):
    def __init__(self, use_model, node_features, n_layers, k_nn, hidden_channels=300, latent_channels=100, loop=False):
        super(ModelGNN, self).__init__()

        in_channels = node_features

        #Graph layer (hyperparameter optimization said only 1 is sufficient)
        self.layer = EdgeLayer(in_channels, hidden_channels, latent_channels)

        lin_in = latent_channels*3+2 # BIG ?????  should be 20???
        self.lin = Sequential(Linear(lin_in, latent_channels),
                              ReLU(),
                              Linear(latent_channels, latent_channels),
                              ReLU(),
                              Linear(latent_channels, 2))

        self.k_nn = k_nn #hyperparameter
        self.pooled = 0.
        self.h = 0.
        self.loop = loop #no selfloops loop = False
        self.namemodel = use_model

    def forward(self, data):

        x, pos, batch, u = data.x, data.pos, data.batch, data.u

        # Get edges using positions by computing the neighbors within a radius
        edge_index = radius_graph(pos, r=self.k_nn, batch=batch, loop=self.loop)

        # Start message passing
        self.h = x
        x = x.relu()


        # Mix different global pooling layers
        addpool = global_add_pool(x, batch) # [num_examples, hidden_channels]
        meanpool = global_mean_pool(x, batch)
        maxpool = global_max_pool(x, batch)
        self.pooled = torch.cat([addpool, meanpool, maxpool, u], dim=1) #dimension 1,20
        # Final linear layer
        return self.lin(self.pooled)

## Training

The model is designed to ouput two values: mean and standard deviation of the halo mass poserior. Therefore the follwoing loss function is minimized:

\begin{equation}
L = \log{(\sum_{i \in batch} (y_{truth,i}-y_{infer,i})^2}) + \log{(\sum_{i \in batch} ((y_{truth,i}-y_{infer,i})^2-\sigma_i^2)^2}) 
\end{equation}

Implementation of the training routine:

GPU usage would decrease the execution time, therefore if available it should be used:

In [5]:
# use GPUs if available
if torch.cuda.is_available():
    print("CUDA Available")
    device = torch.device('cuda')
else:
    print('CUDA Not Available')
    device = torch.device('cpu')

CUDA Available


Permutation and translation invariance are already statisfied by message passing and graph design. Rotating all subhalos around the center of the halo should also leave the global halo properties unchanged. To ensure rotational invariance we randomly perform rotations on each graph at every training epoch.

In [6]:
from scipy.spatial.transform import Rotation as Rot

def train(loader, model, optimizer):
    model.train()

    loss_tot = 0
    for data in loader:  # Iterate in batches over the training dataset.

        # Rotate randomly for data augmentation
        rotmat = Rot.random().as_matrix()
        data.pos = torch.tensor([rotmat.dot(p) for p in data.pos], dtype=torch.float32)
        data.x[:,:3] = torch.tensor([rotmat.dot(p) for p in data.x[:,:3]], dtype=torch.float32)

        data.to(device)
        optimizer.zero_grad()  # Clear gradients.
        out = model(data)  # Perform a single forward pass.
        y_out, err_out = out[:,0], out[:,1]     # Take mean and standard deviation of the output

        # Compute loss as sum of two terms for likelihood-free inference
        loss_mse = torch.mean((y_out - data.y)**2 , axis=0)
        loss_lfi = torch.mean(((y_out - data.y)**2 - err_out**2)**2, axis=0)
        loss = torch.log(loss_mse) + torch.log(loss_lfi)

        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        loss_tot += loss.item()

    return loss_tot/len(loader) #mean training loss per batch

Testing routine:

In [9]:
import numpy as np

def test(loader, model, params):
    model.eval()

    errs = []
    loss_tot = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        with torch.no_grad():

            data.to(device)
            out = model(data)  # Perform a single forward pass.
            y_out, err_out = out[:,0], out[:,1] #mean and std
            err = (y_out.reshape(-1) - data.y)/data.y #relative error of each sample of one batch
            errs.append( np.abs(err.detach().cpu().numpy()).mean(axis=0) ) #mean relative error of one batch

            # Compute loss as sum of two terms for likelihood-free inference
            loss_mse = torch.mean((y_out - data.y)**2 , axis=0)
            loss_lfi = torch.mean(((y_out - data.y)**2 - err_out**2)**2, axis=0)
            loss = torch.log(loss_mse) + torch.log(loss_lfi)
            loss_tot += loss.item() #calculate total loss of all batches


    return loss_tot/len(loader), np.array(errs).mean(axis=0) #output mean loss and std loss of target per batch

Full training routine:

In [10]:
def training_routine(model, train_loader, test_loader, params, verbose=True):

    use_model, learning_rate, weight_decay, n_layers, k_nn, n_epochs, training, simsuite, simset, n_sims = params

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) #choosing the ADAM optimizer for the weights

    train_losses, valid_losses = [], []
    valid_loss_min, err_min = 1000., 1000.

    # Training loop
    for epoch in range(1, n_epochs+1):
        train_loss = train(train_loader, model, optimizer)
        test_loss, err = test(test_loader, model, params)
        train_losses.append(train_loss); valid_losses.append(test_loss) 

        # Save model if it has improved
        if test_loss <= valid_loss_min: #if current test loss is smaller than the previous test_loss save model
            if verbose: print("Validation loss decreased ({:.2e} --> {:.2e}).  Saving model ...".format(valid_loss_min,test_loss))
            torch.save(model.state_dict(), "Models/"+namemodel(params))
            valid_loss_min = test_loss #current test loss is now the new value to beat
            err_min = err 

        if verbose: print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.2e}, Validation Loss: {test_loss:.2e}, Relative error: {err:.2e}')

    return train_losses, valid_losses