This notebook is used for debugging. Should have the same functionality as run_models.py, but with more verbose output.

**[TODO]** Update from run_models.py when the dust settles.

In [1]:
%load_ext autoreload
%autoreload 2


import os
import sys

import torch

# Add the 'scripts' directory to Python Path
scripts_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)

from training.help_functions import *

from gnn.help_functions import (
    EIGN_Loss,
    compute_baseline_of_mean_target,
    compute_baseline_of_no_policies
)
from gnn.models.point_net_transf_gat import PointNetTransfGAT
from gnn.models.eign import EIGNLaplacianConv

In [2]:
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))

# Please adjust as needed
dataset_path = os.path.join(
    project_root, "data", "train_data", "edge_features_with_net_flow"
)
base_dir = os.path.join(project_root, "data")

In [3]:
PARAMETERS = [
    "project_name",
    "predict_mode_stats",
    "in_channels",
    "use_all_features",
    "out_channels",
    "loss_fct",
    "use_weighted_loss",
    "point_net_conv_layer_structure_local_mlp",
    "point_net_conv_layer_structure_global_mlp",
    "gat_conv_layer_structure",
    "use_bootstrapping",
    "num_epochs",
    "batch_size",
    "lr",
    "early_stopping_patience",
    "use_dropout",
    "dropout",
    "gradient_accumulation_steps",
    "use_gradient_clipping",
    "device_nr",
    "unique_model_description",
    "use_monte_carlo_dropout",
]


def get_parameters(args):

    params = {
        # KEEP IN MIND: IF WE CHANGE PARAMETERS, WE NEED TO CHANGE THE NAME OF THE RUN IN WANDB (for the config)
        "project_name": "eign",
        "predict_mode_stats": args.predict_mode_stats,
        "in_channels": args.in_channels,
        "use_all_features": args.use_all_features,
        "out_channels": args.out_channels,
        "loss_fct": args.loss_fct,
        "use_weighted_loss": args.use_weighted_loss,
        "point_net_conv_layer_structure_local_mlp": [
            int(x) for x in args.point_net_conv_layer_structure_local_mlp.split(",")
        ],
        "point_net_conv_layer_structure_global_mlp": [
            int(x) for x in args.point_net_conv_layer_structure_global_mlp.split(",")
        ],
        "gat_conv_layer_structure": [
            int(x) for x in args.gat_conv_layer_structure.split(",")
        ],
        "use_bootstrapping": args.use_bootstrapping,
        "num_epochs": args.num_epochs,
        "batch_size": int(args.batch_size),
        "lr": float(args.lr),
        "early_stopping_patience": args.early_stopping_patience,
        "use_dropout": args.use_dropout,
        "dropout": args.dropout,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,
        "use_gradient_clipping": args.use_gradient_clipping,
        "device_nr": args.device_nr,
        "use_monte_carlo_dropout": args.use_monte_carlo_dropout,
    }

    params["unique_model_description"] = "eign_implementing"
    return params

In [4]:
datalist = []
batch_num = 1
while True and batch_num < 10:  # Change this to "and batch_num < 10" for a faster run
    print(f"Processing batch number: {batch_num}")
    # total_memory, available_memory, used_memory = get_memory_info()
    # print(f"Total Memory: {total_memory:.2f} GB")
    # print(f"Available Memory: {available_memory:.2f} GB")
    # print(f"Used Memory: {used_memory:.2f} GB")
    batch_file = os.path.join(dataset_path, f"datalist_batch_{batch_num}.pt")
    if not os.path.exists(batch_file):
        break
    batch_data = torch.load(batch_file, map_location="cpu")
    if isinstance(batch_data, list):
        datalist.extend(batch_data)
    batch_num += 1
print(f"Loaded {len(datalist)} items into datalist")

Processing batch number: 1
Processing batch number: 2
Processing batch number: 3
Loaded 1000 items into datalist


