In [10]:
%cd /home/ltchen/gnnpp
import argparse
import json
import numpy as np
import os
import pandas as pd
# import lightning as L
import pytorch_lightning as L
import torch
import torch_geometric
from pytorch_lightning import LightningModule

from dataclasses import dataclass
from models.graphensemble.multigraph import Multigraph
from torch_geometric.loader import DataLoader
from torch.optim import AdamW
from utils.data import (
    load_dataframes_old,
    load_distances,
    normalize_features_and_create_graphs,
    split_graph,
    rm_edges,
    summary_statistics,
)

/home/ltchen/gnnpp


In [11]:
# args = argparse.ArgumentParser()
# args.add_argument("data", type=str, default="rf", help='Data to use for testing, can be "rf" or "f"')
# args.add_argument(
#     "leadtime", type=str, default="24h", help='Leadtime to use for testing, can be "24h", "72h" or "120h"'
# )
# args.add_argument("folder", type=str, default="trained_models/best_24h", help="Folder to load the models from")
#
# args = args.parse_args()
args = {
    "data": "rf",
"leadtime": "24h",
"folder": "trained_models/no_ensemble_24h"}
print("#################################################")
print(f"[INFO] Starting evaluation with data: {args['data']} and leadtime: {args['leadtime']}")
print("#################################################")

CHECKPOINT_FOLDER = args['folder']
JSONPATH = os.path.join(CHECKPOINT_FOLDER, "params.json")

# Load the JSON file
with open(JSONPATH, "r") as f:
    print(f"[INFO] Loading {JSONPATH}")
    args_dict = json.load(f)

@dataclass
class DummyConfig:
    pass

for key, value in args_dict.items():
    setattr(DummyConfig, key, value)

config = DummyConfig()
print("[INFO] Starting eval with config: ", args_dict)

# Load Data ######################################################################
dataframes = load_dataframes_old(mode="eval", leadtime=args['leadtime'])
# Only Summary ###################################################################
only_summary = False
if hasattr(config, "only_summary"):
    if config.only_summary is True or config.only_summary == "True":
        print("[INFO] Only using summary statistics...")
        dataframes = summary_statistics(dataframes)
        only_summary = True

dist = load_distances(dataframes["stations"])
graphs_train_rf, tests = normalize_features_and_create_graphs(
    training_data=dataframes["train"],
    valid_test_data=[dataframes["test_rf"], dataframes["test_f"]],
    mat=dist,
    max_dist=config.max_dist,
)
graphs_test_rf, graphs_test_f = tests

graphs_test = graphs_test_rf if args['data'] == "rf" else graphs_test_f

if args['data'] == "f" and not only_summary:
    print("[INFO] Splitting graphs for f data...")
    graphs_split = [split_graph(g) for g in graphs_test]
    graphs_test = [g for sublist in graphs_split for g in sublist]

# Remove Edges ##################################################################
if hasattr(config, "remove_edges"):
    if config.remove_edges == "True" or config.remove_edges is True:
        print("[INFO] Removing edges...")
        rm_edges(graphs_train_rf)
        rm_edges(graphs_test)

# Create Data Loaders ###########################################################
print("[INFO] Creating data loaders...")
train_loader = DataLoader(graphs_train_rf, batch_size=config.batch_size, shuffle=True)
# test_loader_rf = DataLoader(graphs_test_rf, batch_size=1, shuffle=False)
test_loader = DataLoader(graphs_test, batch_size=1 if args['data'] == "rf" else 5, shuffle=False)

# Create Model ##################################################################
print("[INFO] Creating ensemble...")

emb_dim = 20
in_channels = 55  # graphs_train_rf[0].x.shape[1] + emb_dim - 1

