In [1]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))
os.chdir("../")

In [2]:
import json
import numpy as np
import torch

# GraphCast Global

In [3]:
grid_xy = np.load("data/era5_global/static/nwp_xy.npy")

In [4]:
grid_xy.shape

(2, 721, 1440)

In [5]:
grid_xy = np.load("data/era5_uk/static/nwp_xy.npy")

In [6]:
grid_xy.shape

(2, 65, 57)

# Investigate Icospheres

In [7]:
dataset_path = "data/era5_global/"
icospheres_path = os.path.join(dataset_path, "icospheres.json")
with open(icospheres_path, "r") as f:
    loaded_dict = json.load(f)
    icospheres = {
        key: (np.array(value) if isinstance(value, list) else value)
        for key, value in loaded_dict.items()
    }
    print(f"Opened pre-computed graph at {icospheres_path}.")

Opened pre-computed graph at data/era5_global/icospheres.json.


In [8]:
icospheres

{'vertices': array([], dtype=float64),
 'faces': array([], dtype=float64),
 'order_0_vertices': array([[-0.52573111,  0.85065081,  0.        ],
        [ 0.52573111,  0.85065081,  0.        ],
        [-0.52573111, -0.85065081,  0.        ],
        [ 0.52573111, -0.85065081,  0.        ],
        [ 0.        , -0.52573111,  0.85065081],
        [ 0.        ,  0.52573111,  0.85065081],
        [ 0.        , -0.52573111, -0.85065081],
        [ 0.        ,  0.52573111, -0.85065081],
        [ 0.85065081,  0.        , -0.52573111],
        [ 0.85065081,  0.        ,  0.52573111],
        [-0.85065081,  0.        , -0.52573111],
        [-0.85065081,  0.        ,  0.52573111]]),
 'order_0_faces': array([[ 0, 11,  5],
        [ 0,  5,  1],
        [ 0,  1,  7],
        [ 0,  7, 10],
        [ 0, 10, 11],
        [ 1,  5,  9],
        [ 5, 11,  4],
        [11, 10,  2],
        [10,  7,  6],
        [ 7,  1,  8],
        [ 3,  9,  4],
        [ 3,  4,  2],
        [ 3,  2,  6],
        [ 3,

In [9]:
order_5_face_centroid = icospheres["order_5_face_centroid"]

In [10]:
order_5_face_centroid.shape

(20480, 3)

In [11]:
order_5_face_centroid[[0, 59, 214]]

array([[-0.5296202 ,  0.84793042,  0.01458344],
       [-0.52314931,  0.84361806,  0.11940895],
       [-0.55266781,  0.75184675,  0.35885871]])

In [13]:
order_5_faces = icospheres["order_5_faces"]
order_5_faces

array([[    0,  2562,  2564],
       [  642,  2563,  2562],
       [  644,  2564,  2563],
       ...,
       [ 2561, 10241, 10240],
       [ 2559, 10239, 10241],
       [10240, 10241, 10239]])

In [18]:
l = [1, 2, 3]
torch.save(l, "test.pt")

In [20]:
loaded_l = torch.load("test.pt")
loaded_l

[1, 2, 3]

In [15]:
NODES_PER_LEVEL

[12, 42, 162, 642, 2562, 10242, 40962]

In [16]:
max_order = 6
NODES_PER_LEVEL = [icospheres[f"order_{order}_vertices"].shape[0] for order in range(max_order + 1)]

def get_node_level(node_idx):
    for order, max_node_idx in enumerate(NODES_PER_LEVEL):
        if node_idx < max_node_idx:
            return order
    return 6

print(get_node_level(30444))
print(get_node_level(544))

6
3


In [7]:
order_0_faces = icospheres["order_0_faces"]
order_0_faces.min(), order_0_faces.max()

(0, 11)

In [8]:
order_1_faces = icospheres["order_1_faces"]
order_1_faces.min(), order_1_faces.max()

(0, 41)

In [9]:
order_1_faces

array([[ 0, 12, 14],
       [11, 13, 12],
       [ 5, 14, 13],
       [12, 13, 14],
       [ 0, 14, 16],
       [ 5, 15, 14],
       [ 1, 16, 15],
       [14, 15, 16],
       [ 0, 16, 18],
       [ 1, 17, 16],
       [ 7, 18, 17],
       [16, 17, 18],
       [ 0, 18, 20],
       [ 7, 19, 18],
       [10, 20, 19],
       [18, 19, 20],
       [ 0, 20, 12],
       [10, 21, 20],
       [11, 12, 21],
       [20, 21, 12],
       [ 1, 15, 23],
       [ 5, 22, 15],
       [ 9, 23, 22],
       [15, 22, 23],
       [ 5, 13, 25],
       [11, 24, 13],
       [ 4, 25, 24],
       [13, 24, 25],
       [11, 21, 27],
       [10, 26, 21],
       [ 2, 27, 26],
       [21, 26, 27],
       [10, 19, 29],
       [ 7, 28, 19],
       [ 6, 29, 28],
       [19, 28, 29],
       [ 7, 17, 31],
       [ 1, 30, 17],
       [ 8, 31, 30],
       [17, 30, 31],
       [ 3, 32, 34],
       [ 9, 33, 32],
       [ 4, 34, 33],
       [32, 33, 34],
       [ 3, 34, 36],
       [ 4, 35, 34],
       [ 2, 36, 35],
       [34, 3

In [17]:
max_order = (
    len([key for key in icospheres.keys() if "faces" in key]) - 2
)
max_order

6

In [18]:
cum_nodes = 0
for order in range(max_order + 1):
    nodes = icospheres[f"order_{order}_vertices"].shape[0]
    cum_nodes += nodes
    print(f"Level {order}: {nodes}, Cumulative: {cum_nodes}")

Level 0: 12, Cumulative: 12
Level 1: 42, Cumulative: 54
Level 2: 162, Cumulative: 216
Level 3: 642, Cumulative: 858
Level 4: 2562, Cumulative: 3420
Level 5: 10242, Cumulative: 13662
Level 6: 40962, Cumulative: 54624


In [22]:
graphcast_grid = 1440 * 721
graphcast_grid

1038240

In [23]:
54624 / 1038240

0.05261211280628757

## Check global grid size

In [19]:
import glob

In [20]:
samples_path = os.path.join(dataset_path, "samples/train")
sample = np.load(glob.glob(f'{samples_path}/*')[0])
sample.shape

(2178, 48)

In [25]:
66 * 33

2178

In [26]:
162 / sample.shape[0]

0.0743801652892562

# Verify MEPS Graph

In [2]:
import torch
import os

graph_dir_path = "/vol/bitbucket/bet20/neural-lam/graphs/multiscale"
static_dir_path = "/vol/bitbucket/bet20/neural-lam/data/meps_example/static"

In [3]:
grid_features = torch.load(os.path.join(static_dir_path, "grid_features.pt"))
print(grid_features.shape)

mesh_features = torch.load(os.path.join(graph_dir_path, "mesh_features.pt"))
print(mesh_features[0].shape)

torch.Size([63784, 4])
torch.Size([6561, 2])


In [4]:
g2m_features = torch.load(os.path.join(graph_dir_path, "g2m_features.pt"))
g2m_edge_index = torch.load(os.path.join(graph_dir_path, "g2m_edge_index.pt"))

print(g2m_edge_index.shape)

gmin, gmax = g2m_edge_index[0].min(), g2m_edge_index[0].max()
mmin, mmax = g2m_edge_index[1].min(), g2m_edge_index[1].max()
print(gmin, gmax, mmin, mmax)
print(gmax - gmin + 1, mmax - mmin + 1)
print(g2m_features.shape[0])

# print(g2m_edge_index.min(dim=1, keepdim=True)[0])
# print(g2m_edge_index.min(dim=1, keepdim=True).values[0])

g2m_edge_index = g2m_edge_index - g2m_edge_index.min(dim=1, keepdim=True)[0]

torch.Size([2, 100656])
tensor(6561) tensor(70344) tensor(0) tensor(6560)
tensor(63784) tensor(6561)
100656


In [5]:
g2m_edge_index

tensor([[    0,     1,     2,  ..., 63781, 63782, 63783],
        [    0,     0,     0,  ...,  6560,  6560,  6560]])

In [6]:
num_rec = g2m_edge_index[1].max() + 1

In [7]:
g2m_edge_index[0] = (
    g2m_edge_index[0] + num_rec
)

print(g2m_edge_index)

tensor([[ 6561,  6562,  6563,  ..., 70342, 70343, 70344],
        [    0,     0,     0,  ...,  6560,  6560,  6560]])


# Sanity Check Graph

In [6]:
import numpy as np
import torch
import os

In [7]:
dir = "./data/era5_uk/static/nwp_xy.npy"
grid_xy = torch.tensor(np.load(dir))
_, lon, lat = grid_xy.shape

graph_name = "uk_graphcast"
graph_dir_path = os.path.join("graphs", graph_name)

In [8]:
print(lon, lat)
print("Local grid nodes", lon * lat)

X, Y = 1440, 721
fraction = (lon * lat) / (X * Y)
print("Global grid nodes", X * Y)
print("Local : Global", fraction)

57 65
Local grid nodes 3705
Global grid nodes 1038240
Local : Global 0.0035685390661118815


### Verify M2M Graph

In [9]:
N_MESH_NODES = 40962
N_MESH_EDGES = 327660 

mesh_features = torch.load(os.path.join(graph_dir_path, "mesh_features.pt"))
m2m_features = torch.load(os.path.join(graph_dir_path, "m2m_features.pt"))
m2m_edge_index = torch.load(os.path.join(graph_dir_path, "m2m_edge_index.pt"))

print(mesh_features[0].shape)
print(m2m_features[0].shape)
print(m2m_edge_index[0].shape)
print(m2m_edge_index[0].unique().shape)

print("Expected mesh nodes", N_MESH_NODES * fraction)
print("Expected mesh edges", N_MESH_EDGES * fraction)

torch.Size([149, 3])
torch.Size([1008, 4])
torch.Size([2, 1008])
torch.Size([149])
Expected mesh nodes 146.1744972260749
Expected mesh edges 1169.2675104022192


### Verify G2M Graph

In [10]:
N_G2M_EDGES = 1618746

g2m_features = torch.load(os.path.join(graph_dir_path, "g2m_features.pt"))
g2m_edge_index = torch.load(os.path.join(graph_dir_path, "g2m_edge_index.pt"))

print(g2m_features.shape)
print(g2m_edge_index.shape)
print(g2m_edge_index[0].unique().shape)
print(g2m_edge_index[1].unique().shape)
print(lat * lon)
print("Expected g2m edges", N_G2M_EDGES * fraction)

torch.Size([5321, 4])
torch.Size([2, 5321])
torch.Size([3705])
torch.Size([149])
3705
Expected g2m edges 5776.558339112344


### Verify M2G Graph

In [11]:
N_M2G_EDGES = 3114720

m2g_features = torch.load(os.path.join(graph_dir_path, "m2g_features.pt"))
m2g_edge_index = torch.load(os.path.join(graph_dir_path, "m2g_edge_index.pt"))

print(m2g_features.shape)
print(m2g_edge_index.shape)
print(m2g_edge_index[0].unique().shape)
print(m2g_edge_index[1].unique().shape)
print(lat * lon)
print("Expected m2g edges", N_M2G_EDGES * fraction)

torch.Size([11033, 4])
torch.Size([2, 11033])
torch.Size([149])
torch.Size([3705])
3705
Expected m2g edges 11115.0


### Verify against full graphcast

In [14]:
# num_grid_nodes = 
num_m2m_edges = 327660
num_m2g_edges = 3114720

```
g2m graph: Graph(num_nodes={'grid': 3705, 'mesh': 355},
      num_edges={('grid', 'g2m', 'mesh'): 7719},
      metagraph=[('grid', 'mesh', 'g2m')])
m2g graph: Graph(num_nodes={'grid': 3705, 'mesh': 355},
      num_edges={('mesh', 'm2g', 'grid'): 11094},
      metagraph=[('mesh', 'grid', 'm2g')])
mesh graph: Graph(num_nodes=355, num_edges=2584,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'x': Scheme(shape=(4,), dtype=torch.float32)})
```

# Verify Mesh Node Subset

In [15]:
import os
import json

import numpy as np
import torch

from graphcast_mesh import Graph


graph = "uk_graphcast"
dataset = "era5_uk"

graph_dir_path = os.path.join("graphs", graph)
# os.makedirs(graph_dir_path, exist_ok=True)

data_dir_path = os.path.join("data", dataset)
icosophere_path = os.path.join(data_dir_path, "icospheres.json")

nwp_xy_path = os.path.join(data_dir_path, "static", "nwp_xy.npy")
local_lat_lon_grid = np.load(nwp_xy_path) # (2, lon, lat) or (2, x, y)
local_lat_lon_grid = torch.from_numpy(local_lat_lon_grid).permute(2, 1, 0) # (lat, lon, 2)
print(f"Local area shape: {local_lat_lon_grid.shape}")
print(f"Opened lat lon grid at {nwp_xy_path}.")

  from .autonotebook import tqdm as notebook_tqdm


Local area shape: torch.Size([65, 57, 2])
Opened lat lon grid at data/era5_uk/static/nwp_xy.npy.


In [16]:
input_res = (721, 1440) # (lat, lon)
latitudes = torch.linspace(-90, 90, steps=input_res[0])
longitudes = torch.linspace(-180, 180, steps=input_res[1] + 1)[1:]
lat_lon_grid = torch.stack(
    torch.meshgrid(latitudes, longitudes, indexing="ij"), dim=-1
) # (lat, lon, 2)
print(f"Global area shape: {lat_lon_grid.shape}")

graph = Graph(icosophere_path, lat_lon_grid, local_lat_lon_grid)

Global area shape: torch.Size([721, 1440, 2])
Opened pre-computed graph at data/era5_uk/icospheres.json.


In [17]:
g2m_graph = graph.create_g2m_graph()

g2m graph: Graph(num_nodes={'grid': 3705, 'mesh': 355},
      num_edges={('grid', 'g2m', 'mesh'): 7719},
      metagraph=[('grid', 'mesh', 'g2m')])


  g2m_graph.dstdata["pos"] = torch.tensor(


In [18]:
m2g_graph = graph.create_m2g_graph()

m2g graph: Graph(num_nodes={'grid': 3705, 'mesh': 355},
      num_edges={('mesh', 'm2g', 'grid'): 11094},
      metagraph=[('mesh', 'grid', 'm2g')])


In [19]:
mesh_graph, mesh_pos = graph.create_mesh_graph(debug=True)

mesh graph: Graph(num_nodes=355, num_edges=2584,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'x': Scheme(shape=(4,), dtype=torch.float32)})


In [20]:
mesh_pos

tensor([[ 0.5257,  0.8507,  0.0000],
        [ 0.6817,  0.7166, -0.1476],
        [ 0.4906,  0.8640, -0.1134],
        ...,
        [ 0.6366,  0.7711,  0.0090],
        [ 0.6366,  0.7711, -0.0090],
        [ 0.6262,  0.7796,  0.0000]])

In [21]:
from graphcast_utils import xyz2latlon
mesh_coords = xyz2latlon(mesh_pos)

In [30]:
np.save(
    os.path.join(graph_dir_path, "mesh_pos.npy"),
    mesh_coords.numpy().T.astype("float32"),
)

In [25]:
mesh_coords

tensor([[  0.0000,  58.2825],
        [ -8.4891,  46.4277],
        [ -6.5119,  60.4080],
        [  4.0192,  52.7328],
        [ -4.0192,  52.7328],
        [  0.0000,  46.4277],
        [  0.0000,  61.5964],
        [  3.2587,  62.7772],
        [  3.1515,  59.3076],
        [ -3.1515,  59.3076],
        [ -3.2587,  62.7772],
        [-10.0731,  61.5915],
        [  1.9471,  55.6005],
        [ -1.9471,  55.6005],
        [ -5.2773,  56.5627],
        [ -9.8948,  50.4597],
        [ -7.5302,  53.6154],
        [ -6.2056,  49.6753],
        [ -8.8260,  57.6001],
        [  0.0000,  52.7328],
        [  4.2680,  46.4277],
        [  2.0758,  49.6753],
        [ -2.0758,  49.6753],
        [ -4.2680,  46.4277],
        [  0.0000,  59.9117],
        [  1.5761,  60.4529],
        [  1.5494,  58.7861],
        [  1.6022,  62.1764],
        [  0.0000,  63.3364],
        [  3.2057,  61.0133],
        [ -1.5494,  58.7861],
        [ -1.5761,  60.4529],
        [ -4.8883,  61.5940],
        [ 

In [23]:
lon, lat = mesh_coords[:, 0], mesh_coords[:, 1]

In [24]:
print(lon.min(), lon.max(), lat.min(), lat.max())

tensor(-10.6876) tensor(4.6971) tensor(46.4277) tensor(63.6808)
