# GNN Explore
- exploring SimpleConv, GAT, DeepSetAggregator

In [62]:
from jupyter_server.services.contents import checkpoints
%cd /home/ltchen/gnnpp

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
# PyTorch Lightning
import pytorch_lightning as L
import wandb

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# PyTorch geometric
import torch_geometric
import torch_geometric.data as geom_data
import torch_geometric.nn as geom_nn

from torch_geometric.nn import GATv2Conv, GCNConv
from torch_geometric.nn.aggr import MeanAggregation
from torch_geometric.utils import scatter
from torch.nn import Linear, ModuleList, ReLU
from torch_geometric.loader import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.optim import AdamW
from pytorch_lightning.loggers import WandbLogger

from models.loss import NormalCRPS, crps_active_stations, crps_averaged
from models.model_utils import MakePositive, EmbedStations
from utils.data import *
from torch_geometric.utils import to_networkx
from utils.data import *
import matplotlib as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from utils.plot import plot_map
import networkx as nx

/home/ltchen/gnnpp


In [17]:
dataframes = load_dataframes(mode="train", leadtime="24h")
dist = load_distances(dataframes["stations"])

dataframes = summary_statistics(dataframes)
train = dataframes["train"][0]
train_target = dataframes["train"][1]
test_rf = dataframes["test_rf"][0]
test_rf_target = dataframes["test_rf"][1]
test_f = dataframes["test_f"][0]
test_f_target = dataframes["test_f"][1]

[INFO] Dataframes exist. Will load pandas dataframes.
[INFO] Loading distances from file...
[INFO] Calculating summary statistics for train
[INFO] Calculating summary statistics for test_rf
[INFO] Calculating summary statistics for test_f


In [36]:
train_target

Unnamed: 0,time,station_id,t2m
0,1997-01-02,0,277.75
1,1997-01-02,1,279.55
2,1997-01-02,2,276.45
3,1997-01-02,3,275.75
4,1997-01-02,4,279.35
...,...,...,...
420651,2013-12-31,117,281.35
420652,2013-12-31,118,279.35
420653,2013-12-31,119,278.25
420654,2013-12-31,120,273.15


In [3]:
# nur ein ensemble mitglied (if we don't use summary statistics)
train = dataframes["train"][0][dataframes['train'][0]['number'] == 0]
train_target = dataframes["train"][1]
test_rf = dataframes["test_rf"][0][dataframes['test_rf'][0]['number'] == 0]
test_rf_target = dataframes["test_rf"][1]
test_f = dataframes["test_f"][0][dataframes['test_f'][0]['number'] == 0]
test_f_target = dataframes["test_f"][1]
test_f_target

Unnamed: 0,time,station_id,t2m
0,2017-01-01,0,278.65
1,2017-01-01,1,275.25
2,2017-01-01,2,279.75
3,2017-01-01,3,279.15
4,2017-01-01,4,275.05
...,...,...,...
89055,2018-12-31,117,277.95
89056,2018-12-31,118,276.85
89057,2018-12-31,119,276.35
89058,2018-12-31,120,270.65


In [18]:
max_dist = 100
graphs_train_rf, tests = normalize_features_and_create_graphs(
    training_data=(train, train_target),
    valid_test_data=[(test_rf, test_rf_target), (test_f, test_f_target)],
    mat=dist,
    max_dist=max_dist,
)

graphs_test_rf, graphs_test_f = tests

graphs_test = graphs_test_rf

[INFO] Normalizing features...
[INFO] Creating graph data...


In [11]:
#print(graphs_train_rf) #(1342, 36)
print(next(iter(graphs_train_rf)))
print(len(graphs_test_rf))
print(type(graphs_train_rf[:100]))

Data(x=[122, 65], edge_index=[2, 1420], edge_attr=[1420, 1], y=[122], timestamp=1997-01-02 00:00:00, n_idx=[122])
732
<class 'list'>


In [9]:
print(graphs_train_rf[0].x.shape)
print(graphs_train_rf[0].y.shape)
#print(graphs_train_rf[0].x)

# drop nans in target? => gradienten können dann nciht berechnet werden

