In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing


class GeometricGNN(MessagePassing):
    def __init__(self, in_channels, out_channels, threshold=0.5):
        super(GeometricGNN, self).__init__(aggr='add')  # Aggregation type ('add' means sum aggregation)
        self.linear = nn.Linear(in_channels, out_channels)
        self.layer_norm = nn.LayerNorm(out_channels)
        self.threshold = threshold
        self.relu = nn.ReLU()
    
    def forward(self, x, edge_index, edge_attr, pos):
        # x: Node features (num_nodes, in_channels)
        # edge_index: Graph connectivity (2, num_edges)
        # edge_attr: Edge features (num_edges, edge_features)
        # pos: Node positions for geometric computation (num_nodes, 2 or 3)
        
        # Perform message passing
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, pos=pos)

    def message(self, x_j, edge_attr, pos_i, pos_j):
        # x_j: Neighbor node features
        # edge_attr: Edge features
        # pos_i, pos_j: Positions of node i and its neighbor j
        
        # Compute Euclidean distance between node i and its neighbors j
        dist = torch.norm(pos_i - pos_j, p=2, dim=-1)
        
        # Filter edges based on thresholded geometric distance
        mask = dist < self.threshold
        
        # Message passing: edge features contribute to messages if mask is True
        message = edge_attr * mask.float().unsqueeze(1)
        
        return message

    def aggregate(self, inputs, index, x):
        # Perform the aggregation step by summing the messages from neighboring nodes
        num_neighbors = torch.bincount(index, minlength=x.size(0)).float().clamp(min=1)  # Avoid division by zero
        aggregated_message = torch.zeros_like(x)
        aggregated_message.index_add_(0, index, inputs)  # Aggregate messages
        return aggregated_message / num_neighbors.unsqueeze(1)  # Normalize by the number of neighbors
    
    def update(self, aggr_out, x):
        # Concatenate the original node features with the aggregated message
        concat_features = torch.cat([x, aggr_out], dim=1)
        
        # Linear transformation, LayerNorm, and ReLU activation
        out = self.linear(concat_features)
        out = self.layer_norm(out)
        out = self.relu(out)
        
        return out

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.triplet_margin_loss = nn.TripletMarginLoss(margin=margin)

    def forward(self, anchor, positive, negative):
        # anchor, positive, negative: Node embeddings for triplet loss
        loss = self.triplet_margin_loss(anchor, positive, negative)
        return loss

# Sample GNN model combining the geometric GNN layer and contrastive learning
class GeometricGNNModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, threshold=0.5):
        super(GeometricGNNModel, self).__init__()
        self.gnn_layer = GeometricGNN(in_channels, hidden_channels, threshold)
        self.out_linear = nn.Linear(hidden_channels, out_channels)
        self.contrastive_loss_fn = ContrastiveLoss(margin=1.0)

    def forward(self, data):
        # data: A PyTorch Geometric data object containing x, edge_index, edge_attr, pos
        x, edge_index, edge_attr, pos = data.x, data.edge_index, data.edge_attr, data.pos
        
        # Apply the GNN layer
        node_embeddings = self.gnn_layer(x, edge_index, edge_attr, pos)
        
        # Final output node embeddings after an additional linear transformation
        node_embeddings = self.out_linear(node_embeddings)
        
        return node_embeddings

    def compute_contrastive_loss(self, anchor, positive, negative):
        # Calculate contrastive loss using triplet margin loss
        return self.contrastive_loss_fn(anchor, positive, negative)

# Example training loop
def train(model, data, optimizer):
    model.train()
    
    # Forward pass to get node embeddings
    optimizer.zero_grad()
    node_embeddings = model(data)
    
    # Assume anchor, positive, and negative samples are provided
    # These are typically derived from graph augmentation or pre-specified node sets
    anchor, positive, negative = get_triplet_samples(node_embeddings)  # Implement this function

    # Compute contrastive loss
    loss = model.compute_contrastive_loss(anchor, positive, negative)
    
    # Backpropagation
    loss.backward()
    optimizer.step()
    
    return loss.item()

# Helper function to create triplet samples for contrastive learning
def get_triplet_samples(node_embeddings):
    # Placeholder for actual triplet generation logic
    anchor = node_embeddings[0:10]
    positive = node_embeddings[10:20]
    negative = node_embeddings[20:30]
    return anchor, positive, negative

# Usage
model = GeometricGNNModel(in_channels=32, hidden_channels=64, out_channels=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [3]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
import torch.nn.functional as F

class SimpleGNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SimpleGNN, self).__init__(aggr='add')  # "Add" aggregation (sum).
        self.linear = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix with shape [2, num_edges]
        
        # Apply a linear transformation to node features
        x = self.linear(x)
        
        # Start message passing
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # x_j: Neighboring node features. x_j represents the features of node j
        return x_j  # Pass along neighbor's feature as the message

    def update(self, aggr_out):
        # aggr_out: The aggregated messages from neighbors for each node
        return F.relu(aggr_out)  # Apply ReLU activation after aggregation


In [12]:
# Create sample node features: 4 nodes with 3-dimensional features
x = torch.tensor([[1, 2], 
                  [4, 5], 
                  [7, 8], 
                  [10, 11]], dtype=torch.float)

# Define the edges in the graph (in COO format)
# Each edge is represented as (source, target)
edge_index = torch.tensor([[0, 1, 2, 3],  # source nodes
                           [1, 0, 3, 2]], # target nodes
                          dtype=torch.long)

