In [1]:
%cd /home/ltchen/gnnpp
import sys
import os
import pytorch_lightning as L
import torch
import torch_geometric
import json
import wandb

from typing import Tuple
from torch_geometric.nn import GATv2Conv
from torch_geometric.utils import scatter
from torch.nn import Linear, ModuleList, ReLU
from torch_geometric.loader import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.optim import AdamW
from pytorch_lightning.loggers import WandbLogger

from models.loss import NormalCRPS
from models.model_utils import MakePositive, EmbedStations
from utils.data import (
    load_dataframes,
    load_distances,
    normalize_features_and_create_graphs,
    rm_edges,
    summary_statistics,
)
from exploration.graph_creation import *
from models.graphensemble.multigraph import *

/home/ltchen/gnnpp


- SAVEPATH for model saving
- JSONPATH for parameters
- RESULTPATH for test results (f.txt, f_results.csv, rf.txt., rf_results.csv)

# 24h Leadtime Graphs

In [2]:
leadtime = "24h"

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
DIRECTORY = os.getcwd()
SAVEPATH = os.path.join(DIRECTORY, f"leas_final_models/gnn_run4_{leadtime}/models")
JSONPATH = os.path.join(DIRECTORY, f"trained_models/best_{leadtime}/params.json")

with open(JSONPATH, "r") as f:
    print(f"[INFO] Loading {JSONPATH}")
    args_dict = json.load(f)
config = args_dict

# from gnn_run3 ###############################
max_epoch_list = {
    'g1': 31,
    'g2': 26,
    'g3': 31,
    'g4': 32,
    'g5': 23,
}

[INFO] Loading /home/ltchen/gnnpp/trained_models/best_24h/params.json


In [None]:
'''{"batch_size":8,
"gnn_hidden":265,
"gnn_layers":2,
"heads":8,
"lr":0.0002,
"max_dist":100,
"max_epochs": 31}'''

In [3]:
dataframes = load_dataframes(mode="eval", leadtime=leadtime)
dataframes = summary_statistics(dataframes)

[INFO] Dataframes exist. Will load pandas dataframes.
[INFO] Calculating summary statistics for train
[INFO] Calculating summary statistics for test_rf
[INFO] Calculating summary statistics for test_f


## Graph 1

In [4]:
graph_name = "g1"
graphs1_train_rf, tests1 = normalize_features_and_create_graphs1(df_train=dataframes['train'], df_valid_test=[dataframes['test_rf'], dataframes['test_f']], station_df=dataframes['stations'], attributes=["geo"], edges=[("geo", 100)], sum_stats = True)
graphs1_test_rf, graphs1_test_f = tests1

g1_train_loader = DataLoader(graphs1_train_rf, batch_size=config['batch_size'], shuffle=True)
g1_test_f_loader = DataLoader(graphs1_test_f, batch_size=config['batch_size'], shuffle=False)
g1_test_rf_loader = DataLoader(graphs1_test_rf, batch_size=config['batch_size'], shuffle=False)

train_loader = g1_train_loader
test_f_loader = g1_test_f_loader
test_rf_loader = g1_test_rf_loader
test_loader = [test_f_loader, test_rf_loader]

emb_dim = 20
in_channels = graphs1_train_rf[0].x.shape[1] + emb_dim - 1
edge_dim = graphs1_train_rf[0].num_edge_features
max_epochs = max_epoch_list[graph_name]

# embedding_dim = emb_dim
# in_channels = in_channels
# hidden_channels_gnn = config['gnn_hidden']
# out_channels_gnn = config['gnn_hidden']
# num_layers_gnn = config['gnn_hidden']
# heads = config['heads']
# hidden_channels_deepset = config['gnn_hidden']
# optimizer_class = AdamW
# optimizer_params = dict(lr=config['lr'])

[INFO] Normalizing features...
fit_transform
transform 1
transform 2


100%|██████████| 3448/3448 [00:16<00:00, 211.07it/s]
100%|██████████| 732/732 [00:02<00:00, 279.00it/s]
100%|██████████| 730/730 [00:02<00:00, 251.85it/s]


In [5]:
PROJECTNAME = "gnn_run4"
FILENAME = graph_name + "_run_" + leadtime
TRAINNAME = graph_name + "_train_run_" + leadtime

RESULTPATH = os.path.join(DIRECTORY, f"leas_trained_models/best_{leadtime}/best_{leadtime}_{graph_name}")

