In [1]:
%cd /home/ltchen/gnnpp
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
import pytorch_lightning as L
import torch
import torch_geometric
import json
import wandb

from typing import Tuple
from torch_geometric.nn import GATv2Conv
from torch_geometric.utils import scatter
from torch.nn import Linear, ModuleList, ReLU
from torch_geometric.loader import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.optim import AdamW
from pytorch_lightning.loggers import WandbLogger

from models.loss import NormalCRPS
from models.model_utils import MakePositive, EmbedStations
from models.graphensemble.multigraph import *
from utils.data import (
    load_dataframes,
    load_distances,
    normalize_features_and_create_graphs,
    rm_edges,
    summary_statistics,
)
from exploration.graph_creation import *

/home/ltchen/gnnpp


In [2]:
DIRECTORY = os.getcwd()
SAVEPATH = os.path.join(DIRECTORY, "explored_models/gnn_24h/models")
JSONPATH = os.path.join(DIRECTORY, "trained_models/no_ensemble_24h/params.json")

In [3]:
with open(JSONPATH, "r") as f:
    print(f"[INFO] Loading {JSONPATH}")
    args_dict = json.load(f)
config = args_dict
print(config)
print(config['lr'])
print(config['max_dist'])
print(type(config))
print(type(config['lr']))
print(type(config['gnn_hidden']))
'''{"batch_size":8,
"gnn_hidden":265,
"gnn_layers":2,
"heads":8,
"lr":0.0002,
"max_dist":100,
"max_epochs": 31}'''

[INFO] Loading /home/ltchen/gnnpp/trained_models/no_ensemble_24h/params.json
{'batch_size': 8, 'gnn_hidden': 256, 'gnn_layers': 1, 'heads': 8, 'lr': 0.0001, 'max_dist': 50, 'max_epochs': 23, 'remove_edges': 'False', 'only_summary': 'True'}
0.0001
50
<class 'dict'>
<class 'float'>
<class 'int'>


'{"batch_size":8,\n"gnn_hidden":265,\n"gnn_layers":2,\n"heads":8,\n"lr":0.0002,\n"max_dist":100,\n"max_epochs": 31}'

In [4]:
# load graph only for rf
dataframes = load_dataframes(leadtime= "24h")
dataframes = summary_statistics(dataframes)
dist = load_distances(dataframes["stations"])


[INFO] Dataframes exist. Will load pandas dataframes.
[INFO] Calculating summary statistics for train
[INFO] Calculating summary statistics for valid
[INFO] Calculating summary statistics for test_rf
[INFO] Calculating summary statistics for test_f
[INFO] Loading distances from file...


In [20]:
dataframes['valid']

