"From Coarse to Fine: Building and Training Graph Neural Networks with Anemoi-Graphs"

In [1]:
import einops
import matplotlib.pyplot as plt
import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import HeteroData
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
from IPython.display import IFrame

from anemoi.graphs.edges import KNNEdges, CutOffEdges
from anemoi.graphs.edges.attributes import GaussianDistanceWeights
from anemoi.graphs.nodes import ReducedGaussianGridNodes
from anemoi.graphs.inspect import GraphInspector

KeyboardInterrupt: 

## 1. Create simple graph with Anemoi-Graphs

We'll use `ReducedGaussianGridNodes` from `anemoi.graphs.nodes` to create two sets of nodes:
- Coarse grid: o48
- Fine grid: o96

These represent two spatial resolutions for our graph.

In [None]:
# Create coarse and fine grid nodes
coarse_node_builder = ReducedGaussianGridNodes('o48', name='input')
fine_node_builder = ReducedGaussianGridNodes('o96', name='target')

graph = HeteroData()
graph = coarse_node_builder.update_graph(graph)
graph = fine_node_builder.update_graph(graph)
print(graph)

In [None]:
# Create edges from coarse to fine nodes
edge_builder = KNNEdges(num_nearest_neighbours=4, source_name='input', target_name='target')
# Alternatively, use CutOffEdges
# edge_builder = CutOffEdges(cutoff_factor=0.7, source_name='input', target_name='target')

graph = edge_builder.update_graph(graph)
print(graph)

## 2. Graph Inspection

-This section demonstrates how to inspect and visualize the graph structure created with anemoi-Graphs.  
- Using `GraphInspector` from `anemoi.graphs.inspect`, you can interactively explore node types, edge connections, and graph attributes.  
- Example usage is provided to print a summary and plot nodes/edges.
- For command-line inspection, save your graph with `torch.save(graph, "graph.pt")` and use the CLI tool `anemoi-graphs inspect graph.pt` for a summary and visualization directly in the terminal.


In [None]:
os.makedirs("graphs", exist_ok=True)
torch.save(graph, "graphs/my_first_graph.pt")
GraphInspector("graphs/my_first_graph.pt", "interactive_plots/").inspect()
plt.close('all')

In [None]:
IFrame('interactive_nodes/input_to_target.html', width=600, height=400)

## 3. Utility functions

- train_model()
- plot_loss_curve()
- plot_sample()



In [None]:
class DownscalingModel(nn.Module):
    def __init__(self, gnn, graph):
        super().__init__()
        self.gnn = gnn
        self.graph = graph
    
    def forward(self, x_src):
        # We suppose batch size = 1 for simplicity
        # x_src: [num_coarse_nodes, in_channels_src]
        # x_dst: [num_fine_nodes, in_channels_dst]
        assert x_src.shape[0] == 1, "Batch size greater than 1 not supported in this example."
        out = self.gnn(
            x_src=x_src[0, ...].to(torch.float32),
            x_dst=self.graph['target'].x.to(torch.float32),
            edge_index=self.graph['input', 'to', 'target'].edge_index.to(torch.int64),
            edge_attr=None,
        )
        return out

In [None]:
class DummyDataset(Dataset):
    def __init__(self, num_samples: int, graph: HeteroData):
        self.num_samples = num_samples
        self.graph = graph
        self.proj_matrix = self.build_interp_matrix(graph)

    @stat
    def num_variables(self):
        return 1

    def build_interp_matrix(self, graph: HeteroData):
        edge_builder = KNNEdges(num_nearest_neighbours=3, source_name='target', target_name='input')
        graph = edge_builder.update_graph(graph)
        weights = GaussianDistanceWeights(norm="l1")(x=(graph["target"], graph["input"]), edge_index=graph["target", "to", "input"].edge_index)

        interp_matrix = torch.sparse_coo_tensor(
            graph['target', 'to', 'input'].edge_index,
            weights.squeeze(),
            (graph['target'].num_nodes, graph['input'].num_nodes),
            device=graph['target', 'to', 'input'].edge_index.device,
        )
        return interp_matrix.coalesce().T

    def create_random_2d_sine_wave_field(self):
        sine_wave = (
            np.sin(
                10 * np.random.rand() * self.graph["target"].x[:, 0]
            ) * 
            np.cos(
                10 * np.random.rand() * self.graph["target"].x[:, 1]
            )
        )
        return sine_wave.to(torch.float32).unsqueeze(-1)
    
    def interpolate_to_coarse(self, fine_field):
        coarse_field = torch.sparse.mm(self.proj_matrix, fine_field)
        return coarse_field

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        x_fine = self.create_random_2d_sine_wave_field()
        x_coarse = self.interpolate_to_coarse(x_fine)
        return x_coarse, x_fine

dataset = DummyDataset(num_samples=100, graph=graph)