with wandb.init(
        project=PROJECTNAME, id=TRAINNAME, config=args_dict, tags=["final"]
):
    config = wandb.config

    multigraph = Multigraph(
        embedding_dim=emb_dim,
        edge_dim=edge_dim,
        in_channels=in_channels,
        hidden_channels_gnn=config['gnn_hidden'],
        out_channels_gnn=config['gnn_hidden'],
        num_layers_gnn=config['gnn_layers'],
        heads=config['heads'],
        hidden_channels_deepset=config['gnn_hidden'],
        optimizer_class=AdamW,
        optimizer_params=dict(lr=config['lr']),
    )
    torch.compile(multigraph)
    batch = next(iter(train_loader))
    multigraph.forward(batch)

    wandb_logger = WandbLogger(project=PROJECTNAME)
    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=TRAINNAME, monitor="train_loss", mode="min", save_top_k=1
    )

    # print("[INFO] Training model...")
    trainer = L.Trainer(
        max_epochs=max_epochs,
        log_every_n_steps=1,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,
    )

    trainer.fit(model=multigraph, train_dataloaders=train_loader)

[34m[1mwandb[0m: Currently logged in as: [33mleachen01[0m ([33mleachen01-karlsruhe-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/ltchen/.conda/envs/gnn_env4/

Epoch 30: 100%|██████████| 431/431 [00:13<00:00, 31.93it/s, v_num=_24h, train_loss_step=0.486, train_loss_epoch=0.498]

`Trainer.fit` stopped: `max_epochs=31` reached.


Epoch 30: 100%|██████████| 431/431 [00:13<00:00, 31.31it/s, v_num=_24h, train_loss_step=0.486, train_loss_epoch=0.498]


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███
train_loss_epoch,█▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
train_loss_step,▇█▅▄▃▃▄▄▄▂▃▂▂▂▄▃▃▃▃▄▂▂▃▁▁▄▃▁▂▂▂▂▁▁▂▂▂▂▂▁
trainer/global_step,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▇▇▇▇▇█████

0,1
epoch,30.0
train_loss_epoch,0.49845
train_loss_step,0.4858
trainer/global_step,13360.0


In [None]:
CKPT_PATH = os.path.join(SAVEPATH, TRAINNAME+'.ckpt')

multigraph = Multigraph.load_from_checkpoint(
    CKPT_PATH,
    embedding_dim=emb_dim,
    edge_dim=edge_dim,
    in_channels=in_channels,
    hidden_channels_gnn=config['gnn_hidden'],
    out_channels_gnn=config['gnn_hidden'],
    num_layers_gnn=config['gnn_layers'],
    heads=config['heads'],
    hidden_channels_deepset=config['gnn_hidden'],
    optimizer_class=AdamW,
    optimizer_params=dict(lr=config['lr']),
)

multigraph.eval()
trainer = L.Trainer()

In [6]:
data_list = ["f", "rf"]
for data, tl in zip(data_list, test_loader):
    preds_list = []
    preds = trainer.predict(model=multigraph, dataloaders=[tl]) # 92 x 976 x 2 forecasts with mu and sigma of 122 stations
    print(preds[0].shape)
    # preds = [prediction.reshape(1, 122, 2).mean(axis=0) for prediction in preds]
    #ACHTUNG - reshape(1, 122, 2) mit 1 statt 5!
    preds = torch.cat(preds, dim=0)
    preds_list.append(preds)

    targets = dataframes[f"test_{data}"][1]
    targets = torch.tensor(targets.t2m.values) - 273.15

    stacked = torch.stack(preds_list)
    final_preds = torch.mean(stacked, dim=0)

    res = multigraph.loss_fn.crps(final_preds, targets)
    print("#############################################")
    print("#############################################")
    print(f"final crps: {res.item()}")
    print("#############################################")
    print("#############################################")

    os.makedirs(RESULTPATH, exist_ok=True)

    df = pd.DataFrame(np.concatenate([targets.view(-1, 1), final_preds], axis=1), columns=["t2m", "mu", "sigma"])
    df.to_csv(os.path.join(RESULTPATH, f"{data}_{FILENAME}_results.csv"), index=False)

    # Create Log File ###############################################################
    log_file = os.path.join(RESULTPATH, f"{data}.txt")
    with open(log_file, "w") as f:
        f.write(f"Data: {data}\n")
        f.write(f"Leadtime: {leadtime}\n")
        f.write(f"Final crps: {res.item()}")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 80.68it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.6999879884438922
#############################################
#############################################


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 79.75it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.7336807390409388
#############################################
#############################################


## Graph 2

In [6]:
graph_name = "g2"
graphs2_train_rf, tests2 = normalize_features_and_create_graphs1(df_train=dataframes['train'], df_valid_test=[dataframes['test_rf'], dataframes['test_f']], station_df=dataframes['stations'], attributes=["geo", "alt", "lon", "lat", "alt-orog"], edges=[("geo", 100)], sum_stats = True)
graphs2_test_rf, graphs2_test_f = tests2

g2_train_loader = DataLoader(graphs2_train_rf, batch_size=config['batch_size'], shuffle=True)
g2_test_f_loader = DataLoader(graphs2_test_f, batch_size=config['batch_size'], shuffle=False)
g2_test_rf_loader = DataLoader(graphs2_test_rf, batch_size=config['batch_size'], shuffle=False)

train_loader = g2_train_loader
test_f_loader = g2_test_f_loader
test_rf_loader = g2_test_rf_loader
test_loader = [test_f_loader, test_rf_loader]

emb_dim = 20
in_channels = graphs2_train_rf[0].x.shape[1] + emb_dim - 1
edge_dim = graphs2_train_rf[0].num_edge_features
max_epochs = max_epoch_list[graph_name]

# embedding_dim = emb_dim
# in_channels = in_channels
# hidden_channels_gnn = config['gnn_hidden']
# out_channels_gnn = config['gnn_hidden']
# num_layers_gnn = config['gnn_hidden']
# heads = config['heads']
# hidden_channels_deepset = config['gnn_hidden']
# optimizer_class = AdamW
# optimizer_params = dict(lr=config['lr'])

[INFO] Normalizing features...
fit_transform
transform 1
transform 2


100%|██████████| 3448/3448 [00:16<00:00, 205.29it/s]
100%|██████████| 732/732 [00:02<00:00, 263.96it/s]
100%|██████████| 730/730 [00:02<00:00, 273.75it/s]


In [8]:
PROJECTNAME = "gnn_run4"
FILENAME = graph_name + "_run_" + leadtime
TRAINNAME = graph_name + "_train_run_" + leadtime

RESULTPATH = os.path.join(DIRECTORY, f"leas_trained_models/best_{leadtime}/best_{leadtime}_{graph_name}")

with wandb.init(
        project=PROJECTNAME, id=TRAINNAME, config=args_dict, tags=["final"]
):
    config = wandb.config

    multigraph = Multigraph(
        embedding_dim=emb_dim,
        edge_dim=edge_dim,
        in_channels=in_channels,
        hidden_channels_gnn=config['gnn_hidden'],
        out_channels_gnn=config['gnn_hidden'],
        num_layers_gnn=config['gnn_layers'],
        heads=config['heads'],
        hidden_channels_deepset=config['gnn_hidden'],
        optimizer_class=AdamW,
        optimizer_params=dict(lr=config['lr']),
    )
    torch.compile(multigraph)
    batch = next(iter(train_loader))
    multigraph.forward(batch)

    wandb_logger = WandbLogger(project=PROJECTNAME)
    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=TRAINNAME, monitor="train_loss", mode="min", save_top_k=1
    )

    # print("[INFO] Training model...")
    trainer = L.Trainer(
        max_epochs=max_epochs,
        log_every_n_steps=1,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,
    )

    trainer.fit(model=multigraph, train_dataloaders=train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type              | Params | Mode 
----------------------------------------------------------
0 | encoder     | EmbedStations     | 2.4 K  | train
1 | conv        | ResGnn            | 9.9 M  | train
2 | aggr        | DeepSetAggregator | 212 K  | train
3 | postprocess | MakePositive      | 0      | train
4 | loss_fn     | Norm

Epoch 25: 100%|██████████| 431/431 [00:13<00:00, 31.74it/s, v_num=_24h, train_loss_step=0.560, train_loss_epoch=0.530]

`Trainer.fit` stopped: `max_epochs=26` reached.


Epoch 25: 100%|██████████| 431/431 [00:13<00:00, 31.05it/s, v_num=_24h, train_loss_step=0.560, train_loss_epoch=0.530]


0,1
epoch,▁▁▁▁▁▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇█████
train_loss_epoch,█▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁
train_loss_step,█▆▅▅▅█▆▅▇▇█▇▅▆▅▆▅▅▅▅▃▆▄█▂▄▄▄▄▄▄▄▃▄▃▃▆▄▁▅
trainer/global_step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇█

0,1
epoch,25.0
train_loss_epoch,0.52988
train_loss_step,0.5605
trainer/global_step,11205.0


In [None]:
CKPT_PATH = os.path.join(SAVEPATH, TRAINNAME+'.ckpt')

multigraph = Multigraph.load_from_checkpoint(
    CKPT_PATH,
    embedding_dim=emb_dim,
    edge_dim=edge_dim,
    in_channels=in_channels,
    hidden_channels_gnn=config['gnn_hidden'],
    out_channels_gnn=config['gnn_hidden'],
    num_layers_gnn=config['gnn_layers'],
    heads=config['heads'],
    hidden_channels_deepset=config['gnn_hidden'],
    optimizer_class=AdamW,
    optimizer_params=dict(lr=config['lr']),
)

multigraph.eval()
trainer = L.Trainer()

In [11]:
data = "rf"
preds_list_rf = []
preds_rf = trainer.predict(model=multigraph, dataloaders=[test_rf_loader]) # 92 x 976 x 2 forecasts with mu and sigma of 122 stations
print(preds_rf[0].shape)
# preds = [prediction.reshape(1, 122, 2).mean(axis=0) for prediction in preds]
#ACHTUNG - reshape(1, 122, 2) mit 1 statt 5!
preds_rf = torch.cat(preds_rf, dim=0)
preds_list_rf.append(preds_rf)

targets_rf = dataframes["test_rf"][1]
targets_rf = torch.tensor(targets_rf.t2m.values) - 273.15

stacked_rf = torch.stack(preds_list_rf)
final_preds_rf = torch.mean(stacked_rf, dim=0)

res_rf = multigraph.loss_fn.crps(final_preds_rf, targets_rf)
print("#############################################")
print("#############################################")
print(f"final crps: {res_rf.item()}")
print("#############################################")
print("#############################################")

os.makedirs(RESULTPATH, exist_ok=True)

df_rf = pd.DataFrame(np.concatenate([targets_rf.view(-1, 1), final_preds_rf], axis=1), columns=["t2m", "mu", "sigma"])
df_rf.to_csv(os.path.join(RESULTPATH, f"{data}_{FILENAME}_results.csv"), index=False)

# Create Log File ###############################################################
log_file = os.path.join(RESULTPATH, f"{data}.txt")
with open(log_file, "w") as f:
    f.write(f"Data: {data}\n")
    f.write(f"Leadtime: {leadtime}\n")
    f.write(f"Final crps: {res_rf.item()}")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 78.10it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.6385229321035536
#############################################
#############################################


In [12]:
data = "f"
preds_list_f = []
preds_f = trainer.predict(model=multigraph, dataloaders=[test_f_loader]) # 92 x 976 x 2 forecasts with mu and sigma of 122 stations
print(preds_f[0].shape)
# preds = [prediction.reshape(1, 122, 2).mean(axis=0) for prediction in preds]
#ACHTUNG - reshape(1, 122, 2) mit 1 statt 5!
preds_f = torch.cat(preds_f, dim=0)
preds_list_f.append(preds_f)

targets_f = dataframes["test_f"][1]
targets_f = torch.tensor(targets_f.t2m.values) - 273.15

stacked_f = torch.stack(preds_list_f)
final_preds_f = torch.mean(stacked_f, dim=0)

res_f = multigraph.loss_fn.crps(final_preds_f, targets_f)
print("#############################################")
print("#############################################")
print(f"final crps: {res_f.item()}")
print("#############################################")
print("#############################################")

os.makedirs(RESULTPATH, exist_ok=True)

df_f = pd.DataFrame(np.concatenate([targets_f.view(-1, 1), final_preds_f], axis=1), columns=["t2m", "mu", "sigma"])
df_f.to_csv(os.path.join(RESULTPATH, f"{data}_{FILENAME}_results.csv"), index=False)

# Create Log File ###############################################################
log_file = os.path.join(RESULTPATH, f"{data}.txt")
with open(log_file, "w") as f:
    f.write(f"Data: {data}\n")
    f.write(f"Leadtime: {leadtime}\n")
    f.write(f"Final crps: {res_f.item()}")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 79.27it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.644819088460565
#############################################
#############################################


## Graph 3

In [4]:
graph_name = "g3"
graphs3_train_rf, tests3 = normalize_features_and_create_graphs1(df_train=dataframes['train'], df_valid_test=[dataframes['test_rf'], dataframes['test_f']], station_df=dataframes['stations'], attributes=["geo", "alt", "lon", "lat", "alt-orog"], edges=[("geo", 55), ("alt", 6.5), ("alt-orog", 2.5)], sum_stats = True)
graphs3_test_rf, graphs3_test_f = tests3

g3_train_loader = DataLoader(graphs3_train_rf, batch_size=config['batch_size'], shuffle=True)
g3_test_f_loader = DataLoader(graphs3_test_f, batch_size=config['batch_size'], shuffle=False)
g3_test_rf_loader = DataLoader(graphs3_test_rf, batch_size=config['batch_size'], shuffle=False)

train_loader = g3_train_loader
test_f_loader = g3_test_f_loader
test_rf_loader = g3_test_rf_loader
test_loader = [test_f_loader, test_rf_loader]

emb_dim = 20
in_channels = graphs3_train_rf[0].x.shape[1] + emb_dim - 1
edge_dim = graphs3_train_rf[0].num_edge_features
max_epochs = max_epoch_list[graph_name]

# embedding_dim = emb_dim
# in_channels = in_channels
# hidden_channels_gnn = config['gnn_hidden']
# out_channels_gnn = config['gnn_hidden']
# num_layers_gnn = config['gnn_hidden']
# heads = config['heads']
# hidden_channels_deepset = config['gnn_hidden']
# optimizer_class = AdamW
# optimizer_params = dict(lr=config['lr'])

[INFO] Normalizing features...
fit_transform
transform 1
transform 2


100%|██████████| 3448/3448 [00:16<00:00, 204.22it/s]
100%|██████████| 732/732 [00:02<00:00, 276.71it/s]
100%|██████████| 730/730 [00:02<00:00, 257.24it/s]


In [5]:
PROJECTNAME = "gnn_run4"
FILENAME = graph_name + "_run_" + leadtime
TRAINNAME = graph_name + "_train_run_" + leadtime

RESULTPATH = os.path.join(DIRECTORY, f"leas_trained_models/best_{leadtime}/best_{leadtime}_{graph_name}")

with wandb.init(
        project=PROJECTNAME, id=TRAINNAME, config=args_dict, tags=["final"]
):
    config = wandb.config

    multigraph = Multigraph(
        embedding_dim=emb_dim,
        edge_dim=edge_dim,
        in_channels=in_channels,
        hidden_channels_gnn=config['gnn_hidden'],
        out_channels_gnn=config['gnn_hidden'],
        num_layers_gnn=config['gnn_layers'],
        heads=config['heads'],
        hidden_channels_deepset=config['gnn_hidden'],
        optimizer_class=AdamW,
        optimizer_params=dict(lr=config['lr']),
    )
    torch.compile(multigraph)
    batch = next(iter(train_loader))
    multigraph.forward(batch)

    wandb_logger = WandbLogger(project=PROJECTNAME)
    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=TRAINNAME, monitor="train_loss", mode="min", save_top_k=1
    )

    # print("[INFO] Training model...")
    trainer = L.Trainer(
        max_epochs=max_epochs,
        log_every_n_steps=1,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,
    )

    trainer.fit(model=multigraph, train_dataloaders=train_loader)

[34m[1mwandb[0m: Currently logged in as: [33mleachen01[0m ([33mleachen01-karlsruhe-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/ltchen/.conda/envs/gnn_env4/

Epoch 30: 100%|██████████| 431/431 [00:13<00:00, 30.90it/s, v_num=_24h, train_loss_step=0.397, train_loss_epoch=0.492]

`Trainer.fit` stopped: `max_epochs=31` reached.


Epoch 30: 100%|██████████| 431/431 [00:14<00:00, 30.30it/s, v_num=_24h, train_loss_step=0.397, train_loss_epoch=0.492]


0,1
epoch,▁▁▁▁▁▂▂▂▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇█████
train_loss_epoch,█▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁
train_loss_step,██▄▅▅▃▄▃▃▄▃▄▄▃▄▄▃▃▃▃▃▄▄▂▂▂▃▂▂▃▃▂▃▂▃▂▁▂▁▁
trainer/global_step,▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇████████

0,1
epoch,30.0
train_loss_epoch,0.49172
train_loss_step,0.3974
trainer/global_step,13360.0


In [None]:
CKPT_PATH = os.path.join(SAVEPATH, TRAINNAME+'.ckpt')

multigraph = Multigraph.load_from_checkpoint(
    CKPT_PATH,
    embedding_dim=emb_dim,
    edge_dim=edge_dim,
    in_channels=in_channels,
    hidden_channels_gnn=config['gnn_hidden'],
    out_channels_gnn=config['gnn_hidden'],
    num_layers_gnn=config['gnn_layers'],
    heads=config['heads'],
    hidden_channels_deepset=config['gnn_hidden'],
    optimizer_class=AdamW,
    optimizer_params=dict(lr=config['lr']),
)

multigraph.eval()
trainer = L.Trainer()

In [8]:
data_list = ["f", "rf"]
for data, tl in zip(data_list, test_loader):
    preds_list = []
    preds = trainer.predict(model=multigraph, dataloaders=[tl]) # 92 x 976 x 2 forecasts with mu and sigma of 122 stations
    print(preds[0].shape)
    # preds = [prediction.reshape(1, 122, 2).mean(axis=0) for prediction in preds]
    #ACHTUNG - reshape(1, 122, 2) mit 1 statt 5!
    preds = torch.cat(preds, dim=0)
    preds_list.append(preds)

    targets = dataframes[f"test_{data}"][1]
    targets = torch.tensor(targets.t2m.values) - 273.15

    stacked = torch.stack(preds_list)
    final_preds = torch.mean(stacked, dim=0)

    res = multigraph.loss_fn.crps(final_preds, targets)
    print("#############################################")
    print("#############################################")
    print(f"final crps: {res.item()}")
    print("#############################################")
    print("#############################################")

    os.makedirs(RESULTPATH, exist_ok=True)

    df = pd.DataFrame(np.concatenate([targets.view(-1, 1), final_preds], axis=1), columns=["t2m", "mu", "sigma"])
    df.to_csv(os.path.join(RESULTPATH, f"{data}_{FILENAME}_results.csv"), index=False)

    # Create Log File ###############################################################
    log_file = os.path.join(RESULTPATH, f"{data}.txt")
    with open(log_file, "w") as f:
        f.write(f"Data: {data}\n")
        f.write(f"Leadtime: {leadtime}\n")
        f.write(f"Final crps: {res.item()}")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 75.16it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.6489576454025832
#############################################
#############################################


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 79.11it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.6575640421709488
#############################################
#############################################


## Graph 4

In [9]:
graph_name = "g4"
graphs4_train_rf, tests4 = normalize_features_and_create_graphs1(df_train=dataframes['train'], df_valid_test=[dataframes['test_rf'], dataframes['test_f']], station_df=dataframes['stations'], attributes=["geo", "alt", "lon", "lat", "alt-orog"], edges=[("geo", 100), ("alt", 10), ("alt-orog", 5)], sum_stats = True)
graphs4_test_rf, graphs4_test_f = tests4

g4_train_loader = DataLoader(graphs4_train_rf, batch_size=config['batch_size'], shuffle=True)
g4_test_f_loader = DataLoader(graphs4_test_f, batch_size=config['batch_size'], shuffle=False)
g4_test_rf_loader = DataLoader(graphs4_test_rf, batch_size=config['batch_size'], shuffle=False)

train_loader = g4_train_loader
test_f_loader = g4_test_f_loader
test_rf_loader = g4_test_rf_loader
test_loader = [test_f_loader, test_rf_loader]

emb_dim = 20
in_channels = graphs4_train_rf[0].x.shape[1] + emb_dim - 1
edge_dim = graphs4_train_rf[0].num_edge_features

max_epochs = max_epoch_list[graph_name]

# embedding_dim = emb_dim
# in_channels = in_channels
# hidden_channels_gnn = config['gnn_hidden']
# out_channels_gnn = config['gnn_hidden']
# num_layers_gnn = config['gnn_hidden']
# heads = config['heads']
# hidden_channels_deepset = config['gnn_hidden']
# optimizer_class = AdamW
# optimizer_params = dict(lr=config['lr'])

[INFO] Normalizing features...
fit_transform
transform 1
transform 2


100%|██████████| 3448/3448 [00:16<00:00, 206.60it/s]
100%|██████████| 732/732 [00:02<00:00, 274.97it/s]
100%|██████████| 730/730 [00:02<00:00, 247.06it/s]


In [10]:
PROJECTNAME = "gnn_run4"
FILENAME = graph_name + "_run_" + leadtime
TRAINNAME = graph_name + "_train_run_" + leadtime

RESULTPATH = os.path.join(DIRECTORY, f"leas_trained_models/best_{leadtime}/best_{leadtime}_{graph_name}")

with wandb.init(
        project=PROJECTNAME, id=TRAINNAME, config=args_dict, tags=["final"]
):
    config = wandb.config

    multigraph = Multigraph(
        embedding_dim=emb_dim,
        edge_dim=edge_dim,
        in_channels=in_channels,
        hidden_channels_gnn=config['gnn_hidden'],
        out_channels_gnn=config['gnn_hidden'],
        num_layers_gnn=config['gnn_layers'],
        heads=config['heads'],
        hidden_channels_deepset=config['gnn_hidden'],
        optimizer_class=AdamW,
        optimizer_params=dict(lr=config['lr']),
    )
    torch.compile(multigraph)
    batch = next(iter(train_loader))
    multigraph.forward(batch)

    wandb_logger = WandbLogger(project=PROJECTNAME)
    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=TRAINNAME, monitor="train_loss", mode="min", save_top_k=1
    )

    # print("[INFO] Training model...")
    trainer = L.Trainer(
        max_epochs=max_epochs,
        log_every_n_steps=1,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,
    )

    trainer.fit(model=multigraph, train_dataloaders=train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/ltchen/gnnpp/leas_final_models/gnn_run4_24h/models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type              | Params | Mode 
----------------------------------------------------------
0 | encoder    

Epoch 31: 100%|██████████| 431/431 [00:18<00:00, 23.12it/s, v_num=_24h, train_loss_step=0.450, train_loss_epoch=0.485]

`Trainer.fit` stopped: `max_epochs=32` reached.


Epoch 31: 100%|██████████| 431/431 [00:18<00:00, 22.76it/s, v_num=_24h, train_loss_step=0.450, train_loss_epoch=0.485]


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█
train_loss_epoch,█▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇██████

0,1
epoch,31.0
train_loss_epoch,0.485
train_loss_step,0.45047
trainer/global_step,13791.0


In [None]:
CKPT_PATH = os.path.join(SAVEPATH, TRAINNAME+'.ckpt')

multigraph = Multigraph.load_from_checkpoint(
    CKPT_PATH,
    embedding_dim=emb_dim,
    edge_dim=edge_dim,
    in_channels=in_channels,
    hidden_channels_gnn=config['gnn_hidden'],
    out_channels_gnn=config['gnn_hidden'],
    num_layers_gnn=config['gnn_layers'],
    heads=config['heads'],
    hidden_channels_deepset=config['gnn_hidden'],
    optimizer_class=AdamW,
    optimizer_params=dict(lr=config['lr']),
)

multigraph.eval()
trainer = L.Trainer()

In [11]:
data_list = ["f", "rf"]
for data, tl in zip(data_list, test_loader):
    preds_list = []
    preds = trainer.predict(model=multigraph, dataloaders=[tl]) # 92 x 976 x 2 forecasts with mu and sigma of 122 stations
    print(preds[0].shape)
    # preds = [prediction.reshape(1, 122, 2).mean(axis=0) for prediction in preds]
    #ACHTUNG - reshape(1, 122, 2) mit 1 statt 5!
    preds = torch.cat(preds, dim=0)
    preds_list.append(preds)

    targets = dataframes[f"test_{data}"][1]
    targets = torch.tensor(targets.t2m.values) - 273.15

    stacked = torch.stack(preds_list)
    final_preds = torch.mean(stacked, dim=0)

    res = multigraph.loss_fn.crps(final_preds, targets)
    print("#############################################")
    print("#############################################")
    print(f"final crps: {res.item()}")
    print("#############################################")
    print("#############################################")

    os.makedirs(RESULTPATH, exist_ok=True)

    df = pd.DataFrame(np.concatenate([targets.view(-1, 1), final_preds], axis=1), columns=["t2m", "mu", "sigma"])
    df.to_csv(os.path.join(RESULTPATH, f"{data}_{FILENAME}_results.csv"), index=False)

    # Create Log File ###############################################################
    log_file = os.path.join(RESULTPATH, f"{data}.txt")
    with open(log_file, "w") as f:
        f.write(f"Data: {data}\n")
        f.write(f"Leadtime: {leadtime}\n")
        f.write(f"Final crps: {res.item()}")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 53.83it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.6631041168586904
#############################################
#############################################


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 52.21it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.671750525752253
#############################################
#############################################


## Graph 5

In [17]:
graph_name = "g5"
graphs5_train_rf, tests5 = normalize_features_and_create_graphs1(df_train=dataframes['train'], df_valid_test=[dataframes['test_rf'], dataframes['test_f']], station_df=dataframes['stations'], attributes=["geo", "alt", "lon", "lat", "alt-orog"], edges=[("geo", 100), ("alt", 10), ("alt-orog", 5)], sum_stats = True)
graphs5_test_rf, graphs5_test_f = tests5

g5_train_loader = DataLoader(graphs5_train_rf, batch_size=config['batch_size'], shuffle=True)
g5_test_f_loader = DataLoader(graphs5_test_f, batch_size=config['batch_size'], shuffle=False)
g5_test_rf_loader = DataLoader(graphs5_test_rf, batch_size=config['batch_size'], shuffle=False)

train_loader = g5_train_loader
test_f_loader = g5_test_f_loader
test_rf_loader = g5_test_rf_loader
test_loader = [test_f_loader, test_rf_loader]

emb_dim = 20
in_channels = graphs5_train_rf[0].x.shape[1] + emb_dim - 1
edge_dim = graphs5_train_rf[0].num_edge_features

max_epochs = max_epoch_list[graph_name]

# embedding_dim = emb_dim
# in_channels = in_channels
# hidden_channels_gnn = config['gnn_hidden']
# out_channels_gnn = config['gnn_hidden']
# num_layers_gnn = config['gnn_hidden']
# heads = config['heads']
# hidden_channels_deepset = config['gnn_hidden']
# optimizer_class = AdamW
# optimizer_params = dict(lr=config['lr'])

[INFO] Normalizing features...
fit_transform
transform 1
transform 2


100%|██████████| 3448/3448 [00:17<00:00, 200.19it/s]
100%|██████████| 732/732 [00:02<00:00, 247.93it/s]
100%|██████████| 730/730 [00:02<00:00, 270.21it/s]


In [18]:
PROJECTNAME = "gnn_run4"
FILENAME = graph_name + "_run_" + leadtime
TRAINNAME = graph_name + "_train_run_" + leadtime

RESULTPATH = os.path.join(DIRECTORY, f"leas_trained_models/best_{leadtime}/best_{leadtime}_{graph_name}")

with wandb.init(
        project=PROJECTNAME, id=TRAINNAME, config=args_dict, tags=["final"]
):
    config = wandb.config

    multigraph = Multigraph(
        embedding_dim=emb_dim,
        edge_dim=edge_dim,
        in_channels=in_channels,
        hidden_channels_gnn=config['gnn_hidden'],
        out_channels_gnn=config['gnn_hidden'],
        num_layers_gnn=config['gnn_layers'],
        heads=config['heads'],
        hidden_channels_deepset=config['gnn_hidden'],
        optimizer_class=AdamW,
        optimizer_params=dict(lr=config['lr']),
    )
    torch.compile(multigraph)
    batch = next(iter(train_loader))
    multigraph.forward(batch)

    wandb_logger = WandbLogger(project=PROJECTNAME)
    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=TRAINNAME, monitor="train_loss", mode="min", save_top_k=1
    )

    # print("[INFO] Training model...")
    trainer = L.Trainer(
        max_epochs=max_epochs,
        log_every_n_steps=1,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,
    )

    trainer.fit(model=multigraph, train_dataloaders=train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/ltchen/gnnpp/leas_final_models/gnn_run4_24h/models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type              | Params | Mode 
----------------------------------------------------------
0 | encoder    

Epoch 22: 100%|██████████| 431/431 [00:18<00:00, 23.17it/s, v_num=_24h, train_loss_step=0.553, train_loss_epoch=0.545]

`Trainer.fit` stopped: `max_epochs=23` reached.


Epoch 22: 100%|██████████| 431/431 [00:18<00:00, 22.81it/s, v_num=_24h, train_loss_step=0.553, train_loss_epoch=0.545]


0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
train_loss_epoch,█▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
train_loss_step,███▇▅▅▄▄▃▃▄▅▃▄▃▄▃▄▄▄▄▅▃▃▂▃▄▃▅▃▃▁▃▂▂▃▂▃▂▁
trainer/global_step,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇██

0,1
epoch,22.0
train_loss_epoch,0.54482
train_loss_step,0.55287
trainer/global_step,9912.0


In [None]:
CKPT_PATH = os.path.join(SAVEPATH, TRAINNAME+'.ckpt')

multigraph = Multigraph.load_from_checkpoint(
    CKPT_PATH,
    embedding_dim=emb_dim,
    edge_dim=edge_dim,
    in_channels=in_channels,
    hidden_channels_gnn=config['gnn_hidden'],
    out_channels_gnn=config['gnn_hidden'],
    num_layers_gnn=config['gnn_layers'],
    heads=config['heads'],
    hidden_channels_deepset=config['gnn_hidden'],
    optimizer_class=AdamW,
    optimizer_params=dict(lr=config['lr']),
)

multigraph.eval()
trainer = L.Trainer()

In [19]:
data_list = ["f", "rf"]
for data, tl in zip(data_list, test_loader):
    preds_list = []
    preds = trainer.predict(model=multigraph, dataloaders=[tl]) # 92 x 976 x 2 forecasts with mu and sigma of 122 stations
    print(preds[0].shape)
    # preds = [prediction.reshape(1, 122, 2).mean(axis=0) for prediction in preds]
    #ACHTUNG - reshape(1, 122, 2) mit 1 statt 5!
    preds = torch.cat(preds, dim=0)
    preds_list.append(preds)

    targets = dataframes[f"test_{data}"][1]
    targets = torch.tensor(targets.t2m.values) - 273.15

    stacked = torch.stack(preds_list)
    final_preds = torch.mean(stacked, dim=0)

    res = multigraph.loss_fn.crps(final_preds, targets)
    print("#############################################")
    print("#############################################")
    print(f"final crps: {res.item()}")
    print("#############################################")
    print("#############################################")

    os.makedirs(RESULTPATH, exist_ok=True)

    df = pd.DataFrame(np.concatenate([targets.view(-1, 1), final_preds], axis=1), columns=["t2m", "mu", "sigma"])
    df.to_csv(os.path.join(RESULTPATH, f"{data}_{FILENAME}_results.csv"), index=False)

    # Create Log File ###############################################################
    log_file = os.path.join(RESULTPATH, f"{data}.txt")
    with open(log_file, "w") as f:
        f.write(f"Data: {data}\n")
        f.write(f"Leadtime: {leadtime}\n")
        f.write(f"Final crps: {res.item()}")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 52.78it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.6461403042023351
#############################################
#############################################


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting DataLoader 0: 100%|██████████| 92/92 [00:01<00:00, 51.58it/s]
torch.Size([976, 2])
#############################################
#############################################
final crps: 0.6492747347012809
#############################################
#############################################