In [5]:
# Replace the argparse section with this:
args = {
    "in_channels": 5,
    "use_all_features": False,
    "out_channels": 1,
    "loss_fct": "mse",
    "use_weighted_loss": True,
    "predict_mode_stats": False,
    "point_net_conv_layer_structure_local_mlp": "256",
    "point_net_conv_layer_structure_global_mlp": "512",
    "gat_conv_layer_structure": "128,256,512,256",
    "use_bootstrapping": False,
    "num_epochs": 20,
    "batch_size": 1,
    "lr": 0.001,
    "early_stopping_patience": 5,
    "use_dropout": True,
    "dropout": 0.3,
    "gradient_accumulation_steps": 3,
    "use_gradient_clipping": True,
    "lr_scheduler_warmup_steps": 10000,
    "device_nr": 0,
    "use_monte_carlo_dropout": True,
}


# Convert the dictionary to an object with attributes
class Args:
    def __init__(self, **entries):
        self.__dict__.update(entries)


args = Args(**args)
set_random_seeds()

In [6]:
gpus = get_available_gpus()
best_gpu = select_best_gpu(gpus)
set_cuda_visible_device(best_gpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = get_parameters(args)

# Create directory for the run
unique_run_dir = os.path.join(
    base_dir, params["project_name"], params["unique_model_description"]
)
os.makedirs(unique_run_dir, exist_ok=True)

model_save_path, path_to_save_dataloader = get_paths(
    base_dir=os.path.join(base_dir, params["project_name"]),
    unique_model_description=params["unique_model_description"],
    model_save_path="trained_model/model.pth",
)

train_dl, valid_dl, scalers_train, scalers_validation = (
    prepare_data_with_graph_features(
        datalist=datalist,
        batch_size=params["batch_size"],
        path_to_save_dataloader=path_to_save_dataloader,
        use_all_features=params["use_all_features"],
        use_bootstrapping=params["use_bootstrapping"],
    )
)

config = setup_wandb({param: params[param] for param in PARAMETERS})

Using GPU 0 with CUDA_VISIBLE_DEVICES=0
Starting prepare_data_with_graph_features with 1000 items
Splitting into subsets...
Total dataset length: 1000
Training subset length: 800
Validation subset length: 150
Test subset length: 50
Split complete. Train: 800, Valid: 150, Test: 50
Saving test set...
Test set saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/data_created_during_training/test_set.pt
Normalizing train set...
Fitting and normalizing x features...


Fitting scaler: 100%|██████████| 8/8 [00:01<00:00,  4.98it/s]
Normalizing x features: 100%|██████████| 8/8 [00:00<00:00,  9.67it/s]


x features normalized
Fitting and normalizing pos features...


Fitting scaler: 100%|██████████| 1/1 [00:01<00:00,  1.60s/it]
Normalizing pos features: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]


Pos features normalized
Train set normalized
Normalizing validation set...
Fitting and normalizing x features...


Fitting scaler: 100%|██████████| 2/2 [00:00<00:00,  6.75it/s]
Normalizing x features: 100%|██████████| 2/2 [00:00<00:00, 12.57it/s]


x features normalized
Fitting and normalizing pos features...


Fitting scaler: 100%|██████████| 1/1 [00:00<00:00,  3.22it/s]
Normalizing pos features: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


Pos features normalized
Validation set normalized
Normalizing test set...
Fitting and normalizing x features...


Fitting scaler: 100%|██████████| 1/1 [00:00<00:00,  9.75it/s]
Normalizing x features: 100%|██████████| 1/1 [00:00<00:00, 17.18it/s]


x features normalized
Fitting and normalizing pos features...


Fitting scaler: 100%|██████████| 1/1 [00:00<00:00,  9.68it/s]
Normalizing pos features: 100%|██████████| 1/1 [00:00<00:00, 13.42it/s]