FOLDER = os.path.join(CHECKPOINT_FOLDER, "models")
preds_list = []
for path in os.listdir(FOLDER):
    if path.endswith(".ckpt"):
        print(f"[INFO] Loading model from {path}")
        # Load Model from chekcpoint
        checkpoint = torch.load(os.path.join(FOLDER, path))

        multigraph = Multigraph(
            num_nodes=graphs_test_f[0].num_nodes,
            edge_dim=1,
            embedding_dim=emb_dim,
            in_channels=in_channels,
            hidden_channels_gnn=config.gnn_hidden,
            out_channels_gnn=config.gnn_hidden,
            num_layers_gnn=config.gnn_layers,
            heads=config.heads,
            hidden_channels_deepset=config.gnn_hidden,
            optimizer_class=AdamW,
            optimizer_params=dict(lr=config.lr),
        )
        # torch_geometric.compile(multigraph)

        # run a dummy forward pass to initialize the model
        batch = next(iter(train_loader))
        batch = batch  # .to("cuda")
        #multigraph  # .to("cuda")
        multigraph.forward(batch)
        print(type(multigraph))
        print(isinstance(multigraph, LightningModule))

        multigraph.load_state_dict(checkpoint["state_dict"])

        trainer = L.Trainer(log_every_n_steps=1, accelerator="gpu", devices=[1], enable_progress_bar=True)

        preds = trainer.predict(model=multigraph, dataloaders=[test_loader])

        if args['data'] == "f" and not only_summary:
            preds = [
                prediction.reshape(5, 122, 2).mean(axis=0) for prediction in preds
            ]  # Average over the batch dimension

        preds = torch.cat(preds, dim=0)
        preds_list.append(preds)



#################################################
[INFO] Starting evaluation with data: rf and leadtime: 24h
#################################################
[INFO] Loading trained_models/no_ensemble_24h/params.json
[INFO] Starting eval with config:  {'batch_size': 8, 'gnn_hidden': 256, 'gnn_layers': 1, 'heads': 8, 'lr': 0.0001, 'max_dist': 50, 'max_epochs': 23, 'remove_edges': 'False', 'only_summary': 'True'}
[INFO] Dataframes exist. Will load pandas dataframes.
[INFO] Only using summary statistics...
[INFO] Calculating summary statistics for train
[INFO] Calculating summary statistics for test_rf
[INFO] Calculating summary statistics for test_f
[INFO] Loading distances from file...
[INFO] Normalizing features...
[INFO] Creating graph data...


  checkpoint = torch.load(os.path.join(FOLDER, path))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ltchen/.conda/envs/gnn_env4/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/t

[INFO] Creating data loaders...
[INFO] Creating ensemble...
[INFO] Loading model from run_0.ckpt
<class 'models.graphensemble.multigraph.Multigraph'>
True
Predicting DataLoader 0: 100%|██████████| 836/836 [00:04<00:00, 194.69it/s]
[INFO] Loading model from run_1.ckpt
<class 'models.graphensemble.multigraph.Multigraph'>
True


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting DataLoader 0: 100%|██████████| 836/836 [00:04<00:00, 207.19it/s]


In [None]:
# ! Hacky wack of getting the targets
targets = dataframes["test_rf"][1] if args['data'] == "rf" else dataframes["test_f"][1]
targets = torch.tensor(targets.t2m.values) - 273.15

stacked = torch.stack(preds_list)
final_preds = torch.mean(stacked, dim=0)

res = multigraph.loss_fn.crps(final_preds, targets)
print("#############################################")
print("#############################################")
print(f"final crps: {res.item()}")
print("#############################################")
print("#############################################")

# Save Results ##################################################################
# Create DataFrame
df = pd.DataFrame(np.concatenate([targets.view(-1, 1), final_preds], axis=1), columns=["t2m", "mu", "sigma"])
df.to_csv(os.path.join(CHECKPOINT_FOLDER, f"{args['data']}_results.csv"), index=False)

# Create Log File ###############################################################
log_file = os.path.join(CHECKPOINT_FOLDER, f"{args['data']}.txt")
with open(log_file, "w") as f:
    f.write(f"Data: {args['data']}\n")
    f.write(f"Leadtime: {args['leadtime']}\n")
    f.write(f"Final crps: {res.item()}")