This notebook is used for debugging. Should have the same functionality as run_models.py, but with more verbose output.

In [1]:
import math
import numpy as np
import wandb
import random
import torch
import torch_geometric
from torch_geometric.data import Data
import sys
import os
from tqdm import tqdm
import signal
import joblib
import argparse
import json
import os
import subprocess
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.preprocessing import StandardScaler

import help_functions as hf

# Add the 'scripts' directory to the Python path
scripts_path = os.path.abspath(os.path.join('..'))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)
    
import gnn_io as gio
import gnn_architecture as garch

In [2]:
# Output from data_preprocessing.process_simulations_for_gnn.py
dataset_path = '../../data/train_data/dist_not_connected_10k_1pct/'

# Base directory for the run
base_dir = '../../data'

In [3]:
PARAMETERS = [
    "project_name",
    "predict_mode_stats",
    "in_channels",
    "use_all_features",
    "out_channels",
    "loss_fct",
    "use_weighted_loss",
    "point_net_conv_layer_structure_local_mlp",
    "point_net_conv_layer_structure_global_mlp",
    "gat_conv_layer_structure",
    "use_bootrappping",
    "num_epochs",
    "batch_size",
    "lr",
    "early_stopping_patience",
    "use_dropout",
    "dropout",
    "use_monte_carlo_dropout",
    "gradient_accumulation_steps",
    "use_gradient_clipping",
    "lr_scheduler_warmup_steps",
    "device_nr",
    "unique_model_description"
]

def get_parameters(args):
    params = {
        # KEEP IN MIND: IF WE CHANGE PARAMETERS, WE NEED TO CHANGE THE NAME OF THE RUN IN WANDB (for the config)
        "project_name": "runs_01_2025",
        "predict_mode_stats": args.predict_mode_stats,
        "in_channels": args.in_channels,
        "use_all_features": args.use_all_features,
        "out_channels": args.out_channels,
        "loss_fct": args.loss_fct,
        "use_weighted_loss": args.use_weighted_loss,
        "point_net_conv_layer_structure_local_mlp": [int(x) for x in args.point_net_conv_layer_structure_local_mlp.split(',')],
        "point_net_conv_layer_structure_global_mlp": [int(x) for x in args.point_net_conv_layer_structure_global_mlp.split(',')],
        "gat_conv_layer_structure": [int(x) for x in args.gat_conv_layer_structure.split(',')],
        "use_bootrappping": args.use_bootrappping,
        "num_epochs": args.num_epochs,
        "batch_size": int(args.batch_size),
        "lr": float(args.lr),
        "early_stopping_patience": args.early_stopping_patience,
        "use_dropout": args.use_dropout,
        "dropout": args.dropout,
        "use_monte_carlo_dropout": args.use_monte_carlo_dropout,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,
        "use_gradient_clipping": args.use_gradient_clipping,
        "lr_scheduler_warmup_steps": args.lr_scheduler_warmup_steps,
        "device_nr": args.device_nr
    }
    
    params["unique_model_description"] = (
        f"pnc_local_{gio.int_list_to_string(lst=params['point_net_conv_layer_structure_local_mlp'], delimiter='_')}_"
        f"pnc_global_{gio.int_list_to_string(lst=params['point_net_conv_layer_structure_global_mlp'], delimiter='_')}_"
        f"gat_conv_{gio.int_list_to_string(lst=params['gat_conv_layer_structure'], delimiter='_')}_"
        f"use_dropout_{params['use_dropout']}_"
        f"do_{params['dropout']}_"
        f"use_mc_do_{params['use_monte_carlo_dropout']}_"
        f"predict_mode_stats_{params['predict_mode_stats']}"
    )
    
    return params

In [None]:
datalist = []
batch_num = 1
while True:
    print(f"Processing batch number: {batch_num}")
    # total_memory, available_memory, used_memory = get_memory_info()
    # print(f"Total Memory: {total_memory:.2f} GB")
    # print(f"Available Memory: {available_memory:.2f} GB")
    # print(f"Used Memory: {used_memory:.2f} GB")
    batch_file = os.path.join(dataset_path, f'datalist_batch_{batch_num}.pt')
    if not os.path.exists(batch_file):
        break
    batch_data = torch.load(batch_file, map_location='cpu')
    if isinstance(batch_data, list):
        datalist.extend(batch_data)
    batch_num += 1