In [None]:
def train(gnn, dataset, epochs: int, steps_per_epoch: int) -> list[float]:
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    model = DownscalingModel(gnn=gnn, graph=graph)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.MSELoss()

    train_losses = []
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0.0
        for i, (x_coarse, x_fine) in enumerate(dataloader):
            optimizer.zero_grad()
            y_pred = model(x_coarse)  # shape: (batch, len_fine, num_vars)
            loss = criterion(y_pred, x_fine)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * x_coarse.size(0)
            if i + 1 >= steps_per_epoch:
                break
        epoch_loss /= steps_per_epoch
        train_losses.append(epoch_loss)
        print(f"Epoch {epoch}, Loss: {epoch_loss:.4f}")
    return model, train_losses

In [None]:
def plot_loss_curve(train_losses):
    plt.figure(figsize=(8, 4))
    plt.plot(train_losses, marker='o')
    plt.xlabel("Epoch")
    plt.ylabel("Train Loss")
    plt.title("Training Loss Curve")
    plt.grid(True)
    plt.show()

def plot_sample(model, sample):
    x_coarse, x_fine = sample
    y_pred = model(x_coarse.unsqueeze(0)).detach().cpu()

    plt.figure(figsize=(12, 12))

    # Input
    plt.subplot(2, 2, 1)
    plt.scatter(coarse_lon_vals, coarse_lat_vals, c=x_coarse.numpy().squeeze(), cmap='viridis')
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.title("Input (Coarse)")
    plt.colorbar(label="Value")

    # Target
    plt.subplot(2, 2, 2)
    plt.scatter(lon_vals, lat_vals, c=x_fine.numpy().squeeze(), cmap='viridis')
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.title("Target (Fine)")
    plt.colorbar(label="Value")

    # Prediction
    plt.subplot(2, 2, 3)
    plt.scatter(lon_vals, lat_vals, c=y_pred.numpy().squeeze(), cmap='viridis')
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.title("Prediction")
    plt.colorbar(label="Predicted Value")

    # Error
    plt.subplot(2, 2, 4)
    plt.scatter(lon_vals, lat_vals, c=(y_pred - x_fine).numpy().squeeze(), cmap='coolwarm')
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.title("Error (Prediction - Target)")
    plt.colorbar(label="Error")

    plt.tight_layout()
    plt.show()

## 4. Build a Simple GNN Model

Now, let's build a simple Graph Neural Network using PyTorch Geometric (or similar). We'll use dummy node features and labels for demonstration.

In [None]:
class BipartiteGNN(MessagePassing):
    def __init__(
        self,
        in_channels_src: int,
        in_channels_dst: int,
        hidden_dim: int,
        out_channels: int,
        edge_attr_dim: int = 0,
    ):
        super().__init__(aggr='add')
        self.lin_src = torch.nn.Linear(in_channels_src, hidden_dim)
        self.lin_dst = torch.nn.Linear(in_channels_dst, hidden_dim)
        self.lin_edges = torch.nn.Linear(edge_attr_dim, hidden_dim) if edge_attr_dim > 0 else None
        self.projection = nn.Linear(hidden_dim, out_channels)

    def forward(self, x_src, x_dst, edge_index, edge_attr=None):
        # x_src: [num_src_nodes, in_channels_src]
        # x_dst: [num_dst_nodes, in_channels_dst]
        # edge_index: [2, num_edges] (from src to dst)
        out = self.propagate(
            x=(x_src, x_dst),
            edge_index=edge_index.to(torch.int64),
            edge_attr=edge_attr
        )

        return self.projection(out)

    def message(self, x_j, edge_attr=None):
        # x_j: source node features
        if edge_attr is not None:
            return self.lin_src(x_j) + self.lin_edges(edge_attr)
        return self.lin_src(x_j)

    def update(self, aggr_out, x):
        # x: tuple (x_src, x_dst)
        x_dst = x[1]
        return aggr_out + self.lin_dst(x_dst)

## 5. Train the GNN on Dummy Data

Let's train the model for a few epochs and observe the loss. This is a demonstration with random data.

In [None]:
# Example training loop for DownscalingModel with DataLoader
NUM_EPOCHS = 10
STEPS_PER_EPOCH = 100
HIDDEN_DIM = 16

gnn = BipartiteGNN(
    in_channels_src=dataset.num_variables, 
    in_channels_dst=2,
    hidden_dim=HIDDEN_DIM,
    out_channels=dataset.num_variables,
    edge_attr_dim=0,
)
model, train_losses = train(gnn, dataset, epochs=NUM_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)


In [None]:
plot_loss_curve(train_losses)

In [None]:
plot_sample(model, dataset[0])

## 6. Extra I



