This notebook is used for debugging. Should have the same functionality as run_models.py, but with more verbose output.

**[TODO]** Update from run_models.py when the dust settles.

In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import torch

# Add the 'scripts' directory to Python Path
scripts_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)

from training.help_functions import *

from gnn.help_functions import EIGN_Loss, compute_baseline_of_mean_target, compute_baseline_of_no_policies
from gnn.models.point_net_transf_gat import PointNetTransfGAT
from gnn.models.eign import EIGNLaplacianConv

In [2]:
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))

# Please adjust as needed
dataset_path = os.path.join(
    project_root, "data", "train_data", "edge_features_with_net_flow_aggregated"
)
base_dir = os.path.join(project_root, "data")

In [3]:
PARAMETERS = [
    "project_name",
    "predict_mode_stats",
    "in_channels",
    "use_all_features",
    "out_channels",
    "loss_fct",
    "use_weighted_loss",
    "point_net_conv_layer_structure_local_mlp",
    "point_net_conv_layer_structure_global_mlp",
    "gat_conv_layer_structure",
    "use_bootstrapping",
    "num_epochs",
    "batch_size",
    "lr",
    "early_stopping_patience",
    "use_dropout",
    "dropout",
    "gradient_accumulation_steps",
    "use_gradient_clipping",
    "device_nr",
    "unique_model_description",
]


def get_parameters(args):

    params = {
        # KEEP IN MIND: IF WE CHANGE PARAMETERS, WE NEED TO CHANGE THE NAME OF THE RUN IN WANDB (for the config)
        "project_name": "IDP",
        "predict_mode_stats": args.predict_mode_stats,
        "in_channels": args.in_channels,
        "use_all_features": args.use_all_features,
        "out_channels": args.out_channels,
        "loss_fct": args.loss_fct,
        "use_weighted_loss": args.use_weighted_loss,
        "point_net_conv_layer_structure_local_mlp": [
            int(x) for x in args.point_net_conv_layer_structure_local_mlp.split(",")
        ],
        "point_net_conv_layer_structure_global_mlp": [
            int(x) for x in args.point_net_conv_layer_structure_global_mlp.split(",")
        ],
        "gat_conv_layer_structure": [
            int(x) for x in args.gat_conv_layer_structure.split(",")
        ],
        "use_bootstrapping": args.use_bootstrapping,
        "num_epochs": args.num_epochs,
        "batch_size": int(args.batch_size),
        "lr": float(args.lr),
        "early_stopping_patience": args.early_stopping_patience,
        "use_dropout": args.use_dropout,
        "dropout": args.dropout,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,
        "use_gradient_clipping": args.use_gradient_clipping,
        "device_nr": args.device_nr,
    }

    params["unique_model_description"] = "eign"
    return params

In [4]:
datalist = []
batch_num = 1
while True:  # Change this to "and batch_num < 10" for a faster run
    print(f"Processing batch number: {batch_num}")
    # total_memory, available_memory, used_memory = get_memory_info()
    # print(f"Total Memory: {total_memory:.2f} GB")
    # print(f"Available Memory: {available_memory:.2f} GB")
    # print(f"Used Memory: {used_memory:.2f} GB")
    batch_file = os.path.join(dataset_path, f"datalist_batch_{batch_num}.pt")
    if not os.path.exists(batch_file):
        break
    batch_data = torch.load(batch_file, map_location="cpu")
    if isinstance(batch_data, list):
        datalist.extend(batch_data)
    batch_num += 1
print(f"Loaded {len(datalist)} items into datalist")

Processing batch number: 1
Processing batch number: 2
Processing batch number: 3
Processing batch number: 4
Processing batch number: 5
Processing batch number: 6
Processing batch number: 7
Processing batch number: 8
Processing batch number: 9
Processing batch number: 10
Processing batch number: 11
Processing batch number: 12
Processing batch number: 13
Processing batch number: 14
Processing batch number: 15
Processing batch number: 16
Processing batch number: 9
Processing batch number: 10
Processing batch number: 11
Processing batch number: 12
Processing batch number: 13
Processing batch number: 14
Processing batch number: 15
Processing batch number: 16
Processing batch number: 17
Processing batch number: 18
Processing batch number: 19
Processing batch number: 17
Processing batch number: 18
Processing batch number: 19
Processing batch number: 20
Processing batch number: 21
Processing batch number: 22
Processing batch number: 23
Processing batch number: 24
Processing batch number: 25
Pr