# standardize data correctly? (should be standardized using normalize_features_and_create_graphs...
# target size: [975] vs. [975, 1]

torch.Size([122, 65])
torch.Size([122])


In [20]:
graphs_train_rf = graphs_train_rf[:500]
graphs_test_f = graphs_test_f[:300]

In [19]:
batch_size = 8
train_loader = DataLoader(graphs_train_rf, batch_size=batch_size, shuffle=True)
test_f_loader = DataLoader(graphs_test_f, batch_size=batch_size, shuffle=False)

## Simple GNN with Conv

In [48]:
class DeepSetAggregator(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(DeepSetAggregator, self).__init__()

        self.input = torch.nn.Linear(in_channels, hidden_channels)
        self.hidden1 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.hidden2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.output = torch.nn.Linear(hidden_channels, out_channels)
        self.relu = torch.nn.ReLU()

    def forward(self, x, index):
        x = self.input(x)
        x = self.relu(x)
        # print(f"Input: {x.shape}")
        x = self.hidden1(x)
        x = self.relu(x)
        # print(f"Hidden1: {x.shape}")
        x = scatter(x, index, dim=0, reduce="mean")
        # print(f"scatter: {x.shape}")
        # print(f"index: {index}")
        self.hidden2(x)
        x = self.relu(x)
        # print(f"Hidden2: {x.shape}")
        x = self.output(x)
        return x

class SimpleGCN(L.LightningModule):
    def __init__(self, in_features, h_features, out_features, optimizer_class, optimizer_params):
        super().__init__()
        self.conv1 = GCNConv(in_features, h_features)
        self.out = Linear(h_features, out_features)
        self.postprocess = MakePositive()
        self.loss = NormalCRPS()
        self.aggr = DeepSetAggregator(in_channels=out_features, hidden_channels=h_features, out_channels=2)
        self.optimizer_class = optimizer_class
        self.optimizer_params = optimizer_params
        self.save_hyperparameters()

    def forward(self, data):
        x, edge_index, edge_attr, batch_id, node_idx = data.x, data.edge_index, data.edge_attr, data.batch, data.n_idx
        node_idx = node_idx + batch_id * 122
        h = self.conv1(x, edge_index).relu()
        z = self.out(h) #.squeeze()
        print(f"z: {z.shape}")
        print(f"Output from model - Min: {z.min()}, Max: {z.max()}, Mean: {z.mean()}") #wieso mehrere Werte
        z = self.aggr(z, node_idx) # => aggregate to mean
        z = self.postprocess(z)
        return z

    def training_step(self, batch, batch_idx):
        y_hat = self.forward(batch)
        #print(f"y_hat: {y_hat}")
        #print(f"y_hat size: {y_hat.shape}")
        #print(f"Batch y: {batch.y}")
        #print(f"batch y size: {batch.y.shape}")
        loss = self.loss.crps(mu_sigma=y_hat, y=batch.y)
        print(f"Loss: {loss}")
        #print(f"loss size: {loss.shape}")
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        y_hat = self.forward(batch)
        loss = self.loss.crps(mu_sigma=y_hat, y=batch.y)
        self.log('train_loss', loss, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        y_hat = self.forward(batch)
        loss = self.loss.crps(mu_sigma=y_hat, y=batch.y)

        self.log('train_loss', loss, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return self.optimizer_class(self.parameters(), **self.optimizer_params)

    def initialize(self, dataloader): # wird gar nicht ausgeführt?
        batch = next(iter(dataloader))
        print(f"batch: {batch}")
        self.validation_step(batch, 0)

    def predict(self, batch, batch_idx):
        y_hat = self.forward(batch)
        return y_hat



In [27]:
in_channels = graphs_train_rf[0].x.shape[1]

# model = SimpleGCN(in_features=in_channels, h_features=100, out_features=1, optimizer_class=AdamW, optimizer_params={"lr": 0.001})
model = SimpleGCN(in_features=in_channels, h_features=100, out_features=2, optimizer_class=torch.optim.SGD, optimizer_params={"lr": 0.001})

#print(model)
train_iter = iter(train_loader)
batch = next(train_iter)
batch2 = next(train_iter)
print(batch.y.shape)
#print(len(batch))
model.forward(batch)
# print(model.forward(batch2))
print(batch)
print(batch2)


torch.Size([976])
z: torch.Size([976, 2])
Output from model - Min: -8.454617500305176, Max: 0.9300990700721741, Mean: -1.910292625427246
DataBatch(x=[976, 65], edge_index=[2, 11360], edge_attr=[11360, 1], y=[976], timestamp=[8], n_idx=[976], batch=[976], ptr=[9])
DataBatch(x=[976, 65], edge_index=[2, 11360], edge_attr=[11360, 1], y=[976], timestamp=[8], n_idx=[976], batch=[976], ptr=[9])


In [21]:
DIRECTORY = os.getcwd()
SAVEPATH = os.path.join(DIRECTORY, "explored_models/simpleconv_24h/models")
JSONPATH = os.path.join(DIRECTORY, "trained_models/simpleconv_24h/params.json")

In [8]:
with wandb.init(
        project="gnn_explore1",
        id=f"SimpleConv",
        #config=config,
        tags=["exploration"],
):
    in_channels = graphs_train_rf[0].x.shape[1] #+ emb_dim - 1
    print("[INFO] Creating model...")
    model = SimpleGCN(in_features=in_channels, h_features=64, out_features=1, optimizer_class=torch.optim.AdamW, optimizer_params={"lr": 0.001})
    batch = next(iter(train_loader))
    wandb_logger = WandbLogger(project="gnn_explore")

    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=f"run_24h", monitor="train_loss", mode="min", save_top_k=1
    )


    trainer = L.Trainer(max_epochs=15,
        log_every_n_steps=10,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        enable_model_summary=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,)
    print("[INFO] Training...")
    trainer.fit(model=model, train_dataloaders=train_loader)

    final_loss = trainer.logged_metrics["train_loss"]
    print("Final CRPS Loss:", final_loss)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVIC

[INFO] Creating model...
[INFO] Training...
Epoch 0:   0%|          | 0/431 [00:00<?, ?it/s] z: torch.Size([976, 1])
Output from model - Min: -4.21035099029541, Max: 1.0959666967391968, Mean: -1.7018977403640747
Loss: 10.084575653076172
Epoch 0:   0%|          | 1/431 [00:00<04:47,  1.49it/s, v_num=Conv]z: torch.Size([976, 1])
Output from model - Min: -4.812294006347656, Max: 0.5897536873817444, Mean: -2.1870615482330322
Loss: 6.0087127685546875
Epoch 0:   0%|          | 2/431 [00:00<02:26,  2.93it/s, v_num=Conv]z: torch.Size([976, 1])
Output from model - Min: -5.533427715301514, Max: 0.6286254525184631, Mean: -2.458021640777588
Loss: 9.824127197265625
Epoch 0:   1%|          | 3/431 [00:00<01:40,  4.25it/s, v_num=Conv]z: torch.Size([976, 1])
Output from model - Min: -6.036773681640625, Max: 0.39576244354248047, Mean: -2.856152057647705
Loss: 7.3949174880981445
Epoch 0:   1%|          | 4/431 [00:00<01:17,  5.54it/s, v_num=Conv]

/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 976. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


z: torch.Size([976, 1])
Output from model - Min: -6.800191879272461, Max: 1.3710496425628662, Mean: -2.9689369201660156
Loss: 7.8977227210998535
Epoch 0:   1%|          | 5/431 [00:00<01:02,  6.77it/s, v_num=Conv]z: torch.Size([976, 1])
Output from model - Min: -7.403378009796143, Max: 0.3072235882282257, Mean: -3.458500862121582
Loss: 6.661684036254883
Epoch 0:   1%|▏         | 6/431 [00:00<00:53,  7.98it/s, v_num=Conv]z: torch.Size([976, 1])
Output from model - Min: -7.940145015716553, Max: 0.5282481908798218, Mean: -3.590759038925171
Loss: 7.742902755737305
Epoch 0:   2%|▏         | 7/431 [00:00<00:46,  9.12it/s, v_num=Conv]z: torch.Size([976, 1])
Output from model - Min: -8.722893714904785, Max: 0.3542054295539856, Mean: -3.988144636154175
Loss: 8.677055358886719
Epoch 0:   2%|▏         | 8/431 [00:00<00:41, 10.24it/s, v_num=Conv]z: torch.Size([976, 1])
Output from model - Min: -9.3491849899292, Max: 0.5568503737449646, Mean: -4.426943778991699
Loss: 5.688777923583984
Epoch 0:   2%

`Trainer.fit` stopped: `max_epochs=15` reached.


Epoch 14: 100%|██████████| 431/431 [00:05<00:00, 84.14it/s, v_num=Conv]


Traceback (most recent call last):
  File "/tmp/ipykernel_1341860/4121820392.py", line 29, in <module>
    final_loss = trainer.logged_metrics["train_loss"]
KeyError: 'train_loss'


0,1
epoch,▁▁▁▁▁▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇▇▇▇██
train_loss_epoch,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▂▂▂▁▁▂▂▂▁▂▁▂▁▁▂▂▂▁▁▂▁▁▂▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███

0,1
epoch,14.0
train_loss_epoch,0.98489
train_loss_step,1.04815
trainer/global_step,6464.0


KeyError: 'train_loss'

In [45]:
in_channels = graphs_train_rf[0].x.shape[1]
model = SimpleGCN.load_from_checkpoint(os.path.join(SAVEPATH, "run_24h.ckpt"))
model.freeze()


In [46]:
trainer = L.Trainer(max_epochs=15,
        log_every_n_steps=10,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        enable_model_summary=True,
        # logger=wandb_logger,
        # callbacks=checkpoint_callback,
                    )
trainer.predict(model=model, dataloaders=train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0:   0%|          | 0/431 [00:00<?, ?it/s]z: torch.Size([976, 1])
Output from model - Min: -9.27012825012207, Max: 7.238727569580078, Mean: -1.2122957706451416
Predicting DataLoader 0:   0%|          | 1/431 [00:00<00:03, 137.05it/s]z: torch.Size([976, 1])
Output from model - Min: -6.103565216064453, Max: 7.370894908905029, Mean: -1.1382495164871216
Predicting DataLoader 0:   0%|          | 2/431 [00:00<00:03, 122.76it/s]z: torch.Size([976, 1])
Output from model - Min: -10.1791353225708, Max: 4.975143909454346, Mean: -2.117088794708252
Predicting DataLoader 0:   1%|          | 3/431 [00:00<00:03, 124.49it/s]z: torch.Size([976, 1])
Output from model - Min: -14.387320518493652, Max: 8.383459091186523, Mean: -2.896730899810791
Predicting DataLoader 0:   1%|          | 4/431 [00:00<00:03, 129.02it/s]z: torch.Size([976, 1])
Output from model - Min: -9.949376106262207, Max: 5.886409759521484, Mean: -1.3091517686843872
Predicting DataLoader 0:   1%|          | 5/431 [00:

[tensor([[ 6.2097,  1.5042],
         [ 5.5557,  1.5256],
         [ 6.3334,  1.5009],
         ...,
         [12.0216,  1.5284],
         [12.0216,  1.5284],
         [11.4721,  1.5272]]),
 tensor([[6.1574, 1.5055],
         [6.7662, 1.4869],
         [6.0169, 1.5102],
         ...,
         [6.2746, 1.5025],
         [6.2747, 1.5025],
         [5.9824, 1.5115]]),
 tensor([[15.3684,  1.6366],
         [15.0892,  1.6242],
         [14.8191,  1.6122],
         ...,
         [ 1.9705,  1.6272],
         [ 1.9705,  1.6272],
         [ 1.9499,  1.6275]]),
 tensor([[ 8.1608,  1.4927],
         [ 7.7329,  1.4894],
         [ 8.2828,  1.4945],
         ...,
         [-4.5700,  2.2138],
         [-4.5700,  2.2138],
         [-4.1097,  2.1434]]),
 tensor([[ 4.8332,  1.5524],
         [ 4.7116,  1.5574],
         [ 4.9844,  1.5461],
         ...,
         [10.2211,  1.5116],
         [10.2211,  1.5116],
         [ 9.7414,  1.5023]]),
 tensor([[6.8654, 1.4826],
         [5.5883, 1.5245],
        

In [65]:
preds_list = []

targets = test_f_target # R2F
model.eval()
preds = trainer.predict(model=model, dataloaders=test_f_loader) #R2F
print("test_rf_targets:")
print(targets)
print(f"preds shape before: {len(preds)}")

preds = torch.cat(preds, dim=0)
print(f"preds shape after: {preds.shape}")
# Reverse transform of the y_scaler (only on the mean)
# preds[:, 0] = torch.Tensor(y_scaler.inverse_transform(preds[:, 0].view(-1, 1))).flatten()
preds_list.append(preds)
print(f"len preds_list: {len(preds_list)}")

targets = torch.Tensor(targets.t2m.values)
print("t2m values:")
print(targets)
print(targets.shape)

stacked = torch.stack(preds_list)
print(stacked.shape)
final_preds = torch.mean(stacked, dim=0)
print("final_preds")
print(final_preds)
print(final_preds.shape)

res = model.loss.crps(final_preds, targets) # weil ich predict benutze anstatt loss? => probieren mit loss
print(res)
print("#############################################")
print("#############################################")
print(f"final crps: {res.item()}")
print("#############################################")
print("#############################################")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0:   0%|          | 0/92 [00:00<?, ?it/s]z: torch.Size([976, 1])
Output from model - Min: -1.152315616607666, Max: 11.222573280334473, Mean: 3.2589409351348877
Predicting DataLoader 0:   1%|          | 1/92 [00:00<00:00, 128.21it/s]z: torch.Size([976, 1])
Output from model - Min: -0.9620112180709839, Max: 7.250721454620361, Mean: 2.4508533477783203
Predicting DataLoader 0:   2%|▏         | 2/92 [00:00<00:00, 121.66it/s]z: torch.Size([976, 1])
Output from model - Min: 0.8579966425895691, Max: 8.953042030334473, Mean: 4.717904567718506
Predicting DataLoader 0:   3%|▎         | 3/92 [00:00<00:00, 125.04it/s]z: torch.Size([976, 1])
Output from model - Min: -2.235095500946045, Max: 7.640069484710693, Mean: 2.0289511680603027
Predicting DataLoader 0:   4%|▍         | 4/92 [00:00<00:00, 130.46it/s]z: torch.Size([976, 1])
Output from model - Min: -2.883188247680664, Max: 6.882808208465576, Mean: 1.263904333114624
Predicting DataLoader 0:   5%|▌         | 5/92 [00:00<00:00

## Simple GNN only with GATv2Conv
- GATv2Conv
- mean aggregation

In [None]:
# GAT => was macht das aus? GATv2Conv, get_attentions?
class GAT(torch.nn.Module):
    def __init__(self, in_features, h_features, out_features, num_heads):
        super(GAT, self).__init__()

        # layers
        self.conv1 = GATv2Conv(in_features, h_features, heads=num_heads, edge_dim=1, add_self_loop=True, fill_value=0.01)
        self.out = Linear(h_features * num_heads, out_features)
        self.relu = nn.ReLU()
        # self.loss = ??

    def forward(self, x, edge_index, edge_attr):
        x = x.float() # why?
        edge_attr = edge_attr.float()
        h = self.conv1(x, edge_index, edge_attr).relu()
        z = self.out(h).squeeze()
        return z

    @torch.no_grad()
    def get_attention(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = x.float()
        edge_attr = edge_attr.float()
        attention_list = []

        edge_index_attention, attention_weights = None, None
        x, (edge_index_attention, attention_weights) = self.conv1(x, edge_index, edge_attr, return_attention_weights=True)
        attention_list.append(attention_weights)
        x = self.relu(x)
        x = self.norm(x)
        x = self.out(x)
        return x, edge_index_attention, attention_weights, attention_list


In [None]:
class LGAT(L.LightningModule):
    def __init__(self, in_features, h_features, out_features, num_heads, optimizer_class, optimizer_params):
        super(LGAT, self).__init__()
        self.conv = GAT(in_features, h_features, out_features, num_heads)
        self.postprocess = MakePositive()
        self.loss_fn = NormalCRPS()
        self.optimizer_class = optimizer_class
        self.optimizer_params = optimizer_params

    def forward(self, data):
        x, edge_index, edge_attr, batch_id, node_idx = data.x, data.edge_index, data.edge_attr, data.batch, data.n_idx
        node_idx = node_idx + batch_id * 122  # add batch_id to node_idx to get unique node indices
        x = self.encoder(x)
        x = self.conv(x, edge_index, edge_attr)
        x = self.aggr(x, node_idx)
        x = self.postprocess(x)
        return x

    def training_step(self, batch, batch_idx):
        y_hat = self.forward(batch)
        loss = self.loss_fn.crps(mu_sigma=y_hat, y=batch.y)
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=1
        )  # The batch size is not actually 1 but the loss is already averaged over the batch
        return loss

    def configure_optimizers(self):
        return self.optimizer_class(self.parameters(), **self.optimizer_params)

    def validation_step(self, batch, batch_idx):
        y_hat = self.forward(batch)
        loss = self.loss_fn.crps(mu_sigma=y_hat, y=batch.y)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=1)
        return loss

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        y_hat = self.forward(batch)
        loss = self.loss_fn.crps(mu_sigma=y_hat, y=batch.y)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=1)
        return loss

    def initialize(self, dataloader):
        batch = next(iter(dataloader))
        self.validation_step(batch, 0)

In [None]:
# Moritz ResGNN as reference
class ResGnn(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_layers: int, hidden_channels: int, heads: int):
        super(ResGnn, self).__init__()
        assert num_layers > 0, "num_layers must be > 0."

        # Create Layers
        self.convolutions = ModuleList()
        for _ in range(num_layers):
            self.convolutions.append(
                GATv2Conv(-1, hidden_channels, heads=heads, edge_dim=1, add_self_loops=True, fill_value=0.01)
            )
        self.lin = Linear(hidden_channels * heads, out_channels)
        self.relu = ReLU()

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        x = x.float()
        edge_attr = edge_attr.float()
        for i, conv in enumerate(self.convolutions):
            if i == 0:
                # First Layer
                x = conv(x, edge_index, edge_attr)
                x = self.relu(x)
            else:
                x = x + self.relu(conv(x, edge_index, edge_attr))  # Residual Layers

        x = self.lin(x)
        return x

    @torch.no_grad()
    def get_attention(
        self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Runs a forward Pass for the given graph only though the ResGNN layer.
        NOTE: the data that is given to this method must first pass through the layers before this layer in the Graph

        :param torch.Tensor x: Tensor of Node Features (NxD)
        :param torch.Tensor edge_index: Tensor of Edges (2xE)
        :param torch.Tensor edge_attr: Edge Attributes (ExNum_Attr)
        :return x, edge_index_attention, attention_weights: Tensor of Node Features (NxD), Tensor of Edges with
        self loops (2xE), Tensor of Attention per edge (ExNum_Heads)
        """
        x = x.float()
        edge_attr = edge_attr.float()

        # Pass Data though Layer to get the Attention
        attention_list = []
        # Note: edge_index_attention has to be added since we have self loops now
        edge_index_attention, attention_weights = None, None

        for i, conv in enumerate(
            self.convolutions,
        ):
            if i == 0:
                # First Layer
                x, (edge_index_attention, attention_weights) = conv(
                    x, edge_index, edge_attr, return_attention_weights=True
                )
                attention_list.append(attention_weights)
                x = self.relu(x)
                x = self.norm(x)
            else:
                x_conv, (edge_index_attention, attention_weights) = conv(
                    x, edge_index, edge_attr, return_attention_weights=True
                )
                attention_list.append(attention_weights)
                x = x + self.relu(x_conv)  # Residual Layers
        x = self.lin(x)

        # Attention weights of first layer
        attention_weights = attention_weights.mean(dim=1)

        return x, edge_index_attention, attention_weights, attention_list

## GNN with DeepSetAggregator (without GAT)


## GNN with GAT and DeepSetAggregator

##