Pos features normalized
Test set normalized
Creating train loader...
Train loader created
Creating validation loader...
Validation loader created
Creating test loader...
Test loader created
Dataloaders and scalers saved


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mthuaduc24042001[0m ([33mthuaduc24042001-technical-university-of-munich[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
def create_model(architecture: str, config: object, device: torch.device):
    """
    Factory function to create the specified model architecture.

    Parameters:
    - architecture: str, the name of the architecture to use
    - config: object containing model parameters
    - device: torch device to put the model on

    Returns:
    - Initialized model on the specified device
    """
    if architecture == "point_net_transf_gat":
        return PointNetTransfGAT(
            in_channels=config.in_channels,
            out_channels=config.out_channels,
            point_net_conv_layer_structure_local_mlp=config.point_net_conv_layer_structure_local_mlp,
            point_net_conv_layer_structure_global_mlp=config.point_net_conv_layer_structure_global_mlp,
            gat_conv_layer_structure=config.gat_conv_layer_structure,
            use_dropout=config.use_dropout,
            dropout=config.dropout,
            predict_mode_stats=config.predict_mode_stats,
            dtype=torch.float32,
        ).to(device)
    elif architecture == "eign":
        return EIGNLaplacianConv(
            in_channels_signed=1,
            out_channels_signed=1,
            in_channels_unsigned=5,
            out_channels_unsigned=1,
            hidden_channels_signed=32,
            hidden_channels_unsigned=32,
            num_blocks=4,
        ).to(device)
    else:
        raise ValueError(f"Unknown architecture: {architecture}")

In [8]:
# gnn_instance = create_model("point_net_transf_gat", config, device)
gnn_instance = create_model("eign", config, device)


model = gnn_instance.to(device)
loss_fct = EIGN_Loss(config.loss_fct, datalist[0].x.shape[0], device)

baseline_loss_mean_target = compute_baseline_of_mean_target(
    dataset=train_dl, loss_fct=loss_fct, device=device, scalers=scalers_train
)
baseline_loss = compute_baseline_of_no_policies(
    dataset=train_dl, loss_fct=loss_fct, device=device, scalers=scalers_train
)
print("baseline loss mean " + str(baseline_loss_mean_target))
print("baseline loss no  " + str(baseline_loss))

early_stopping = EarlyStopping(patience=params["early_stopping_patience"], verbose=True)

baseline loss mean 114.76651000976562
baseline loss no  114.93736267089844


In [9]:
best_val_loss, best_epoch = gnn_instance.train_model(
    config=config,
    loss_fct=loss_fct,
    optimizer=torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-4),
    train_dl=train_dl,
    valid_dl=valid_dl,
    device=device,
    early_stopping=early_stopping,
    model_save_path=model_save_path,
    scalers_train=scalers_train,
    scalers_validation=scalers_validation,
)

print(
    f"Best model saved to {model_save_path} with validation loss: {best_val_loss} at epoch {best_epoch}"
)

Epoch 1/20: 100%|██████████| 800/800 [00:48<00:00, 16.50it/s]


epoch: 0, validation loss: 109.50632671356202, lr: 0.00099875, r^2: -0.00019168853759765625
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 109.50632671356202
Checkpoint saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/checkpoints/checkpoint_epoch_0.pt


Epoch 2/20: 100%|██████████| 800/800 [00:48<00:00, 16.50it/s]


epoch: 1, validation loss: 109.16880620320639, lr: 0.0009932656741722626, r^2: 0.002891063690185547
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 109.16880620320639


Epoch 3/20: 100%|██████████| 800/800 [00:48<00:00, 16.58it/s]


epoch: 2, validation loss: 65.7386860148112, lr: 0.0009732127441394693, r^2: 0.39956629276275635
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 65.7386860148112


Epoch 4/20: 100%|██████████| 800/800 [00:47<00:00, 16.68it/s]


epoch: 3, validation loss: 59.50482958475749, lr: 0.000940388190986082, r^2: 0.4565041661262512
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 59.50482958475749


Epoch 5/20: 100%|██████████| 800/800 [00:48<00:00, 16.65it/s]


epoch: 4, validation loss: 57.18687469482422, lr: 0.0008956873829549011, r^2: 0.4776754379272461
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 57.18687469482422


Epoch 6/20: 100%|██████████| 800/800 [00:48<00:00, 16.37it/s]


epoch: 5, validation loss: 55.63652453104655, lr: 0.0008403296415627077, r^2: 0.4918358325958252
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 55.63652453104655


Epoch 7/20: 100%|██████████| 800/800 [00:48<00:00, 16.56it/s]


epoch: 6, validation loss: 53.89411473592122, lr: 0.0007758249816878191, r^2: 0.507750391960144
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 53.89411473592122


Epoch 8/20: 100%|██████████| 800/800 [00:48<00:00, 16.46it/s]


epoch: 7, validation loss: 51.201228294372555, lr: 0.0007039329223005235, r^2: 0.532346248626709
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 51.201228294372555


Epoch 9/20: 100%|██████████| 800/800 [00:48<00:00, 16.36it/s]


epoch: 8, validation loss: 51.654256140391034, lr: 0.0006266144913722947, r^2: 0.5282084941864014
EarlyStopping counter: 1 out of 5


Epoch 10/20: 100%|██████████| 800/800 [00:48<00:00, 16.54it/s]


epoch: 9, validation loss: 49.76831528345744, lr: 0.0005459787341447478, r^2: 0.5454339981079102
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 49.76831528345744


Epoch 11/20: 100%|██████████| 800/800 [00:48<00:00, 16.54it/s]


epoch: 10, validation loss: 49.26901751200358, lr: 0.0004642251838733197, r^2: 0.5499943494796753
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 49.26901751200358


Epoch 12/20: 100%|██████████| 800/800 [00:48<00:00, 16.48it/s]


epoch: 11, validation loss: 49.958893610636395, lr: 0.00038358386429381447, r^2: 0.5436933040618896
EarlyStopping counter: 1 out of 5


Epoch 13/20: 100%|██████████| 800/800 [00:48<00:00, 16.37it/s]


epoch: 12, validation loss: 48.80448270797729, lr: 0.00030625446038813056, r^2: 0.5542372465133667
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 48.80448270797729


Epoch 14/20: 100%|██████████| 800/800 [00:48<00:00, 16.40it/s]


epoch: 13, validation loss: 47.525541152954105, lr: 0.00023434631671210353, r^2: 0.5659186244010925
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 47.525541152954105


Epoch 15/20: 100%|██████████| 800/800 [00:49<00:00, 16.32it/s]


epoch: 14, validation loss: 47.61064317067464, lr: 0.00016982089997467532, r^2: 0.5651413202285767
EarlyStopping counter: 1 out of 5


Epoch 16/20: 100%|██████████| 800/800 [00:48<00:00, 16.50it/s]


epoch: 15, validation loss: 47.13144252777099, lr: 0.00011443829533923471, r^2: 0.5695182085037231
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 47.13144252777099


Epoch 17/20: 100%|██████████| 800/800 [00:48<00:00, 16.60it/s]


epoch: 16, validation loss: 47.263259684244794, lr: 6.970919588856186e-05, r^2: 0.568314254283905
EarlyStopping counter: 1 out of 5


Epoch 18/20: 100%|██████████| 800/800 [00:47<00:00, 16.73it/s]


epoch: 17, validation loss: 46.873480072021486, lr: 3.685369485561072e-05, r^2: 0.5718743801116943
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 46.873480072021486


Epoch 19/20: 100%|██████████| 800/800 [00:47<00:00, 16.70it/s]


epoch: 18, validation loss: 46.808010292053225, lr: 1.6768004660661315e-05, r^2: 0.572472333908081
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 46.808010292053225


Epoch 20/20: 100%|██████████| 800/800 [00:48<00:00, 16.63it/s]


epoch: 19, validation loss: 46.79047307332357, lr: 1.0000010572745346e-05, r^2: 0.5726325511932373
Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 46.79047307332357
Best validation loss:  46.79047307332357


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
lr,████▇▇▆▆▅▅▄▄▃▃▂▂▁▁▁▁
pearson,▁▂▇▇▇███████████████
r^2,▁▁▆▇▇▇▇█▇███████████
spearman,▁▂▄▄▅▆▆▇▇▇▇▇████████
train_loss,██▇▆▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_signed,██▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_unsigned,█▅▂▄▄▅▂▃▁▃▂▄▁▂▂▄▂▁▂▄▃▁▁▃▂▁▂▁▁▁▂▂▃▁▂▂▂▂▁▂
val_loss,██▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁

0,1
best_val_loss,46.79047
epoch,19.0
lr,1e-05
pearson,0.75857
r^2,0.57263
spearman,0.16313
train_loss,19574.80469
train_loss_signed,19505.19922
train_loss_unsigned,69.60474
val_loss,46.79047


Best model saved to /home/duc-nguyen/Downloads/projects/gnn_predicting_effects_of_traffic_policies/data/eign/eign_implementing/trained_model/model.pth with validation loss: 46.79047307332357 at epoch 19