In [None]:
class GraphTransformerNN(MessagePassing):
    def __init__(
        self,
        in_channels_src: int,
        in_channels_dst: int,
        hidden_dim: int,
        out_channels: int,
        num_heads: int = 1,
        edge_attr_dim: int = 0,
        qk_norm: bool = False,
        mlp_hidden_ratio: int = 4,
    ):
        super().__init__(aggr='add')
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.mlp_hidden_ratio = mlp_hidden_ratio
        self.qk_norm = qk_norm

        # Define layers
        self.layer_norm_attention_src = nn.LayerNorm(normalized_shape=in_channels_src)
        self.layer_norm_attention_dst = nn.LayerNorm(normalized_shape=in_channels_dst)

        self.lin_key = nn.Linear(in_channels_src, self.num_heads * self.hidden_dim)
        self.lin_query = nn.Linear(in_channels_dst, self.num_heads * self.hidden_dim)
        self.lin_value = nn.Linear(in_channels_src, self.num_heads * self.hidden_dim)

        self.lin_self = nn.Linear(in_channels_dst, self.num_heads * self.hidden_dim)
        self.projection = nn.Linear(self.num_heads * self.hidden_dim, self.num_heads * self.hidden_dim)

        self.lin_edge = nn.Linear(edge_attr_dim, self.num_heads * self.hidden_dim)

        self.node_data_extractor = nn.Sequential(
            nn.LayerNorm(normalized_shape=self.num_heads * self.hidden_dim),
            nn.Linear(self.num_heads * self.hidden_dim, self.mlp_hidden_ratio * self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.mlp_hidden_ratio * self.hidden_dim, out_channels),
        )

    def forward(self, x_src, x_dst, edge_index, edge_attr):
        x_src = self.layer_norm_attention_src(x_src)
        x_dst = self.layer_norm_attention_dst(x_dst)
        
        query, key, value, edge_attr = self.get_qkve(x_src, x_dst, edge_attr)
        # query: [num_dst_nodes, num_heads x hidden_dim]
        # key: [num_src_nodes, num_heads x hidden_dim]
        # value: [num_src_nodes, num_heads x hidden_dim]

        x_res = self.lin_self(x_dst)
        # x_res: [num_dst_nodes, num_heads x hidden_dim]
    
        out = self.propagate(
            edge_index=edge_index,
            size=(x_src.size(0), x_dst.size(0)),
            edge_attr=edge_attr,
            query=query,
            key=key,
            value=value,
        )
        # out: [num_dst_nodes, num_heads x hidden_dim]

        out = self.projection(out + x_res)
        # out: [num_dst_nodes, num_heads x hidden_dim]

        out = self.node_data_extractor(out)
        # out: [num_dst_nodes, out_channels]

        return out

    def get_qkve(self, x_src, x_dst, edge_attr):
        query = self.lin_query(x_dst)
        key = self.lin_key(x_src)
        value = self.lin_value(x_src)

        if self.qk_norm:
            query = self.q_norm(query)
            key = self.k_norm(key)
    
        if edge_attr is not None:
            edge_attr = self.lin_edge(edge_attr)

        return query, key, value, edge_attr

    def message(
        self,
        query_i: torch.Tensor,
        key_j: torch.Tensor,
        value_j: torch.Tensor,
        edge_attr: torch.Tensor,
        index: torch.Tensor,
        ptr: torch.Tensor,
        size_i: int | None,
    ) -> torch.Tensor:
        query_i = einops.rearrange(query_i, "edges (heads vars) -> edges heads vars", heads=self.num_heads)
        key_j = einops.rearrange(key_j, "edges (heads vars) -> edges heads vars", heads=self.num_heads)
        value_j = einops.rearrange(value_j, "edges (heads vars) -> edges heads vars", heads=self.num_heads)
        # query_i, key_j, value_j: [num_edges, num_heads, hidden_dim]

        if edge_attr is not None:
            edge_attr = einops.rearrange(edge_attr, "edges (heads vars) -> edges heads vars", heads=self.num_heads)
            key_j = key_j + edge_attr

        # Compute attention coefficients
        alpha = (query_i * key_j).sum(dim=-1) / self.hidden_dim ** 0.5
        alpha = softmax(alpha, index, ptr, size_i)
        # alpha: [num_edges, num_heads]

        if edge_attr is not None:
            value_j = value_j + self.lin_edge(edge_attr)

        out = value_j * alpha.view(-1, self.num_heads, 1)
        # out: [num_edges, num_heads, hidden_dim]

        out = einops.rearrange(out, "edges heads vars -> edges (heads vars)")
        # out: [num_edges, num_heads x hidden_dim]
        
        return out

    def update(self, aggr_out):
        return aggr_out


# DEBUG: You can use this code to debug the forward pass
gnn = GraphTransformerNN(
    in_channels_src=2, 
    in_channels_dst=2,
    hidden_dim=16,
    out_channels=6,
    edge_attr_dim=1,
    num_heads=3,
)
output = gnn(
    x_src=graph['input'].x, 
    x_dst=graph['target'].x,
    edge_index=graph[('input', 'to', 'target')].edge_index.to(torch.int64),
    edge_attr=None
)
print(output.shape)  # Should be [num_target_nodes, out_channels]

In [None]:
gnn = GraphTransformerNN(
    in_channels_src=dataset.num_variables, 
    in_channels_dst=2,
    hidden_dim=HIDDEN_DIM,
    out_channels=dataset.num_variables,
    edge_attr_dim=0,
)
model, train_losses = train(gnn, dataset, epochs=NUM_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)

In [None]:
plot_loss_curve(train_losses)

In [None]:
plot_sample(model, dataset[0])