(             time  station_id  model_orography  station_altitude  \
 0      2010-01-01           0        -1.706008               1.2   
 1      2010-01-01           1        -1.298122              -3.3   
 2      2010-01-01           2         0.333424              10.8   
 3      2010-01-01           3         1.302155               0.7   
 4      2010-01-01           4         2.576800               1.9   
 ...           ...         ...              ...               ...   
 100315 2013-12-31         115       521.714299             331.0   
 100316 2013-12-31         116       689.253673             424.0   
 100317 2013-12-31         117       972.938723             439.0   
 100318 2013-12-31         118      1752.460782            1478.0   
 100319 2013-12-31         119      2105.435549            1587.0   
 
         station_latitude  station_longitude  cape_mean   cape_std  \
 0              52.928000           4.781000  18.895952  17.770164   
 1              52.318000     

In [6]:
dataframes['train'][0].nunique()

time                  2612
station_id             120
model_orography        114
station_altitude       116
station_latitude       120
                     ...  
v_mean              301729
v_std               302992
t_mean              245228
t_std               302992
number                   1
Length: 65, dtype: int64

## GNN Architecture

In [5]:
# gnn architecture
class DeepSetAggregator(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(DeepSetAggregator, self).__init__()

        self.input = torch.nn.Linear(in_channels, hidden_channels)
        self.hidden1 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.hidden2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.output = torch.nn.Linear(hidden_channels, out_channels)
        self.relu = torch.nn.ReLU()

    def forward(self, x, index):
        x = self.input(x)
        x = self.relu(x)
        # print(f"DS - Input: {x.shape}")
        x = self.hidden1(x)
        x = self.relu(x)
        # print(f"DS - Hidden1: {x.shape}")
        x = scatter(x, index, dim=0, reduce="mean")
        # print(f"DS - scatter: {x.shape}")
        # print(f"DS - index: {index}")
        self.hidden2(x)
        x = self.relu(x)
        # print(f"DS - Hidden2: {x.shape}")
        x = self.output(x)
        # print(f"DS - output: {x.shape}")
        return x


class ResGnn(torch.nn.Module):
    def __init__(self, edge_dim: int, in_channels: int, out_channels: int, num_layers: int, hidden_channels: int, heads: int):
        super(ResGnn, self).__init__()
        assert num_layers > 0, "num_layers must be > 0."

        # Create Layers
        self.convolutions = ModuleList()
        for _ in range(num_layers):
            self.convolutions.append(
                GATv2Conv(-1, hidden_channels, heads=heads, edge_dim=edge_dim, add_self_loops=True, fill_value=0.01)
            )
        self.lin = Linear(hidden_channels * heads, out_channels)
        self.relu = ReLU()

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        x = x.float()
        edge_attr = edge_attr.float()
        for i, conv in enumerate(self.convolutions):
            if i == 0:
                # First Layer
                x = conv(x, edge_index, edge_attr)
                x = self.relu(x)
            else:
                x = x + self.relu(conv(x, edge_index, edge_attr))  # Residual Layers

        x = self.lin(x)
        return x

    @torch.no_grad()
    def get_attention(
        self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Runs a forward Pass for the given graph only though the ResGNN layer.
        NOTE: the data that is given to this method must first pass through the layers before this layer in the Graph

        :param torch.Tensor x: Tensor of Node Features (NxD)
        :param torch.Tensor edge_index: Tensor of Edges (2xE)
        :param torch.Tensor edge_attr: Edge Attributes (ExNum_Attr)
        :return x, edge_index_attention, attention_weights: Tensor of Node Features (NxD), Tensor of Edges with
        self loops (2xE), Tensor of Attention per edge (ExNum_Heads)
        """
        x = x.float()
        edge_attr = edge_attr.float()

        # Pass Data though Layer to get the Attention
        attention_list = []
        # Note: edge_index_attention has to be added since we have self loops now
        edge_index_attention, attention_weights = None, None

        for i, conv in enumerate(
            self.convolutions,
        ):
            if i == 0:
                # First Layer
                x, (edge_index_attention, attention_weights) = conv(
                    x, edge_index, edge_attr, return_attention_weights=True
                )
                print("attention_weights:")
                print(attention_weights)
                print("edge_index_attention")
                print(edge_index_attention)
                print(f"attention_weights.shape{attention_weights.shape}")
                print(f"type(attention_weights){type(attention_weights)}")
                attention_list.append(attention_weights)
                x = self.relu(x)
                x = self.norm(x)
            else:
                x_conv, (edge_index_attention, attention_weights) = conv(
                    x, edge_index, edge_attr, return_attention_weights=True
                )
                attention_list.append(attention_weights)
                x = x + self.relu(x_conv)  # Residual Layers
        x = self.lin(x)

        # Attention weights of first layer
        attention_weights = attention_weights.mean(dim=1)
        print("attention_weights.mean(dim=1)")
        print(attention_weights)
        print(attention_weights.shape)

        return x, edge_index_attention, attention_weights, attention_list

# gnn architecture
class ThisMultigraph(L.LightningModule):
    def __init__(
        self,
        embedding_dim,
        edge_dim,
        in_channels,
        hidden_channels_gnn,
        out_channels_gnn,
        num_layers_gnn,
        heads,
        hidden_channels_deepset,
        optimizer_class,
        optimizer_params,
    ):
        super(ThisMultigraph, self).__init__()

        self.encoder = EmbedStations(num_stations_max=120, embedding_dim=embedding_dim)

        self.conv = ResGnn(
            edge_dim=edge_dim,
            in_channels=in_channels,
            hidden_channels=hidden_channels_gnn,
            out_channels=out_channels_gnn,
            num_layers=num_layers_gnn,
            heads=heads,
        )

        self.aggr = DeepSetAggregator(
            in_channels=out_channels_gnn, hidden_channels=hidden_channels_deepset, out_channels=2
        )

        self.postprocess = MakePositive()
        self.loss_fn = NormalCRPS()

        self.optimizer_class = optimizer_class
        self.optimizer_params = optimizer_params

    def forward(self, data):
        x, edge_index, edge_attr, batch_id, node_idx = data.x, data.edge_index, data.edge_attr, data.batch, data.n_idx
        node_idx = node_idx + batch_id * 120  # add batch_id to node_idx to get unique node indices
        # print(f"GNN - input: {x.shape}")
        x = self.encoder(x)
        # print(f"GNN - embedding: {x.shape}")
        x = self.conv(x, edge_index, edge_attr)
        # print(f"GNN - after conv: {x.shape}")
        x = self.aggr(x, node_idx)
        # print(f"GNN - after aggr: {x.shape}")
        x = self.postprocess(x)
        return x

    def training_step(self, batch, batch_idx):
        y_hat = self.forward(batch)
        loss = self.loss_fn.crps(mu_sigma=y_hat, y=batch.y)
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=1
        )  # The batch size is not actually 1 but the loss is already averaged over the batch
        return loss

    def configure_optimizers(self):
        return self.optimizer_class(self.parameters(), **self.optimizer_params)

    def validation_step(self, batch, batch_idx):
        y_hat = self.forward(batch)
        loss = self.loss_fn.crps(mu_sigma=y_hat, y=batch.y)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=1)
        return loss

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        y_hat = self.forward(batch)
        loss = self.loss_fn.crps(mu_sigma=y_hat, y=batch.y)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=1)
        return loss

    def initialize(self, dataloader):
        batch = next(iter(dataloader))
        self.validation_step(batch, 0)

### Check DeepSetAggregator and ResGNN outside of Multigraph

In [23]:
graphs_train_rf, tests = normalize_features_and_create_graphs(
    training_data=dataframes["train"],
    valid_test_data=[dataframes["test_rf"], dataframes["test_f"]],
    mat=dist,
    max_dist=config['max_dist'],
)
graphs_test_rf, graphs_test_f = tests

graphs_test = graphs_test_rf

# print(graphs_test_rf[0].x.shape) (1342, 36)


print("[INFO] Creating data loaders...")
train_loader = DataLoader(graphs_train_rf, batch_size=config['batch_size'], shuffle=True)
print("[INFO] Creating model...")
emb_dim=20
in_channels = graphs_train_rf[0].x.shape[1] + emb_dim - 1


[INFO] Normalizing features...
[INFO] Creating graph data...
[INFO] Creating data loaders...
[INFO] Creating model...


In [11]:
embedding_dim=emb_dim
in_channels=in_channels
hidden_channels_gnn=config['gnn_hidden']
out_channels_gnn=config['gnn_hidden']
num_layers_gnn=config['gnn_hidden']
heads=config['heads']
hidden_channels_deepset=config['gnn_hidden']
optimizer_class=AdamW
optimizer_params=dict(lr=config['lr'])


In [10]:
graphs_test_rf[0].x.shape

torch.Size([120, 65])

## Train GNN

In [16]:
# train gnn
# build a graph with wandb => create multigraph - without summmary_statistics and no edges removed
with wandb.init(
    project="multigraph", id=f"training_run_24h_8", config=args_dict, tags=["final_training"]
):
    config = wandb.config
    #print("[INFO] Starting sweep with config: ", config)

    # graphs_train_rf, tests = normalize_features_and_create_graphs(
    #     training_data=dataframes["train"],
    #     valid_test_data=[dataframes["test_rf"], dataframes["test_f"]],
    #     mat=dist,
    #     max_dist=config['max_dist'],
    # )
    # graphs_test_rf, graphs_test_f = tests

    # graphs_test = graphs_test_rf

    # print(graphs_test_rf[0].x.shape) (1342, 36)


    # print("[INFO] Creating data loaders...")
    # train_loader = DataLoader(graphs_train_rf, batch_size=config['batch_size'], shuffle=True)

    # print("[INFO] Creating model...")
    # emb_dim=20
    # in_channels = graphs_train_rf[0].x.shape[1] + emb_dim - 1 #(36 + 20 - 1) = 55

    multigraph = ThisMultigraph(
        embedding_dim=emb_dim,
        edge_dim = 1,
        in_channels=in_channels,
        hidden_channels_gnn=config['gnn_hidden'],
        out_channels_gnn=config['gnn_hidden'],
        num_layers_gnn=config['gnn_layers'],
        heads=config['heads'],
        hidden_channels_deepset=config['gnn_hidden'],
        optimizer_class=AdamW,
        optimizer_params=dict(lr=config['lr']),
    )
    torch.compile(multigraph)

    # understand what this is
    batch = next(iter(train_loader))
    batch = batch  # .to("cuda")
    multigraph  # .to("cuda")
    multigraph.forward(batch)

    wandb_logger = WandbLogger(project="multigraph")
    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=f"run_24h", monitor="train_loss", mode="min", save_top_k=1
    )

    # print("[INFO] Training model...")
    trainer = L.Trainer(
        max_epochs=config['max_epochs'],
        log_every_n_steps=1,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,
    )

    trainer.fit(model=multigraph, train_dataloaders=train_loader)
wandb.finish()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/ltchen/gnnpp/explored_models/gnn_24h/models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type              | Params | Mode 
----------------------------------------------------------
0 | encoder     | Embe

Epoch 39: 100%|██████████| 327/327 [00:06<00:00, 51.46it/s, v_num=4h_8, train_loss_step=0.821, train_loss_epoch=0.561]

`Trainer.fit` stopped: `max_epochs=40` reached.


Epoch 39: 100%|██████████| 327/327 [00:06<00:00, 51.13it/s, v_num=4h_8, train_loss_step=0.821, train_loss_epoch=0.561]


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇█████
train_loss_epoch,█▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▅▄▃▄▄▃▃▃▂▂▂▃▃▂▂▂▃▂▂▂▃▂▃▁▃▂▂▃▂▃▂▂▂▃▂▄▃▁▂
trainer/global_step,▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▇▇▇▇▇█████

0,1
epoch,39.0
train_loss_epoch,0.56101
train_loss_step,0.8215
trainer/global_step,13079.0


In [8]:
with wandb.init(
    project="multigraph", id=f"training_run_24h_4", config=args_dict, tags=["final_training"], reinit=True
):
    config = wandb.config
    edge_dim = 1
    #print("[INFO] Starting sweep with config: ", config)

    # graphs_train_rf, tests = normalize_features_and_create_graphs(
    #     training_data=dataframes["train"],
    #     valid_test_data=[dataframes["test_rf"], dataframes["test_f"]],
    #     mat=dist,
    #     max_dist=config['max_dist'],
    # )
    # graphs_test_rf, graphs_test_f = tests

    # graphs_test = graphs_test_rf

    # print(graphs_test_rf[0].x.shape) (1342, 36)


    # print("[INFO] Creating data loaders...")
    # train_loader = DataLoader(graphs_train_rf, batch_size=config['batch_size'], shuffle=True)

    # print("[INFO] Creating model...")
    # emb_dim=20
    # in_channels = graphs_train_rf[0].x.shape[1] + emb_dim - 1 #(36 + 20 - 1) = 55

    multigraph2 = Multigraph(
        embedding_dim=emb_dim,
        edge_dim=edge_dim,
        in_channels=in_channels,
        hidden_channels_gnn=config['gnn_hidden'],
        out_channels_gnn=config['gnn_hidden'],
        num_layers_gnn=config['gnn_layers'],
        heads=config['heads'],
        hidden_channels_deepset=config['gnn_hidden'],
        optimizer_class=AdamW,
        optimizer_params=dict(lr=config['lr']),
    )
    torch.compile(multigraph2)

    # understand what this is
    batch = next(iter(train_loader))
    batch = batch  # .to("cuda")
    multigraph2  # .to("cuda")
    multigraph2.forward(batch)

    wandb_logger = WandbLogger(project="multigraph")
    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=f"run_24h", monitor="train_loss", mode="min", save_top_k=1
    )

    # print("[INFO] Training model...")
    trainer = L.Trainer(
        max_epochs=config['max_epochs'],
        log_every_n_steps=1,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,
    )

    trainer.fit(model=multigraph2, train_dataloaders=train_loader)
wandb.finish()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/ltchen/.conda/envs/gnn_env4/

Epoch 30: 100%|██████████| 431/431 [00:13<00:00, 31.66it/s, v_num=4h_4, train_loss_step=0.510, train_loss_epoch=0.515]

`Trainer.fit` stopped: `max_epochs=31` reached.


Epoch 30: 100%|██████████| 431/431 [00:13<00:00, 31.65it/s, v_num=4h_4, train_loss_step=0.510, train_loss_epoch=0.515]


0,1
epoch,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
train_loss_epoch,█▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▅▄▇▃▄▅▅▃▄▃▄▄▄▃▂▂▃▃▃▃▃▅▅▃▂▃▄▄▂▂▃▂▁▁▁▁▃▃▂
trainer/global_step,▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██

0,1
epoch,30.0
train_loss_epoch,0.51528
train_loss_step,0.51018
trainer/global_step,13360.0


In [7]:
l_graphs_train_rf, l_tests = normalize_features_and_create_graphs1(df_train=dataframes['train'], df_valid_test=[dataframes['test_rf'], dataframes['test_f']], station_df=dataframes['stations'], attributes=["geo"], edges=[("geo", 100)], sum_stats = True)
l_graphs_test_rf, l_graphs_test_f = l_tests
l_graphs_test = l_graphs_test_rf

l_train_loader = DataLoader(l_graphs_train_rf, batch_size=config['batch_size'], shuffle=True)
print(l_graphs_train_rf[0].num_edge_features)

[INFO] Normalizing features...
fit_transform
transform 1
transform 2
[INFO] Converting temperature values...


100%|██████████| 3448/3448 [00:16<00:00, 207.67it/s]
100%|██████████| 732/732 [00:02<00:00, 281.24it/s]
100%|██████████| 730/730 [00:02<00:00, 264.42it/s]

1





In [8]:
l_graphs_train_rf[0].y

tensor([ 4.6000,  6.4000,  3.3000,  2.6000,  6.2000,  3.3000,  4.3000,  2.1000,
         3.1000,  4.8000,  1.2000,  3.1000,  2.1000,  1.1000,  3.4000, -0.5000,
         2.5000,  8.9000,  8.7000,     nan,  8.3000,  8.8000,  8.3000,  6.9000,
         7.9000,  6.7000,  7.4000,  6.1000,  8.6000,  5.8000,  5.7000, -4.6000,
        -3.2000, -4.6000,     nan, -2.0000,  5.6000,  5.9000, -1.5000,  2.7000,
         7.8000,  2.7000,  2.9000,  0.9000, -1.2000,  2.4000, -3.7000,  0.4000,
        -2.4000,  1.3000, -0.2000, -0.1000,  2.1000,  4.3000,  5.1000,  6.7000,
        -3.3000,  1.6000,  2.4000,  0.0000, -1.0000, -0.1000,     nan,  5.1000,
         0.3000,  3.5000,  6.0000,  1.1000, -5.2000,  1.7000,  3.9000,  3.0000,
         5.4000,  6.8000,     nan, -3.2000,  4.2000,  0.4000, -1.2000,  1.6000,
         5.2000,  9.6000,  8.8000,  8.5000,  9.1000,  8.4000,  8.3000,  8.8000,
         8.9000,  8.7000,  7.6000,  6.9000,  8.7000,  7.8000,  6.7000,  6.3000,
         7.4000,  7.1000,  8.2000,  8.00

In [14]:
with wandb.init(
    project="multigraph", id=f"training_run_24h_10", config=args_dict, tags=["final_training"], reinit=True
):
    config = wandb.config
    edge_dim = l_graphs_train_rf[0].num_edge_features
    #print("[INFO] Starting sweep with config: ", config)

    # graphs_train_rf, tests = normalize_features_and_create_graphs(
    #     training_data=dataframes["train"],
    #     valid_test_data=[dataframes["test_rf"], dataframes["test_f"]],
    #     mat=dist,
    #     max_dist=config['max_dist'],
    # )
    # graphs_test_rf, graphs_test_f = tests

    # graphs_test = graphs_test_rf

    # print(graphs_test_rf[0].x.shape) (1342, 36)


    # print("[INFO] Creating data loaders...")
    # train_loader = DataLoader(graphs_train_rf, batch_size=config['batch_size'], shuffle=True)

    # print("[INFO] Creating model...")
    # emb_dim=20
    # in_channels = graphs_train_rf[0].x.shape[1] + emb_dim - 1 #(36 + 20 - 1) = 55

    multigraph3 = Multigraph(
        embedding_dim=emb_dim,
        edge_dim=edge_dim,
        in_channels=in_channels,
        hidden_channels_gnn=config['gnn_hidden'],
        out_channels_gnn=config['gnn_hidden'],
        num_layers_gnn=config['gnn_layers'],
        heads=config['heads'],
        hidden_channels_deepset=config['gnn_hidden'],
        optimizer_class=AdamW,
        optimizer_params=dict(lr=config['lr']),
    )
    torch.compile(multigraph3)

    # understand what this is
    batch = next(iter(train_loader))
    batch = batch  # .to("cuda")
    # multigraph3  # .to("cuda")
    multigraph3.forward(batch)

    wandb_logger = WandbLogger(project="multigraph")
    checkpoint_callback = ModelCheckpoint(
        dirpath=SAVEPATH, filename=f"run_24h", monitor="train_loss", mode="min", save_top_k=1
    )

    # print("[INFO] Training model...")
    trainer = L.Trainer(
        max_epochs=config['max_epochs'],
        log_every_n_steps=1,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        logger=wandb_logger,
        callbacks=checkpoint_callback,
    )

    trainer.fit(model=multigraph3, train_dataloaders=l_train_loader)
wandb.finish()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/ltchen/.conda/envs/gnn_env4/

Epoch 30: 100%|██████████| 431/431 [00:15<00:00, 28.58it/s, v_num=h_10, train_loss_step=0.473, train_loss_epoch=0.486]

`Trainer.fit` stopped: `max_epochs=31` reached.


Epoch 30: 100%|██████████| 431/431 [00:15<00:00, 28.01it/s, v_num=h_10, train_loss_step=0.473, train_loss_epoch=0.486]


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇█████
train_loss_epoch,█▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train_loss_step,██▇▆▇▄▆▅▄▅▅▂▅▃▄▂▃▃▃▂▃▃▃▄▃▂▁▁▂▂▁▂▂▂▃▂▂▂▂▁
trainer/global_step,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇█

0,1
epoch,30.0
train_loss_epoch,0.48608
train_loss_step,0.47318
trainer/global_step,13360.0


In [17]:
# evaluate gnn
test_loader = DataLoader(graphs_test, batch_size=1, shuffle=False)

# multigraph.load_state_dict(checkpoint["state_dict"])

# trainer = L.Trainer(log_every_n_steps=1, accelerator="gpu", devices=1, enable_progress_bar=True)
preds_list = []
preds = trainer.predict(model=multigraph, dataloaders=[test_loader])


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 732/732 [00:03<00:00, 217.22it/s]


In [18]:
preds = [prediction.reshape(1, 120, 2).mean(axis=0) for prediction in preds] #ACHTUNG - reshape(1, 122, 2) mit 1 statt 5!
preds = torch.cat(preds, dim=0)
preds_list.append(preds)

targets = dataframes["test_rf"][1]
targets = torch.tensor(targets.t2m.values) - 273.15

stacked = torch.stack(preds_list)
final_preds = torch.mean(stacked, dim=0)

res = multigraph.loss_fn.crps(final_preds, targets)
print("#############################################")
print("#############################################")
print(f"final crps: {res.item()}")
print("#############################################")
print("#############################################")


#############################################
#############################################
final crps: 0.6571427750898216
#############################################
#############################################


In [11]:
for prediction in preds:
    print(prediction.shape)

torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([122, 2])
torch.Size([1

In [18]:
import gc
import torch

gc.collect()
torch.cuda.empty_cache()