In [1]:
import xarray as xr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch.utils.data import DataLoader

In [4]:
ds = xr.open_dataset("../data.grib")
t2m_numpy = ds.t2m.to_numpy()
sp_numpy = ds.sp.to_numpy()
data = np.stack((t2m_numpy, sp_numpy), axis=-1)
data.shape

Ignoring index file '../data.grib.923a8.idx' incompatible with GRIB file


(12, 25, 45, 2)

In [3]:
num_timestamps, num_latitudes, num_longitudes, num_features = data.shape
data = data.reshape(num_timestamps, -1, num_features)
data.shape

(12, 1125, 2)

In [4]:
def node_index(i, j, num_cols):
    return i * num_cols + j

edge_index = []
for i in range(num_latitudes):
    for j in range(num_longitudes):
        if i > 0:
            edge_index.append([node_index(i, j, num_longitudes), node_index(i - 1, j, num_longitudes)])
        if j > 0:
            edge_index.append([node_index(i, j, num_longitudes), node_index(i, j - 1, num_longitudes)])

grid_size = data.shape[1]
edge_index = torch.tensor(edge_index, dtype=torch.long).t()
edge_index.size()

torch.Size([2, 2180])

In [5]:
class WeatherGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(WeatherGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.fc(x)
        x = x.view(x.shape[0], grid_size, num_features)
        return x

input_dim = num_features
hidden_dim = 4096
output_dim = num_features # * grid_size
model = WeatherGNN(input_dim, hidden_dim, output_dim)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [6]:
train_ratio = 0.8
num_samples = data.shape[0]
train_size = int(train_ratio * num_samples)
train_data, val_data = data[:train_size], data[train_size:]

train_loader = DataLoader(train_data, batch_size=2, shuffle=True)
val_loader = DataLoader(val_data, batch_size=2, shuffle=False)

num_epochs = 100

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = criterion.to(device)
edge_index = edge_index.to(device)

In [8]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        x = batch.to(device)
        output = model(x, edge_index)
        loss = criterion(output, x) 
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.shape[0] # batch.num_graphs
    
    if (epoch+1) % 10 == 0:
        avg_loss = total_loss / len(train_loader.dataset)
        print(f"Epoch {epoch + 1}/{num_epochs}\nTrain Loss: {avg_loss:.4f}")
    
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch in val_loader:
            x = batch.to(device)
            output = model(x, edge_index)
            loss = criterion(output, x)
            val_loss += loss.item() * batch.shape[0] # batch.num_graphs
        
        if (epoch+1) % 10 == 0:
            avg_val_loss = val_loss / len(val_loader.dataset)
            print(f"Val Loss: {avg_val_loss:.4f}\n---------")

Epoch 10/100
Train Loss: 15571295.5000
Val Loss: 20061834.0000
---------
Epoch 20/100
Train Loss: 5770589.0556
Val Loss: 5776851.1667
---------
Epoch 30/100
Train Loss: 5986333.7778
Val Loss: 5770451.3333
---------
Epoch 40/100
Train Loss: 7275906.7778
Val Loss: 7364385.8333
---------
Epoch 50/100
Train Loss: 5736275.6111
Val Loss: 5771170.5000
---------
Epoch 60/100
Train Loss: 5729458.6111
Val Loss: 5761696.5000
---------
Epoch 70/100
Train Loss: 5737716.5556
Val Loss: 5774022.5000
---------
Epoch 80/100
Train Loss: 5737507.1667
Val Loss: 5793409.8333
---------
Epoch 90/100
Train Loss: 5732635.4444
Val Loss: 5770974.3333
---------
Epoch 100/100
Train Loss: 5835633.1111
Val Loss: 5803538.1667
---------
