# PEARL reimplementation

@misc{kanatsoulis2025learningefficientpositionalencodings,
      title={Learning Efficient Positional Encodings with Graph Neural Networks}, 
      author={Charilaos I. Kanatsoulis and Evelyn Choi and Stephanie Jegelka and Jure Leskovec and Alejandro Ribeiro},
      year={2025},
      eprint={2502.01122},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2502.01122}, 
}

In [15]:
import random, numpy as np
import torch, torch.nn as nn
from torch_geometric.nn import MessagePassing

  from .autonotebook import tqdm as notebook_tqdm


## Hyper Parameters

In [12]:
class HyperParam:
    # Dataset properties (ZINimport torch.nn as nn specific)
    n_node_types: int = 28  # Number of atom types in ZINC
    n_edge_types: int = 3   # Number of bond types in ZINC

    # General hyperparameters
    seed: int = 42
    device: torch.device
    device_name: str

    # Model > MLP hyperparameters
    n_mlp_layers: int = 3
    mlp_hidden_dims: int = 128
    mlp_dropout_prob: float = 0.0
    mlp_norm_type: str = "batch"

    # Model > GINE hyperparameters
    n_base_layers: int = 4
    node_emb_dims: int = 128
    base_hidden_dims: int = 128
    gine_model_bn: bool = False
    pooling: str = "add"
    target_dim: int = 1

    # Model > GIN / SampleAggregator hyperparameters
    gin_sample_aggregator_bn: bool = True
    n_sample_aggr_layers: int = 8
    sample_aggr_hidden_dims: int = 40

    # Model > Positional Encoding / PEARL
    pe_dims: int = 37  # Based on SPE paper (Huang et al., 2023)
    basis: bool = True  # False for R-PEARL, True for B-PEARL
    num_samples: int = 120  # Number of samples for R-PEARL
    pearl_k: int = 7  # Polynomial filter order
    pearl_mlp_nlayers: int = 1
    pearl_mlp_hid: int = 37
    pearl_mlp_out: int = 37

    # Dataset hyperparameters
    use_subset: bool = True  # Use ZINC subset (12K graphs) or full (250K)
    train_batch_size: int = 32
    val_batch_size: int = 32
    test_batch_size: int = 32
    
    # Training hyperparameters
    learning_rate: float = 1e-3
    weight_decay: float = 1e-6
    num_epochs: int = 10  # Set to 1400 for full training (as in paper)
    n_warmup_steps: int = 100
    
    def __init__(self):
        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
            self.device_name = "MPS"
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_name = torch.cuda.get_device_name(0)
        else:
            self.device = torch.device("cpu")
            self.device_name = "CPU"
        self.set_seed()

    def set_seed(self) -> None:
        """Set random seeds for reproducibility."""
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)

    def device(self) -> torch.device:
        return self.device    

# Initialize hyperparameters
hp = HyperParam()
print(f"Device: {hp.device}")
print(f"Device name: {hp.device_name}")
print(f"Using {'ZINC subset (12K)' if hp.use_subset else 'Full ZINC (250K)'}")
print(f"Mode: {'R-PEARL' if not hp.basis else 'B-PEARL'}")

Device: cuda
Device name: AMD Radeon RX 7800 XT
Using ZINC subset (12K)
Mode: B-PEARL


## General MPL

We're going to apply on top of different models:
* A 
* B
* C

With the possibility to support masking. In this case of course the shape of the mask should be compatinle with the shape of the variable `X` forwared to the model.

* for each layers:
  1. Linar Layer (linear projection)
  2. Apply masking (if present)
  3. Normalization (batch or layer)
  4. Activation function (usually the Relu)
  5. Dropout (if configured)

* at the end:
  1. Final Linar Layer (linear projection)
  2. Dropout (if configured)

In [13]:
class MLPLayer(nn.Module):
    def __init__(self, in_dims: int, out_dims: int, hp: HyperParam):
        super(MLPLayer, self).__init__()
        self.linear = nn.Linear(in_dims, out_dims)
        self.normalization = nn.BatchNorm1d(out_dims) if hp.mlp_norm_type == "batch" else nn.LayerNorm(out_dims)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(p=hp.mlp_dropout_prob)

    def forward(self, X: torch.Tensor, mask=None) -> torch.Tensor:
        X = self.linear(X)
        if mask is not None:
            X[~mask] = 0

        if mask is None:
            shape = X.size()
            X = X.reshape(-1, shape[-1])
            X = self.normalization(X)
            X = X.reshape(shape)
        else:
            X[mask] = self.normalization(X[mask])

        X = self.activation(X)
        X = self.dropout(X)
        return X


class MLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int, hp: HyperParam):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(hp.n_mlp_layers - 1):
            self.layers.append(MLPLayer(in_dims, hp.mlp_hidden_dims, hp))
            in_dims = hp.mlp_hidden_dims
        self.linear = nn.Linear(hp.mlp_hidden_dims, out_dims)
        self.dropout = nn.Dropout(p=hp.mlp_dropout_prob)

    def forward(self, X: torch.Tensor, mask=None) -> torch.Tensor:
        for layer in self.layers:
            X = layer(X, mask=mask)
        X = self.linear(X)
        X = self.dropout(X)
        return X

    @property
    def out_dims(self) -> int:
        return self.linear.out_features

## GIN Model

This is used as final layer of the sample aggregator model that is used to produce the positional embeddings.
We don't have in input edge embeddings, so a GIN is enough and we don't need in this case a GINE.

**NOTE** `node_dim=0` is crucial in our configuration. Since the dimensions of the expected forwarded input `X` are:

0. Dimension that identifies the node
1. Intermediate dimention: used to process in parallel (M or N) initial independent node attributes
2. Node attributes

After the message passing implemented as GIN Layer a batch normalization is applied.
While the residual connection is not applied in the paper code base in case of ZINC.
Actually I don't know why, since it seems that residual connections never hurt.
Also I don't know why for the ZINC the paper implementation uses batch normalization insted of layer normalization.
Moder tranformers use extensivly layer normalization and residual connections.
In any case I keep my reimplementation consistent with the paper hyperpameter choice for ZINC.

Last comment here is about the masking. Masking is used by the B-PEARL implementation, not by R-PEARL.
We want to support both the kinds.

In [None]:
class GINLayer(MessagePassing):
    def __init__(self, in_dims: int, out_dims: int, hp: HyperParam):
        super(GINLayer, self).__init__(aggr="add", node_dim=0)
        self.eps = nn.Parameter(data=torch.randn(1))
        self.mlp = MLP(in_dims, out_dims, hp)

    def forward(self, X: torch.Tensor, edge_index: torch.Tensor, mask=None) -> torch.Tensor:
        S = self.propagate(edge_index, X=X)
        Z = (1 + self.eps) * X + S
        return self.mlp(Z, mask=mask)
    
    def message(self, X_j: torch.Tensor) -> torch.Tensor:
        return X_j

    @property
    def out_dims(self) -> int:
        return self.mlp.out_dims


class GIN(nn.Module):
    def __init__(self, hp: HyperParam, residual: bool = False):
        super(GIN, self).__init__()
        self.layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList() if hp.gin_sample_aggregator_bn else None
        self.residual = residual
        
        in_dims = hp.pearl_mlp_out
        for _ in range(hp.n_sample_aggr_layers - 1):
            self.layers.append(GINLayer(in_dims, hp.sample_aggr_hidden_dims, hp))
            in_dims = hp.sample_aggr_hidden_dims
            if self.batch_norms is not None:
                self.batch_norms.append(nn.BatchNorm1d(hp.sample_aggr_hidden_dims))

        self.layers.append(GINLayer(hp.sample_aggr_hidden_dims, hp.pe_dims, hp))

    def forward(self, X: torch.Tensor, edge_index: torch.Tensor, mask=None) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            X0 = X
            X = layer(X, edge_index, mask=mask)
            if mask is not None:
                X[~mask] = 0
            if self.batch_norms is not None and i < len(self.layers) - 1:
                if mask is None:
                    X = self.batch_norms[i](X.transpose(2, 1)).transpose(2, 1) if X.ndim == 3 else self.batch_norms[i](X)
                else:address
                    X[mask] = self.batch_norms[i](X[mask])
            if self.residual:
                X = X + X0
        return X

    @property
    def out_dims(self) -> int:
        return self.layers[-1].out_dims


ModuleList(
  (0): MLPLayer(
    (linear): Linear(in_features=16, out_features=128, bias=True)
    (normalization): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (1): MLPLayer(
    (linear): Linear(in_features=128, out_features=128, bias=True)
    (normalization): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
    (dropout): Dropout(p=0.0, inplace=False)
  )
)