# Create a PyTorch Geometric Data object
data = Data(x=x, edge_index=edge_index)

In [8]:
data

Data(x=[4, 3], edge_index=[2, 4])

In [14]:
# Instantiate the GNN layer
gnn_layer = SimpleGNN(in_channels=2, out_channels=3)  # 3 input features, 2 output features

# Apply the GNN layer to the graph data
out = gnn_layer(x=data.x, edge_index=data.edge_index)

print("Updated node features after GNN layer:")
print(out)

Updated node features after GNN layer:
tensor([[0.0000, 0.0000, 4.3526],
        [0.4707, 0.0000, 1.6616],
        [0.0000, 0.0000, 9.7346],
        [0.0000, 0.0000, 7.0436]], grad_fn=<ReluBackward0>)


In [16]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
import torch.nn.functional as F
import torch.optim as optim

# Define a simple custom GNN layer using MessagePassing for node classification
class SimpleGNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SimpleGNN, self).__init__(aggr='add')  # "Add" aggregation (sum).
        self.linear = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # Apply a linear transformation to node features
        x = self.linear(x)
        
        # Start message passing
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # Pass along the neighboring node's features as the message
        return x_j

    def update(self, aggr_out):
        # Apply ReLU activation to the aggregated messages
        return F.relu(aggr_out)

# Define a simple model that uses two GNN layers
class GNNModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNNModel, self).__init__()
        self.gnn1 = SimpleGNN(in_channels, hidden_channels)  # First GNN layer
        self.gnn2 = SimpleGNN(hidden_channels, out_channels)  # Second GNN layer

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # First GNN layer
        x = self.gnn1(x, edge_index)
        
        # Second GNN layer (outputs logits for each class)
        x = self.gnn2(x, edge_index)
        
        return F.log_softmax(x, dim=1)  # Log-softmax for classification

# Create sample node features (4 nodes, 3 features per node)
x = torch.tensor([[1, 2, 3], 
                  [4, 5, 6], 
                  [7, 8, 9], 
                  [10, 11, 12]], dtype=torch.float)

# Define the edges in the graph (COO format)
edge_index = torch.tensor([[0, 1, 2, 3],  # source nodes
                           [1, 0, 3, 2]], # target nodes
                          dtype=torch.long)

# Define the labels for each node (node classification task)
# Assume we want to classify nodes into 2 classes (0 or 1)
y = torch.tensor([0, 1, 0, 1], dtype=torch.long)

# Create the data object (graph + features + labels)
data = Data(x=x, edge_index=edge_index, y=y)

# Instantiate the model, optimizer, and loss function
model = GNNModel(in_channels=3, hidden_channels=4, out_channels=2)  # 2 output classes
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.NLLLoss()  # Negative log likelihood loss (since we use log_softmax)

# Training loop
def train():
    model.train()
    optimizer.zero_grad()  # Clear gradients
    out = model(data)  # Forward pass (get predictions)
    loss = criterion(out, data.y)  # Compute the loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update the parameters
    return loss.item()

# Evaluate the model (compute accuracy)
def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # Get the index of the max log-probability
    correct = pred.eq(data.y).sum().item()  # Count correct predictions
    accuracy = correct / data.num_nodes  # Accuracy
    return accuracy

# Train the model for multiple epochs
for epoch in range(100):
    loss = train()
    accuracy = test()
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

Epoch 1, Loss: 0.6931, Accuracy: 0.5000
Epoch 2, Loss: 0.6931, Accuracy: 0.5000
Epoch 3, Loss: 0.6931, Accuracy: 0.5000
Epoch 4, Loss: 0.6931, Accuracy: 0.5000
Epoch 5, Loss: 0.6931, Accuracy: 0.5000
Epoch 6, Loss: 0.6931, Accuracy: 0.5000
Epoch 7, Loss: 0.6931, Accuracy: 0.5000
Epoch 8, Loss: 0.6931, Accuracy: 0.5000
Epoch 9, Loss: 0.6931, Accuracy: 0.5000
Epoch 10, Loss: 0.6931, Accuracy: 0.5000
Epoch 11, Loss: 0.6931, Accuracy: 0.5000
Epoch 12, Loss: 0.6931, Accuracy: 0.5000
Epoch 13, Loss: 0.6931, Accuracy: 0.5000
Epoch 14, Loss: 0.6931, Accuracy: 0.5000
Epoch 15, Loss: 0.6931, Accuracy: 0.5000
Epoch 16, Loss: 0.6931, Accuracy: 0.5000
Epoch 17, Loss: 0.6931, Accuracy: 0.5000
Epoch 18, Loss: 0.6931, Accuracy: 0.5000
Epoch 19, Loss: 0.6931, Accuracy: 0.5000
Epoch 20, Loss: 0.6931, Accuracy: 0.5000
Epoch 21, Loss: 0.6931, Accuracy: 0.5000
Epoch 22, Loss: 0.6931, Accuracy: 0.5000
Epoch 23, Loss: 0.6931, Accuracy: 0.5000
Epoch 24, Loss: 0.6931, Accuracy: 0.5000
Epoch 25, Loss: 0.6931, A

In [18]:
out = model(data)
out

tensor([[-0.6931, -0.6931],
        [-0.6931, -0.6931],
        [-0.6931, -0.6931],
        [-0.6931, -0.6931]], grad_fn=<LogSoftmaxBackward0>)

In [24]:
data.x

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])