In [5]:
# Replace the argparse section with this:
args = {
    "in_channels": 5,
    "use_all_features": False,
    "out_channels": 1,
    "loss_fct": "mse",
    "use_weighted_loss": True,
    "predict_mode_stats": False,
    "point_net_conv_layer_structure_local_mlp": "256",
    "point_net_conv_layer_structure_global_mlp": "512",
    "gat_conv_layer_structure": "128,256,512,256",
    "use_bootstrapping": False,
    "num_epochs": 100,
    "batch_size": 8,
    "lr": 0.001,
    "early_stopping_patience": 20,
    "use_dropout": True,
    "dropout": 0.3,
    "gradient_accumulation_steps": 3,
    "use_gradient_clipping": True,
    "lr_scheduler_warmup_steps": 10000,
    "device_nr": 0,
}


# Convert the dictionary to an object with attributes
class Args:
    def __init__(self, **entries):
        self.__dict__.update(entries)


args = Args(**args)
set_random_seeds()

In [6]:
gpus = get_available_gpus()
best_gpu = select_best_gpu(gpus)
set_cuda_visible_device(best_gpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = get_parameters(args)

# Create directory for the run
unique_run_dir = os.path.join(
    base_dir, params["project_name"], params["unique_model_description"]
)
os.makedirs(unique_run_dir, exist_ok=True)

model_save_path, path_to_save_dataloader = get_paths(
    base_dir=os.path.join(base_dir, params["project_name"]),
    unique_model_description=params["unique_model_description"],
    model_save_path="trained_model/model.pth",
)
train_dl, valid_dl, scalers_train, scalers_validation = (
    prepare_data_with_graph_features(
        datalist=datalist,
        batch_size=params["batch_size"],
        path_to_save_dataloader=path_to_save_dataloader,
        use_all_features=params["use_all_features"],
        use_bootstrapping=params["use_bootstrapping"],
    )
)

config = setup_wandb({param: params[param] for param in PARAMETERS})

Using GPU 0 with CUDA_VISIBLE_DEVICES=0
Starting prepare_data_with_graph_features with 4984 items
Splitting into subsets...
Total dataset length: 4984
Training subset length: 3987
Validation subset length: 747
Test subset length: 250
Split complete. Train: 3987, Valid: 747, Test: 250
Normalizing train set...
Fitting and normalizing x features...
Fitting and normalizing x features...


Fitting scaler: 100%|██████████| 40/40 [00:08<00:00,  4.80it/s]
Fitting scaler: 100%|██████████| 40/40 [00:08<00:00,  4.80it/s]
Normalizing x features: 100%|██████████| 40/40 [00:05<00:00,  7.14it/s]



x features normalized


Fitting x_signed scaler: 100%|██████████| 4/4 [00:00<00:00,  6.63it/s]
Fitting x_signed scaler: 100%|██████████| 4/4 [00:00<00:00,  6.63it/s]
Normalizing x_signed features: 100%|██████████| 4/4 [00:00<00:00,  4.06it/s]



Fitting and normalizing pos features...
Train set normalized
Normalizing validation set...
Fitting and normalizing x features...
Fitting and normalizing x features...


Fitting scaler: 100%|██████████| 8/8 [00:01<00:00,  4.72it/s]
Fitting scaler: 100%|██████████| 8/8 [00:01<00:00,  4.72it/s]
Normalizing x features: 100%|██████████| 8/8 [00:01<00:00,  7.30it/s]



x features normalized


Fitting x_signed scaler: 100%|██████████| 1/1 [00:00<00:00,  3.31it/s]
Fitting x_signed scaler: 100%|██████████| 1/1 [00:00<00:00,  3.31it/s]
Normalizing x_signed features: 100%|██████████| 1/1 [00:00<00:00,  5.76it/s]



Fitting and normalizing pos features...
Validation set normalized
Normalizing test set...
Fitting and normalizing x features...


Fitting scaler: 100%|██████████| 3/3 [00:00<00:00,  5.58it/s]
Fitting scaler: 100%|██████████| 3/3 [00:00<00:00,  5.58it/s]
Normalizing x features: 100%|██████████| 3/3 [00:00<00:00,  7.87it/s]



x features normalized


Fitting x_signed scaler: 100%|██████████| 1/1 [00:00<00:00, 28.45it/s]
Fitting x_signed scaler: 100%|██████████| 1/1 [00:00<00:00, 28.45it/s]
Normalizing x_signed features: 100%|██████████| 1/1 [00:00<00:00, 19.27it/s]



Fitting and normalizing pos features...
Test set normalized
Creating train loader...
Train loader created
Creating validation loader...
Validation loader created
Creating test loader...
Test loader created


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Dataloaders and scalers saved


[34m[1mwandb[0m: Currently logged in as: [33mthuaduc24042001[0m ([33mthuaduc24042001-technical-university-of-munich[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
def create_model(architecture: str, config: object, device: torch.device):
    """
    Factory function to create the specified model architecture.

    Parameters:
    - architecture: str, the name of the architecture to use
    - config: object containing model parameters
    - device: torch device to put the model on

    Returns:
    - Initialized model on the specified device
    """
    if architecture == "point_net_transf_gat":
        return PointNetTransfGAT(
            in_channels=config.in_channels,
            out_channels=config.out_channels,
            point_net_conv_layer_structure_local_mlp=config.point_net_conv_layer_structure_local_mlp,
            point_net_conv_layer_structure_global_mlp=config.point_net_conv_layer_structure_global_mlp,
            gat_conv_layer_structure=config.gat_conv_layer_structure,
            use_dropout=config.use_dropout,
            dropout=config.dropout,
            predict_mode_stats=config.predict_mode_stats,
            dtype=torch.float32,
        ).to(device)
    elif architecture == "eign":
        # TO BE IMPLEMENTED
        return EIGNLaplacianConv(
            in_channels_signed=1,
            out_channels_signed=1,
            in_channels_unsigned=5,
            out_channels_unsigned=1,
            hidden_channels_signed=32,
            hidden_channels_unsigned=32,
            num_blocks=4,
        ).to(device)
    else:
        raise ValueError(f"Unknown architecture: {architecture}")

In [8]:
gnn_instance = create_model("eign", config, device)



model = gnn_instance.to(device)
loss_fct = EIGN_Loss(
    config.loss_fct, datalist[0].x.shape[0], device, config.use_weighted_loss
)

baseline_loss_mean_target = compute_baseline_of_mean_target(
    dataset=train_dl, loss_fct=loss_fct, device=device, scalers=scalers_train
)
baseline_loss = compute_baseline_of_no_policies(
    dataset=train_dl, loss_fct=loss_fct, device=device, scalers=scalers_train
)
print("baseline loss mean " + str(baseline_loss_mean_target))
print("baseline loss no  " + str(baseline_loss))

early_stopping = EarlyStopping(patience=params["early_stopping_patience"], verbose=True)

baseline loss mean 4.058615684509277
baseline loss no  4.065121173858643


In [9]:
# Count and print the number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {total_params - trainable_params:,}")

# Print parameter breakdown by layer
print("\nParameter breakdown by layer:")
for name, param in model.named_parameters():
    print(f"{name}: {param.numel():,} parameters, shape: {param.shape}")

Total parameters: 35,968
Trainable parameters: 35,968
Non-trainable parameters: 0

Parameter breakdown by layer:
blocks.0.signed_fusion_layer.lin_layer1.weight: 1,024 parameters, shape: torch.Size([32, 32])
blocks.0.signed_fusion_layer.lin_layer2.weight: 1,024 parameters, shape: torch.Size([32, 32])
blocks.0.unsigned_fusion_layer.lin_layer1.weight: 1,024 parameters, shape: torch.Size([32, 32])
blocks.0.unsigned_fusion_layer.lin_layer1.bias: 32 parameters, shape: torch.Size([32])
blocks.0.unsigned_fusion_layer.lin_layer2.weight: 1,024 parameters, shape: torch.Size([32, 32])
blocks.0.unsigned_fusion_layer.lin_layer2.bias: 32 parameters, shape: torch.Size([32])
blocks.0.unsigned_conv.lin.weight: 160 parameters, shape: torch.Size([32, 5])
blocks.0.unsigned_conv.conv.bias: 32 parameters, shape: torch.Size([32])
blocks.0.unsigned_conv.conv.lin.weight: 160 parameters, shape: torch.Size([32, 5])
blocks.0.unsigned_signed_conv.lin.weight: 160 parameters, shape: torch.Size([32, 5])
blocks.0.signe

In [None]:
best_val_loss, best_epoch = gnn_instance.train_model(
    config=config,
    loss_fct=loss_fct,
    optimizer=torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-4),
    train_dl=train_dl,
    valid_dl=valid_dl,
    device=device,
    early_stopping=early_stopping,
    model_save_path=model_save_path,
    scalers_train=scalers_train,
    scalers_validation=scalers_validation,
    use_signed=False,
)

print(
    f"Best model saved to {model_save_path} with validation loss: {best_val_loss} at epoch {best_epoch}"
)

Epoch 1/100: 100%|██████████| 499/499 [01:18<00:00,  6.37it/s]



epoch: 0, validation loss: 3.869412457689326, lr: 0.00019959919839679358, r^2: 0.018782854080200195
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 3.869412457689326
Checkpoint saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/checkpoints/checkpoint_epoch_0.pt


Epoch 2/100: 100%|██████████| 499/499 [01:18<00:00,  6.35it/s]
Epoch 2/100: 100%|██████████| 499/499 [01:18<00:00,  6.35it/s]


epoch: 1, validation loss: 2.5360361109388636, lr: 0.0003995991983967936, r^2: 0.35947632789611816
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 2.5360361109388636


Epoch 3/100: 100%|██████████| 499/499 [01:17<00:00,  6.46it/s]
Epoch 3/100: 100%|██████████| 499/499 [01:17<00:00,  6.46it/s]


epoch: 2, validation loss: 2.1454220781935023, lr: 0.0005995991983967936, r^2: 0.4582584500312805
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 2.1454220781935023


Epoch 4/100: 100%|██████████| 499/499 [01:17<00:00,  6.44it/s]



epoch: 3, validation loss: 1.9038724442745776, lr: 0.0007995991983967936, r^2: 0.5202885866165161
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.9038724442745776


Epoch 5/100: 100%|██████████| 499/499 [01:17<00:00,  6.41it/s]



epoch: 4, validation loss: 1.7652639017460194, lr: 0.0009995991983967937, r^2: 0.5534452199935913
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.7652639017460194


Epoch 6/100: 100%|██████████| 499/499 [01:17<00:00,  6.43it/s]
Epoch 6/100: 100%|██████████| 499/499 [01:17<00:00,  6.43it/s]


epoch: 5, validation loss: 1.6707858037441334, lr: 0.0009997304459184093, r^2: 0.5786421298980713
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.6707858037441334


Epoch 7/100: 100%|██████████| 499/499 [01:17<00:00,  6.41it/s]



epoch: 6, validation loss: 1.5998551528504554, lr: 0.0009989199124447589, r^2: 0.5967550277709961
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.5998551528504554


Epoch 8/100: 100%|██████████| 499/499 [01:17<00:00,  6.43it/s]



epoch: 7, validation loss: 1.5158839923270204, lr: 0.0009975692847985092, r^2: 0.6173065304756165
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.5158839923270204


Epoch 9/100: 100%|██████████| 499/499 [01:17<00:00,  6.40it/s]
Epoch 9/100: 100%|██████████| 499/499 [01:17<00:00,  6.40it/s]


epoch: 8, validation loss: 1.471365603994816, lr: 0.0009956800398711614, r^2: 0.6283063888549805
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.471365603994816


Epoch 10/100: 100%|██████████| 499/499 [01:18<00:00,  6.37it/s]



epoch: 9, validation loss: 1.4336389214434522, lr: 0.0009932542435243065, r^2: 0.6375834941864014
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.4336389214434522


Epoch 11/100: 100%|██████████| 499/499 [01:17<00:00,  6.41it/s]



epoch: 10, validation loss: 1.4222131067133965, lr: 0.0009902945483306342, r^2: 0.6405092477798462
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.4222131067133965


Epoch 12/100: 100%|██████████| 499/499 [01:17<00:00,  6.43it/s]
Epoch 12/100: 100%|██████████| 499/499 [01:17<00:00,  6.43it/s]


epoch: 11, validation loss: 1.343320298068067, lr: 0.0009868041906733853, r^2: 0.66117924451828
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.343320298068067


Epoch 13/100: 100%|██████████| 499/499 [01:17<00:00,  6.42it/s]



epoch: 12, validation loss: 1.323764144106114, lr: 0.0009827869872074126, r^2: 0.6660875082015991
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.323764144106114


Epoch 14/100: 100%|██████████| 499/499 [01:17<00:00,  6.44it/s]



epoch: 13, validation loss: 1.3044337972681572, lr: 0.0009782473306857266, r^2: 0.6705885529518127
Best model saved to /home/dnguyen/gnn_predicting_effects_of_traffic_policies/data/IDP/eign/trained_model/model.pth with validation loss: 1.3044337972681572


Epoch 15/100:  20%|██        | 102/499 [00:15<01:00,  6.51it/s]