print(f"Loaded {len(datalist)} items into datalist")

In [5]:
# Replace the argparse section with this:
args = {
    "in_channels": 5,
    "use_all_features": False,
    "out_channels": 1,
    "loss_fct": "mse",
    "use_weighted_loss": True,
    "predict_mode_stats": False,
    "point_net_conv_layer_structure_local_mlp": "256",
    "point_net_conv_layer_structure_global_mlp": "512",
    "gat_conv_layer_structure": "128,256,512,256",
    "use_bootrappping": False,
    "num_epochs": 3000,
    "batch_size": 8,
    "lr": 0.001,
    "early_stopping_patience": 100,
    "use_dropout": True,
    "dropout": 0.3,
    "use_monte_carlo_dropout": True,
    "gradient_accumulation_steps": 3,
    "use_gradient_clipping": True,
    "lr_scheduler_warmup_steps": 10000,
    "device_nr": 0
}

# Convert the dictionary to an object with attributes
class Args:
    def __init__(self, **entries):
        self.__dict__.update(entries)

args = Args(**args)
hf.set_random_seeds()

In [None]:
gpus = hf.get_available_gpus()
best_gpu = hf.select_best_gpu(gpus)
hf.set_cuda_visible_device(best_gpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = get_parameters(args)

# Create directory for the run
unique_run_dir = os.path.join(base_dir, params['project_name'], params['unique_model_description'])
os.makedirs(unique_run_dir, exist_ok=True)

model_save_path, path_to_save_dataloader = hf.get_paths(base_dir=os.path.join(base_dir, params['project_name']), unique_model_description= params['unique_model_description'], model_save_path= 'trained_model/model.pth')
train_dl, valid_dl, scalers_train, scalers_validation = hf.prepare_data_with_graph_features(datalist=datalist,
                                                                                            batch_size=params['batch_size'],
                                                                                            path_to_save_dataloader=path_to_save_dataloader,
                                                                                            use_all_features=params['use_all_features'],
                                                                                            use_bootstrapping=params['use_bootrappping'])

config = hf.setup_wandb({param: params[param] for param in PARAMETERS})

In [None]:
gnn_instance = garch.MyGnn(in_channels=config.in_channels, 
                        out_channels=config.out_channels, 
                        point_net_conv_layer_structure_local_mlp=config.point_net_conv_layer_structure_local_mlp,
                        point_net_conv_layer_structure_global_mlp=config.point_net_conv_layer_structure_global_mlp,
                        gat_conv_layer_structure=config.gat_conv_layer_structure,
                        use_dropout=config.use_dropout, 
                        dropout=config.dropout, 
                        use_monte_carlo_dropout=config.use_monte_carlo_dropout,
                        predict_mode_stats=config.predict_mode_stats, 
                        dtype=torch.float32)
        
model = gnn_instance.to(device)
loss_fct = gio.GNN_Loss(config.loss_fct, datalist[0].x.shape[0], device, config.use_weighted_loss)

baseline_loss_mean_target = gio.compute_baseline_of_mean_target(dataset=train_dl, loss_fct=loss_fct, device=device, scalers=scalers_train)
baseline_loss = gio.compute_baseline_of_no_policies(dataset=train_dl, loss_fct=loss_fct, device=device, scalers=scalers_train)
print("baseline loss mean " + str(baseline_loss_mean_target))
print("baseline loss no  " + str(baseline_loss) )

early_stopping = gio.EarlyStopping(patience=params['early_stopping_patience'], verbose=True)

In [None]:
best_val_loss, best_epoch = garch.train(model=model, 
                    config=config, 
                    loss_fct=loss_fct,
                    optimizer=torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-4),
                    train_dl=train_dl, 
                    valid_dl=valid_dl,
                    device=device, 
                    early_stopping=early_stopping,
                    model_save_path=model_save_path,
                    scalers_train=scalers_train,
                    scalers_validation=scalers_validation)

print(f'Best model saved to {model_save_path} with validation loss: {best_val_loss} at epoch {best_epoch}')