In [1]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))
os.chdir("../")

In [2]:
import os
import json
import logging

import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from torch import Tensor

from graphcast_utils import (
    add_edge_features,
    add_node_features,
    cell_to_adj,
    create_graph,
    create_heterograph,
    get_edge_len,
    latlon2xyz,
    xyz2latlon,
    find_subset_indices,
)

# Fully Connected Graph

In [11]:
# This makes an example bipartite graph
# each fine mesh node is connected to all coarse mesh nodes

n_fine_mesh = torch.arange(4)
n_coarse_mesh = torch.arange(5)
grid_m, grid_n = torch.meshgrid(n_coarse_mesh, n_fine_mesh, indexing='ij')

# Stack and reshape to get a 2D tensor with pairs (i, j)
# Coarse2Fine Edges
edges = torch.stack([grid_m.reshape(-1), grid_n.reshape(-1)], dim=1).T

In [12]:
edges

tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]])

# Randomly Connected Graph

In [13]:
# this makes a randomly connected bipartite graph
N_coarse = 10
N_fine = 5
coarse_idxs = torch.randint(0, N_coarse, (6,))
fine_idxs = torch.randint(0, N_fine, (6,))

In [14]:
edge_index = torch.stack([coarse_idxs, fine_idxs])
idx_offset = torch.tensor([[N_coarse], [N_fine]])

edge_index_batch = torch.cat([
    edge_index + idx_offset * i for i in range(2)
], dim=1)

print(edge_index)
print(edge_index_batch)

tensor([[0, 2, 6, 9, 4, 8],
        [3, 2, 4, 4, 3, 3]])
tensor([[ 0,  2,  6,  9,  4,  8, 10, 12, 16, 19, 14, 18],
        [ 3,  2,  4,  4,  3,  3,  8,  7,  9,  9,  8,  8]])


# Check Bipartite Graph Operation

In [15]:
import torch
from torch_geometric.nn import MessagePassing

class BipartiteGraphOperator(MessagePassing):
        def __init__(self):
            super(BipartiteGraphOperator, self).__init__('add')
   
        def forward(self, x, assign_index, N, M):
            return self.propagate(assign_index, size=(N, M), x=x)

conv = BipartiteGraphOperator()

x = torch.Tensor([[1], [2], [3]])
row = torch.tensor([0, 1, 2, 0, 1, 2])
col = torch.tensor([0, 0, 0, 1, 1, 1])
index = torch.stack([row, col], dim=0)

out = conv(x, index, N=3, M=2)
print(out.size())
print(out)

torch.Size([2, 1])
tensor([[6.],
        [6.]])


In [16]:
class SumGraphOperator(MessagePassing):
    def __init__(self):
        super(SumGraphOperator, self).__init__('add')

    def forward(self, x, edge_index):
        if isinstance(x, tuple):
            return self.propagate(
                edge_index, 
                size=(x[0].size(0), x[1].size(0)), 
                x=x
            )
        return self.propagate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        return x_j
    
    def update(self, aggr_out, x):
        if isinstance(x, tuple):
            return x[0], aggr_out + x[1]
        return aggr_out + x
    
conv = SumGraphOperator()
x = torch.Tensor([[1], [2], [3]])
y = torch.Tensor([[4], [5], [6]])
# row = torch.tensor([0, 1, 2, 0, 1, 2])
# col = torch.tensor([0, 0, 0, 1, 1, 1])
# index = torch.stack([row, col], dim=0)

index = torch.tensor([
    [0, 0, 1, 1, 2, 2],
    [0, 2, 0, 1, 1, 2], 
])

# out = conv((x, y), index)
out = conv((x, y), index)
# print(out.size())
print(out[0])
print(out[1])

tensor([[1.],
        [2.],
        [3.]])
tensor([[ 7.],
        [10.],
        [10.]])


Conclusion:

conv((src, dst), index)

index = [
    [src1, src2, ...],
    [dst1, dst2, ...]
]

In [24]:
index

tensor([[0, 1, 2, 0, 1, 2],
        [0, 0, 0, 1, 1, 1]])

# Transformer Conv

# Multiscale Graph

In [3]:
big_dataset = "era5_uk_big"
small_dataset = "era5_uk_small"

big_grid_xy = np.load(f"data/{big_dataset}/static/nwp_xy.npy")
small_grid_xy = np.load(f"data/{small_dataset}/static/nwp_xy.npy")