In [1]:
import argparse
import copy
import gc
import json
import math
import os
import psutil
import random
import signal
import subprocess
import sys
import traceback

import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler
from torch.utils.data import DataLoader, Dataset, Subset
from torch_geometric.data import Batch, Data
from torch_geometric.utils import to_undirected
from tqdm import tqdm

import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import help_functions as hf
import joblib
import wandb

# Adding scripts path to sys.path if not already included
scripts_path = os.path.abspath(os.path.join('..'))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)

import gnn_architectures_improved as garch
import gnn_io as gio

def get_memory_info():
        memory_info = psutil.virtual_memory()
        total_memory = memory_info.total / (1024 ** 3)  # Convert bytes to GB
        available_memory = memory_info.available / (1024 ** 3)  # Convert bytes to GB
        used_memory = memory_info.used / (1024 ** 3)  # Convert bytes to GB
        return total_memory, available_memory, used_memory

In [2]:
total_memory, available_memory, used_memory = get_memory_info()
print(f"Total Memory: {total_memory:.2f} GB")
print(f"Available Memory: {available_memory:.2f} GB")
print(f"Used Memory: {used_memory:.2f} GB")

Total Memory: 125.49 GB
Available Memory: 28.61 GB
Used Memory: 95.67 GB


In [3]:
try:
    dataset_path = '../../data/train_data/sim_output_1pm_capacity_reduction_10k_15_10_2024/'
    datalist = []
    batch_num = 1
    # while True and batch_num < 100:
    while True:
        print(f"Processing batch number: {batch_num}")
        # total_memory, available_memory, used_memory = get_memory_info()
        # print(f"Total Memory: {total_memory:.2f} GB")
        # print(f"Available Memory: {available_memory:.2f} GB")
        # print(f"Used Memory: {used_memory:.2f} GB")
        batch_file = os.path.join(dataset_path, f'datalist_batch_{batch_num}.pt')
        if not os.path.exists(batch_file):
            break
        batch_data = torch.load(batch_file, map_location='cpu')
        if isinstance(batch_data, list):
            datalist.extend(batch_data)
        batch_num += 1
    print(f"Loaded {len(datalist)} items into datalist")

except Exception as e:
    print(f"An error occurred: {str(e)}")
    
# Set parameters here
params = {"project_name": "test",
            "num_epochs": 1000,
            "batch_size": 8,
            "point_net_conv_layer_structure_local_mlp": [64, 128],
            "point_net_conv_layer_structure_global_mlp": [256, 64],
            "gat_conv_layer_structure": [128, 256, 256, 128],
            "graph_mlp_layer_structure": [128, 256, 128],
            "lr": 0.001,
            "gradient_accumulation_steps": 3,
            "in_channels": 15,
            "out_channels": 1,
            "early_stopping_patience": 100,
            "unique_model_description": "my_test",
            "dropout": 0.3,
            "use_dropout": False
        } 
    
base_dir = '../../data/' + params['project_name'] + '/'
unique_run_dir = os.path.join(base_dir, params['unique_model_description'])
os.makedirs(unique_run_dir, exist_ok=True)

# Define the paths here
def get_paths(base_dir: str, unique_model_description: str, model_save_path: str = 'trained_model/model.pth'):
    data_path = os.path.join(base_dir, unique_model_description)
    os.makedirs(data_path, exist_ok=True)
    model_save_to = os.path.join(data_path, model_save_path)
    path_to_save_dataloader = os.path.join(data_path, 'data_created_during_training/')
    os.makedirs(os.path.dirname(model_save_to), exist_ok=True)
    os.makedirs(path_to_save_dataloader, exist_ok=True)
    return model_save_to, path_to_save_dataloader

model_save_path, path_to_save_dataloader = get_paths(base_dir=base_dir, unique_model_description= params['unique_model_description'], model_save_path= 'trained_model/model.pth')

Processing batch number: 1
Processing batch number: 2
Processing batch number: 3
Processing batch number: 4
Processing batch number: 5
Processing batch number: 6
Processing batch number: 7
Processing batch number: 8
Processing batch number: 9
Processing batch number: 10
Processing batch number: 11
Processing batch number: 12
Processing batch number: 13
Processing batch number: 14
Processing batch number: 15
Processing batch number: 16
Processing batch number: 17
Processing batch number: 18
Processing batch number: 19
Processing batch number: 20
Processing batch number: 21
Processing batch number: 22
Processing batch number: 23
Processing batch number: 24
Processing batch number: 25
Processing batch number: 26
Processing batch number: 27
Processing batch number: 28
Processing batch number: 29
Processing batch number: 30
Processing batch number: 31
Processing batch number: 32
Processing batch number: 33
Processing batch number: 34
Processing batch number: 35
Processing batch number: 36
P

In [4]:
def check_nans_in_data(datalist):
    nan_counts = {'x': 0, 'pos': 0, 'mode_stats': 0}
    total_items = len(datalist)

    for data in tqdm(datalist, desc="Checking for NaNs"):
        if torch.isnan(data.x).any():
            nan_counts['x'] += 1
        if torch.isnan(data.pos).any():
            nan_counts['pos'] += 1
        if torch.isnan(data.mode_stats).any():
            nan_counts['mode_stats'] += 1

    print(f"NaN check results:")
    print(f"Total items checked: {total_items}")
    print(f"Items with NaNs in x: {nan_counts['x']} ({nan_counts['x']/total_items*100:.2f}%)")
    print(f"Items with NaNs in pos: {nan_counts['pos']} ({nan_counts['pos']/total_items*100:.2f}%)")
    print(f"Items with NaNs in mode_stats: {nan_counts['mode_stats']} ({nan_counts['mode_stats']/total_items*100:.2f}%)")

    return nan_counts

# Usage
nan_results = check_nans_in_data(datalist)

Checking for NaNs: 100%|██████████| 4950/4950 [00:01<00:00, 3706.47it/s]

NaN check results:
Total items checked: 4950
Items with NaNs in x: 0 (0.00%)
Items with NaNs in pos: 0 (0.00%)
Items with NaNs in mode_stats: 0 (0.00%)





In [17]:
datalist[0].x

tensor([[0., 0., 0.,  ..., 1., 1., 0.],
        [0., 0., 0.,  ..., 1., 1., 0.],
        [0., 0., 0.,  ..., 1., 1., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)

In [18]:
datalist[0].pos

tensor([[[ 2.3630, 48.8842],
         [ 2.3631, 48.8844],
         [ 2.3630, 48.8843]],

        [[ 2.3630, 48.8842],
         [ 2.3629, 48.8841],
         [ 2.3629, 48.8842]],

        [[ 2.3614, 48.8810],
         [ 2.3614, 48.8811],
         [ 2.3614, 48.8810]],

        ...,

        [[ 2.3134, 48.8943],
         [ 2.3134, 48.8943],
         [ 2.3134, 48.8943]],

        [[ 2.2712, 48.8380],
         [ 2.2789, 48.8359],
         [ 2.2750, 48.8370]],

        [[ 2.2712, 48.8380],
         [ 2.2712, 48.8380],
         [ 2.2712, 48.8380]]])

In [19]:
datalist[0].mode_stats

tensor([[1.1822e+03, 3.6665e+03],
        [1.8033e+03, 4.9954e+03],
        [4.8058e+02, 4.4846e+03],
        [7.8831e-01, 1.0569e+03],
        [1.6104e+03, 5.4706e+03],
        [1.0098e+03, 1.2123e+03]])

In [5]:
def normalize_dataset(dataset_input, directory_path):
    data_list = [dataset_input.dataset[idx] for idx in dataset_input.indices]
    print("LEN DATALIST")
    print(len(data_list))
    print("Fitting and normalizing x features...")
    normalized_data_list, x_scaler = normalize_x_features_batched(data_list)
    print("x features normalized")
    print(len(normalized_data_list))
    

    print("Fitting and normalizing pos features...")
    normalized_data_list, pos_scaler = normalize_pos_features_batched(normalized_data_list)
    print("Pos features normalized")
    
    print("Fitting and normalizing modestats features...")
    normalized_data_list, modestats_scaler = normalize_modestats_features_batched(normalized_data_list)
    print("Modestats features normalized")
    
    print("FINAL LEN")
    print(len(normalized_data_list))
    return normalized_data_list, (x_scaler, pos_scaler, modestats_scaler)

def normalize_x_features_batched(data_list, batch_size=100):
    scaler = StandardScaler()
    
    # First pass: Fit the scaler
    for i in tqdm(range(0, len(data_list), batch_size), desc="Fitting scaler"):
        batch = data_list[i:i+batch_size]
        batch_x = np.vstack([data.x.numpy() for data in batch])
        scaler.partial_fit(batch_x)
    
    # Second pass: Transform the data
    for i in tqdm(range(0, len(data_list), batch_size), desc="Normalizing x features"):
        batch = data_list[i:i+batch_size]
        batch_x = np.vstack([data.x.numpy() for data in batch])
        batch_x_normalized = scaler.transform(batch_x)
        for j, data in enumerate(batch):
            data.x = torch.tensor(batch_x_normalized[j*31140:(j+1)*31140], dtype=torch.float32)
    
    return data_list, scaler

def normalize_pos_features_batched(data_list, batch_size=1000):
    scaler = StandardScaler()
    
    # First pass: Fit the scaler
    for i in tqdm(range(0, len(data_list), batch_size), desc="Fitting scaler"):
        batch = data_list[i:i+batch_size]
        batch_pos = np.vstack([data.pos.numpy().reshape(-1, 6) for data in batch])
        scaler.partial_fit(batch_pos)
    
    # Second pass: Transform the data
    for i in tqdm(range(0, len(data_list), batch_size), desc="Normalizing pos features"):
        batch = data_list[i:i+batch_size]
        for data in batch:
            pos_reshaped = data.pos.numpy().reshape(-1, 6)
            pos_normalized = scaler.transform(pos_reshaped)
            data.pos = torch.tensor(pos_normalized.reshape(31140, 3, 2), dtype=torch.float32)
    
    return data_list, scaler

def normalize_modestats_features_batched(data_list, batch_size=1000):
    scaler = StandardScaler()
    
    # First pass: Fit the scaler
    for i in tqdm(range(0, len(data_list), batch_size), desc="Fitting scaler"):
        batch = data_list[i:i+batch_size]
        batch_modestats = np.vstack([data.mode_stats.numpy().reshape(1, -1) for data in batch])
        scaler.partial_fit(batch_modestats)
    
    # Second pass: Transform the data
    for i in tqdm(range(0, len(data_list), batch_size), desc="Normalizing modestats features"):
        batch = data_list[i:i+batch_size]
        for data in batch:
            modestats_reshaped = data.mode_stats.numpy().reshape(1, -1)
            modestats_normalized = scaler.transform(modestats_reshaped)
            data.mode_stats = torch.tensor(modestats_normalized.reshape(6, 2), dtype=torch.float32)
    
    return data_list, scaler

def replace_invalid_values(tensor):
    # print(f"Input tensor shape: {tensor.shape}")
    # nan_count = torch.isnan(tensor).sum().item()
    # inf_count = torch.isinf(tensor).sum().item()
    # print(f"NaN count: {nan_count}, Inf count: {inf_count}")
    
    tensor[torch.isnan(tensor)] = 0  # replace NaNs with 0
    tensor[torch.isinf(tensor)] = 0  # replace inf and -inf with 0
    return tensor


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [21]:
# def prepare_data_with_graph_features(datalist, batch_size, path_to_save_dataloader):
#     print(f"Starting prepare_data_with_graph_features with {len(datalist)} items")
    
#     try:
#         print("Splitting into subsets...")
#         train_set, valid_set, test_set = gio.split_into_subsets(dataset=datalist, train_ratio=0.8, val_ratio=0.15, test_ratio=0.05)
#         print(f"Split complete. Train: {len(train_set)}, Valid: {len(valid_set)}, Test: {len(test_set)}")
        
#         print("Normalizing train set...")
#         train_set_normalized, scalers_train = normalize_dataset(dataset_input=train_set, directory_path=path_to_save_dataloader + "train_")
#         print("Train set normalized")
        
#         print("Normalizing validation set...")
#         valid_set_normalized, scalers_validation = normalize_dataset(dataset_input=valid_set, directory_path=path_to_save_dataloader + "valid_")
#         print("Validation set normalized")
#         print(len(valid_set_normalized))
        
#         print("Creating train loader...")
#         train_loader = DataLoader(dataset=train_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
#         print("Train loader created")
        
#         print("Creating validation loader...")
#         val_loader = DataLoader(dataset=valid_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
#         print("Validation loader created")
        
#         return train_loader, val_loader, scalers_train, scalers_validation
#     except Exception as e:
#         print(f"Error in prepare_data_with_graph_features: {str(e)}")
#         import traceback
#         traceback.print_exc()
#         raise

# train_dl, valid_dl, scalers_train, scalers_validation = prepare_data_with_graph_features(datalist=datalist, batch_size= params['batch_size'], path_to_save_dataloader= path_to_save_dataloader)

Starting prepare_data_with_graph_features with 4950 items
Splitting into subsets...
Total dataset length: 4950
Training subset length: 3960
Validation subset length: 742
Test subset length: 248
Split complete. Train: 3960, Valid: 742, Test: 248
Normalizing train set...
LEN DATALIST
3960
Fitting and normalizing x features...


Fitting scaler:   0%|          | 0/40 [00:00<?, ?it/s]

Fitting scaler: 100%|██████████| 40/40 [00:39<00:00,  1.01it/s]
Normalizing x features: 100%|██████████| 40/40 [00:37<00:00,  1.08it/s]


x features normalized
3960
Fitting and normalizing pos features...


Fitting scaler:   0%|          | 0/4 [00:00<?, ?it/s]

: 

In [6]:
batch_size = params['batch_size']
print("Splitting into subsets...")
train_set, valid_set, test_set = gio.split_into_subsets(dataset=datalist, train_ratio=0.8, val_ratio=0.15, test_ratio=0.05)
print(f"Split complete. Train: {len(train_set)}, Valid: {len(valid_set)}, Test: {len(test_set)}")

Splitting into subsets...
Total dataset length: 4950
Training subset length: 3960
Validation subset length: 742
Test subset length: 248
Split complete. Train: 3960, Valid: 742, Test: 248


In [7]:
print("Normalizing train set...")
train_set_normalized, scalers_train = normalize_dataset(dataset_input=train_set, directory_path=path_to_save_dataloader + "train_")
print("Train set normalized")

Normalizing train set...
LEN DATALIST
3960
Fitting and normalizing x features...


Fitting scaler: 100%|██████████| 40/40 [00:31<00:00,  1.26it/s]
Normalizing x features: 100%|██████████| 40/40 [00:15<00:00,  2.53it/s]


x features normalized
3960
Fitting and normalizing pos features...


Fitting scaler: 100%|██████████| 4/4 [00:13<00:00,  3.31s/it]
Normalizing pos features: 100%|██████████| 4/4 [00:10<00:00,  2.61s/it]


Pos features normalized
Fitting and normalizing modestats features...


Fitting scaler: 100%|██████████| 4/4 [00:00<00:00, 84.29it/s]
Normalizing modestats features: 100%|██████████| 4/4 [00:00<00:00,  7.68it/s]

Modestats features normalized
FINAL LEN
3960
Train set normalized





In [8]:
print("Normalizing validation set...")
valid_set_normalized, scalers_validation = normalize_dataset(dataset_input=valid_set, directory_path=path_to_save_dataloader + "valid_")
print("Validation set normalized")
print(len(valid_set_normalized))

Normalizing validation set...
LEN DATALIST
742
Fitting and normalizing x features...


Fitting scaler: 100%|██████████| 8/8 [00:04<00:00,  1.78it/s]
Normalizing x features: 100%|██████████| 8/8 [00:02<00:00,  2.79it/s]


x features normalized
742
Fitting and normalizing pos features...


Fitting scaler: 100%|██████████| 1/1 [00:02<00:00,  2.21s/it]
Normalizing pos features: 100%|██████████| 1/1 [00:01<00:00,  1.61s/it]


Pos features normalized
Fitting and normalizing modestats features...


Fitting scaler: 100%|██████████| 1/1 [00:00<00:00, 126.19it/s]
Normalizing modestats features: 100%|██████████| 1/1 [00:00<00:00, 12.11it/s]

Modestats features normalized
FINAL LEN
742
Validation set normalized
742





In [9]:
print("Creating train loader...")
train_loader = DataLoader(dataset=train_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
print("Train loader created")

print("Creating validation loader...")
val_loader = DataLoader(dataset=valid_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
print("Validation loader created")

Creating train loader...
Train loader created
Creating validation loader...
Validation loader created


In [10]:
train_loader.dataset[0].mode_stats

tensor([[ 4.5609,  4.5641],
        [ 2.4442,  2.2361],
        [ 1.5264,  1.4667],
        [-2.9498, -2.9884],
        [-0.2325, -2.1963],
        [-5.2668, -5.2668]])

In [11]:
train_dl = train_loader
valid_dl = val_loader


In [12]:
gpus = hf.get_available_gpus()
best_gpu = hf.select_best_gpu(gpus)
hf.set_cuda_visible_device(best_gpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = hf.setup_wandb(params['project_name'], {
    "epochs": params['num_epochs'],
    "batch_size": params['batch_size'],
    "lr": params['lr'],
    "gradient_accumulation_steps": params['gradient_accumulation_steps'],
    "early_stopping_patience": params['early_stopping_patience'],
    "point_net_conv_local_mlp": params['point_net_conv_layer_structure_local_mlp'],
    "point_net_conv_global_mlp": params['point_net_conv_layer_structure_global_mlp'],
    "gat_conv_layer_structure": params['gat_conv_layer_structure'],
    "graph_mlp_layer_structure": params['graph_mlp_layer_structure'],
    "in_channels": params['in_channels'],
    "out_channels": params['out_channels'],
    "dropout": params['dropout'],
    "use_dropout": params['use_dropout']
})

model = garch.MyGnn(in_channels=config.in_channels, out_channels=config.out_channels, point_net_conv_layer_structure_local_mlp=config.point_net_conv_local_mlp,
                            point_net_conv_layer_structure_global_mlp=config.point_net_conv_global_mlp,
                            gat_conv_layer_structure=config.gat_conv_layer_structure,
                            graph_mlp_layer_structure=config.graph_mlp_layer_structure,
                            dropout=config.dropout, use_dropout=config.use_dropout)

model.to(device)

loss_fct = torch.nn.MSELoss()

baseline_loss_mean_target = gio.compute_baseline_of_mean_target(dataset=train_dl, loss_fct=loss_fct)
baseline_loss = gio.compute_baseline_of_no_policies(dataset=train_dl, loss_fct=loss_fct)
print("baseline loss mean " + str(baseline_loss_mean_target))
print("baseline loss no  " +str(baseline_loss) )

early_stopping = gio.EarlyStopping(patience=params['early_stopping_patience'], verbose=True)
best_val_loss, best_epoch = garch.train(model=model, 
            config=config, 
            loss_fct=loss_fct,
            optimizer=torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-4),
            train_dl=train_dl,  
            valid_dl=valid_dl,
            device=device, 
            early_stopping=early_stopping,
            accumulation_steps=config.gradient_accumulation_steps,
            model_save_path=model_save_path,
            use_gradient_clipping=True,
            lr_scheduler_warmup_steps=20000,
            lr_scheduler_cosine_decay_rate=0.2)
print(f'Best model saved to {model_save_path} with validation loss: {best_val_loss} at epoch {best_epoch}')  

Using GPU 0 with CUDA_VISIBLE_DEVICES=0


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33menatterer[0m ([33mtum-traffic-engineering[0m). Use [1m`wandb login --relogin`[0m to force relogin


Model initialized
MyGnn(
  (point_net_conv_1): PointNetConv(local_nn=Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
  ), global_nn=Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
  ))
  (point_net_conv_2): PointNetConv(local_nn=Sequential(
    (0): Linear(in_features=66, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
  ), global_nn=Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
  ))
  (point_net_conv_3): PointNetConv(local_nn=Sequential(
    (0)

Epoch 1/1000: 100%|██████████| 495/495 [01:20<00:00,  6.12it/s]


epoch: 0, validation loss: 2.962466355293028, lr: 2.47e-05, r^2: -0.00022220611572265625
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.962466355293028
Checkpoint saved to ../../data/test/my_test/trained_model/checkpoints/checkpoint_epoch_0.pt


Epoch 2/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 1, validation loss: 2.931720866951891, lr: 4.945e-05, r^2: 0.009590446949005127
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.931720866951891


Epoch 3/1000: 100%|██████████| 495/495 [01:21<00:00,  6.09it/s]


epoch: 2, validation loss: 2.8703624074177077, lr: 7.42e-05, r^2: 0.030202865600585938
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.8703624074177077


Epoch 4/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 3, validation loss: 2.833978147916896, lr: 9.894999999999999e-05, r^2: 0.04276949167251587
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.833978147916896


Epoch 5/1000: 100%|██████████| 495/495 [01:21<00:00,  6.09it/s]


epoch: 4, validation loss: 2.8111383530401413, lr: 0.0001237, r^2: 0.05066913366317749
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.8111383530401413


Epoch 6/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 5, validation loss: 2.8192207556898876, lr: 0.00014845, r^2: 0.048224449157714844
EarlyStopping counter: 1 out of 100


Epoch 7/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 6, validation loss: 2.9590860848785727, lr: 0.0001732, r^2: 0.0005736947059631348
EarlyStopping counter: 2 out of 100


Epoch 8/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 7, validation loss: 2.7239293206122612, lr: 0.00019795, r^2: 0.07986128330230713
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.7239293206122612


Epoch 9/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 8, validation loss: 2.6757970368990334, lr: 0.00022270000000000002, r^2: 0.09650003910064697
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.6757970368990334


Epoch 10/1000: 100%|██████████| 495/495 [01:21<00:00,  6.09it/s]


epoch: 9, validation loss: 2.67127533881895, lr: 0.00024745, r^2: 0.09760415554046631
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.67127533881895


Epoch 11/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 10, validation loss: 2.659845152208882, lr: 0.0002722, r^2: 0.10166752338409424
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.659845152208882


Epoch 12/1000: 100%|██████████| 495/495 [01:21<00:00,  6.09it/s]


epoch: 11, validation loss: 2.666088201666391, lr: 0.00029695, r^2: 0.09942156076431274
EarlyStopping counter: 1 out of 100


Epoch 13/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 12, validation loss: 2.725151836231191, lr: 0.0003217, r^2: 0.07979017496109009
EarlyStopping counter: 2 out of 100


Epoch 14/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 13, validation loss: 2.6515517004074587, lr: 0.00034645, r^2: 0.10450887680053711
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.6515517004074587


Epoch 15/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 14, validation loss: 2.8467646875689105, lr: 0.00037119999999999997, r^2: 0.03860306739807129
EarlyStopping counter: 1 out of 100


Epoch 16/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 15, validation loss: 2.7128318996839624, lr: 0.00039595000000000006, r^2: 0.0837705135345459
EarlyStopping counter: 2 out of 100


Epoch 17/1000: 100%|██████████| 495/495 [01:21<00:00,  6.11it/s]


epoch: 16, validation loss: 2.5424389198262203, lr: 0.00042070000000000003, r^2: 0.1412370800971985
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.5424389198262203


Epoch 18/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 17, validation loss: 2.5172953810743106, lr: 0.00044545, r^2: 0.14963650703430176
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.5172953810743106


Epoch 19/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 18, validation loss: 2.4832903236471195, lr: 0.0004702, r^2: 0.16138958930969238
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.4832903236471195


Epoch 20/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 19, validation loss: 2.440176543369088, lr: 0.00049495, r^2: 0.17577511072158813
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.440176543369088


Epoch 21/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 20, validation loss: 2.4584474819962696, lr: 0.0005197000000000001, r^2: 0.16999465227127075
Checkpoint saved to ../../data/test/my_test/trained_model/checkpoints/checkpoint_epoch_20.pt
EarlyStopping counter: 1 out of 100


Epoch 22/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 21, validation loss: 2.423237185324392, lr: 0.00054445, r^2: 0.18165606260299683
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.423237185324392


Epoch 23/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 22, validation loss: 2.383851102603379, lr: 0.0005692000000000001, r^2: 0.19487369060516357
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.383851102603379


Epoch 24/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 23, validation loss: 2.3598719361007854, lr: 0.00059395, r^2: 0.20299017429351807
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.3598719361007854


Epoch 25/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 24, validation loss: 2.3470516127924763, lr: 0.0006187, r^2: 0.20722264051437378
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.3470516127924763


Epoch 26/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 25, validation loss: 2.4297991721860823, lr: 0.00064345, r^2: 0.17927569150924683
EarlyStopping counter: 1 out of 100


Epoch 27/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 26, validation loss: 2.3602595995831233, lr: 0.0006682, r^2: 0.20272880792617798
EarlyStopping counter: 2 out of 100


Epoch 28/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 27, validation loss: 2.436246971930227, lr: 0.00069295, r^2: 0.17718863487243652
EarlyStopping counter: 3 out of 100


Epoch 29/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 28, validation loss: 2.3464901293477705, lr: 0.0007177, r^2: 0.20759552717208862
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.3464901293477705


Epoch 30/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 29, validation loss: 2.3018383031250327, lr: 0.0007424500000000001, r^2: 0.2224147915840149
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.3018383031250327


Epoch 31/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 30, validation loss: 2.3086859820991434, lr: 0.0007672, r^2: 0.2199857234954834
EarlyStopping counter: 1 out of 100


Epoch 32/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 31, validation loss: 2.316098810524069, lr: 0.00079195, r^2: 0.2179507613182068
EarlyStopping counter: 2 out of 100


Epoch 33/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 32, validation loss: 2.3781781581140335, lr: 0.0008167, r^2: 0.19672662019729614
EarlyStopping counter: 3 out of 100


Epoch 34/1000: 100%|██████████| 495/495 [01:21<00:00,  6.11it/s]


epoch: 33, validation loss: 2.333648440658405, lr: 0.0008414500000000001, r^2: 0.21169525384902954
EarlyStopping counter: 4 out of 100


Epoch 35/1000: 100%|██████████| 495/495 [01:21<00:00,  6.11it/s]


epoch: 34, validation loss: 2.3821157486208024, lr: 0.0008662, r^2: 0.19551688432693481
EarlyStopping counter: 5 out of 100


Epoch 36/1000: 100%|██████████| 495/495 [01:21<00:00,  6.11it/s]


epoch: 35, validation loss: 2.336846943824522, lr: 0.00089095, r^2: 0.21053647994995117
EarlyStopping counter: 6 out of 100


Epoch 37/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 36, validation loss: 2.330545784324728, lr: 0.0009157, r^2: 0.21275633573532104
EarlyStopping counter: 7 out of 100


Epoch 38/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 37, validation loss: 2.298000599748345, lr: 0.00094045, r^2: 0.22382193803787231
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.298000599748345


Epoch 39/1000: 100%|██████████| 495/495 [01:20<00:00,  6.11it/s]


epoch: 38, validation loss: 2.2879310372055217, lr: 0.0009651999999999999, r^2: 0.2271556258201599
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.2879310372055217


Epoch 40/1000: 100%|██████████| 495/495 [01:21<00:00,  6.11it/s]


epoch: 39, validation loss: 2.2904008819210913, lr: 0.00098995, r^2: 0.22657173871994019
EarlyStopping counter: 1 out of 100


Epoch 41/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 40, validation loss: 2.21142170634321, lr: 0.0003999996218996755, r^2: 0.25312870740890503
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.21142170634321
Checkpoint saved to ../../data/test/my_test/trained_model/checkpoints/checkpoint_epoch_40.pt


Epoch 42/1000: 100%|██████████| 495/495 [01:21<00:00,  6.11it/s]


epoch: 41, validation loss: 2.1971981410057313, lr: 0.0003999972768877301, r^2: 0.25789082050323486
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.1971981410057313


Epoch 43/1000: 100%|██████████| 495/495 [01:21<00:00,  6.11it/s]


epoch: 42, validation loss: 2.1791428058378157, lr: 0.000399992788261618, r^2: 0.2638503909111023
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.1791428058378157


Epoch 44/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 43, validation loss: 2.1682041909105036, lr: 0.0003999861560694491, r^2: 0.26755571365356445
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.1682041909105036


Epoch 45/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 44, validation loss: 2.1602882544199624, lr: 0.0003999773803823088, r^2: 0.2704750895500183
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.1602882544199624


Epoch 46/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 45, validation loss: 2.1481677255322857, lr: 0.0003999664612942568, r^2: 0.2743532061576843
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.1481677255322857


Epoch 47/1000: 100%|██████████| 495/495 [01:21<00:00,  6.10it/s]


epoch: 46, validation loss: 2.153240002611632, lr: 0.00039995339892232615, r^2: 0.27270740270614624
EarlyStopping counter: 1 out of 100


Epoch 48/1000: 100%|██████████| 495/495 [01:21<00:00,  6.11it/s]


epoch: 47, validation loss: 2.1460842637605566, lr: 0.00039993819340652226, r^2: 0.27527791261672974
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 2.1460842637605566


Epoch 49/1000:  62%|██████▏   | 306/495 [00:49<00:32,  5.76it/s]

In [None]:
# Call this function during training without the scalars and with the directory path, and during the testing with the saved scalars and without a directory path to save.
# def normalize_dataset(dataset_input, directory_path):
#     dataset = copy_subset(dataset_input)
#     dataset = normalize_x_values(dataset, directory_path)
#     dataset = normalize_positional_features(dataset, directory_path)
#     dataset = normalize_mode_stats(dataset, directory_path)
#     return dataset
    
    
# def normalize_x_values(dataset, directory_path):
#     try:
#         shape_of_x = dataset[0].x.shape[1]
#         print(f"Shape of x: {shape_of_x}")
        
#         list_of_scalers_to_save = []
#         print("Processing x values...")

#         # Process in batches
#         batch_size = 100  # Adjust this value based on your available memory
#         for i in range(shape_of_x):
#             print(f"Processing feature {i}/{shape_of_x}")
#             scaler = StandardScaler()
            
#             # Fit scaler in batches
#             for j in range(0, len(dataset), batch_size):
#                 batch = dataset[j:j+batch_size]
#                 print(f"Processing batch {j//batch_size + 1}/{len(dataset)//batch_size + 1}")
#                 batch_x_values = torch.cat([data.x[:, i].reshape(-1, 1) for data in batch], dim=0)
#                 batch_x_values = replace_invalid_values(batch_x_values)
#                 scaler.partial_fit(batch_x_values.numpy())

#             list_of_scalers_to_save.append(scaler)

#             # Transform data
#             for j, data in enumerate(dataset):
#                 if j % 100 == 0:
#                     print(f"Transforming data point {j}/{len(dataset)}")
#                 data_x_dim = replace_invalid_values(data.x[:, i].reshape(-1, 1))
#                 normalized_x_dim = torch.tensor(scaler.transform(data_x_dim.numpy()), dtype=torch.float)
#                 if i == 0:
#                     data.normalized_x = normalized_x_dim
#                 else:
#                     data.normalized_x = torch.cat((data.normalized_x, normalized_x_dim), dim=1)

#         print("Saving scalers...")
#         joblib.dump(list_of_scalers_to_save, (directory_path + 'x_scaler.pkl'))
#         print("Scalers saved successfully")

#         print("Updating x values in dataset...")
#         for data in dataset:
#             data.x = data.normalized_x
#             del data.normalized_x
#         print("Dataset x values updated")

#         return dataset
#     except Exception as e:
#         print(f"Error in normalize_x_values: {str(e)}")
#         traceback.print_exc()
#         raise   


# import networkx as nx
# import matplotlib.pyplot as plt
# import numpy as np
# from collections import Counter

# def check_scale_free_distribution(edge_index, num_nodes):
#     # Create a NetworkX graph from the edge_index
#     G = nx.Graph()
#     G.add_nodes_from(range(num_nodes))
#     edge_list = edge_index.t().tolist()
#     G.add_edges_from(edge_list)

#     # Calculate degree for each node
#     degrees = [d for n, d in G.degree()]
#     degree_counts = Counter(degrees)

#     # Sort the degree counts
#     sorted_degree_counts = sorted(degree_counts.items())
#     x = [k for k, v in sorted_degree_counts]
#     y = [v for k, v in sorted_degree_counts]

#     # Plot degree distribution on log-log scale
#     plt.figure(figsize=(10, 6))
#     plt.loglog(x, y, 'bo-')
#     plt.xlabel('Degree (log scale)')
#     plt.ylabel('Count (log scale)')
#     plt.title('Degree Distribution (Log-Log Scale)')
#     plt.grid(True)

#     # Fit a power law distribution
#     x_log = np.log(x)
#     y_log = np.log(y)
#     coeffs = np.polyfit(x_log[1:], y_log[1:], 1)
#     power_law_exponent = -coeffs[0]

#     # Plot the fitted line
#     x_fit = np.logspace(np.log10(min(x)), np.log10(max(x)), 100)
#     y_fit = np.exp(coeffs[1]) * x_fit**(-power_law_exponent)
#     plt.loglog(x_fit, y_fit, 'r--', label=f'Power Law Fit (γ ≈ {power_law_exponent:.2f})')

#     plt.legend()
#     plt.show()

#     print(f"Estimated power law exponent: γ ≈ {power_law_exponent:.2f}")
    
#     if 2 < power_law_exponent < 3:
#         print("The network shows characteristics of a scale-free network.")
#     else:
#         print("The network may not be scale-free.")

#     return power_law_exponent

# # Usage example:
# # Assuming you have a PyTorch Geometric Data object called 'data'
# exponent = check_scale_free_distribution(data.edge_index, data.num_nodes)

# from torch_geometric.utils import to_undirected, is_undirected

# # Assuming you're working with the first graph in your dataset
# data = train_dl.dataset[0]

# # Check if the graph is already undirected
# # if not is_undirected(data.edge_index):
# #     # If it's directed, convert it to undirected
# #     data.edge_index = to_undirected(data.edge_index)
# #     print("Graph has been converted to undirected.")
# # else:
# #     print("Graph is already undirected.")

# # Verify that the graph is now undirected
# print(f"Is the graph undirected? {is_undirected(data.edge_index)}")


# def normalize_mode_stats(dataset, directory_path):
#     # Initialize 12 StandardScalers for 6 sets of 2 dimensions
#     scalers = [[StandardScaler() for _ in range(2)] for _ in range(6)]

#     # Standardize the data
#     for i in range(6):  # Iterate over the first dimension (6 sets)
#         for j in range(2):  # Iterate over the second dimension (2D vectors)
#             values = np.vstack([data.mode_stats[i, j].numpy().reshape(-1, 1) for data in dataset])
#             # Fit the corresponding scaler on the extracted values
#             scalers[i][j].fit(values)
#             for data in dataset:
#                 transformed = scalers[i][j].transform(data.mode_stats[i, j].numpy().reshape(-1, 1)).flatten()
#                 # Convert the transformed NumPy array back into a torch tensor
#                 data.mode_stats[i, j] = torch.tensor(transformed, dtype=torch.float32)
    
#     # Save the scalers using joblib
#     for i in range(6):
#         for j in range(2):
#             # Dump the scalers with meaningful names to differentiate them
#             scaler_path = directory_path + f'scaler_mode_stats_{i}_{j}.pkl'
#             joblib.dump(scalers[i][j], scaler_path)

#     print("Mode stats scalers saved and dataset standardized.")
#     return dataset

# def replace_invalid_values(tensor):
#     tensor[tensor != tensor] = 0  # replace NaNs with 0
#     tensor[tensor == float('inf')] = 0  # replace inf with 0
#     tensor[tensor == float('-inf')] = 0  # replace -inf with 0
#     return tensor



# def prepare_data_with_graph_features(datalist, batch_size, path_to_save_dataloader):
#     # datalist = [Data(x=d['x'], edge_index=d['edge_index'], edge_attr=d['edge_attr'], pos=d['pos'], y=d['y'], mode_stats=d['mode_stats']) for d in data_dict_list]
#     train_set, valid_set, test_set = gio.split_into_subsets(dataset=datalist, train_ratio=0.8, val_ratio=0.15, test_ratio=0.05)
    
#     train_set_normalized = normalize_dataset(dataset_input = train_set, directory_path=path_to_save_dataloader + "train_")
#     valid_set_normalized = normalize_dataset(dataset_input = valid_set, directory_path=path_to_save_dataloader + "valid_")
#     # # test_set_normalized = normalize_dataset(dataset_input = test_set, directory_path=path_to_save_dataloader + "test_")
        
#     train_loader = DataLoader(dataset=train_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
#     val_loader = DataLoader(dataset=valid_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
#     # test_loader = DataLoader(dataset=test_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
#     # gio.save_dataloader(test_loader, path_to_save_dataloader + 'test_dl.pt')
#     # gio.save_dataloader_params(test_loader, path_to_save_dataloader + 'test_loader_params.json')
    
#     return train_loader, val_loader


# def normalize_x_values(dataset, directory_path):
#     shape_of_x = dataset[0].x.shape[1]
#     list_of_scalers_to_save = []
#     x_values = torch.cat([data.x for data in dataset], dim=0)

#     for i in range(shape_of_x):
#         all_node_features = replace_invalid_values(x_values[:, i].reshape(-1, 1)).numpy()
        
#         scaler = StandardScaler()
#         print(f"Scaler created for x values at index {i}: {scaler}")
#         scaler.fit(all_node_features)
#         list_of_scalers_to_save.append(scaler)

#         for data in dataset:
#             data_x_dim = replace_invalid_values(data.x[:, i].reshape(-1, 1))
#             normalized_x_dim = torch.tensor(scaler.transform(data_x_dim.numpy()), dtype=torch.float)
#             if i == 0:
#                 data.normalized_x = normalized_x_dim
#             else:
#                 data.normalized_x = torch.cat((data.normalized_x, normalized_x_dim), dim=1)

#     joblib.dump(list_of_scalers_to_save, (directory_path + 'x_scaler.pkl'))
#     for data in dataset:
#         data.x = data.normalized_x
#         del data.normalized_x
#     return dataset


# def normalize_positional_features(dataset, directory_path):
#     # Initialize 6 StandardScalers for 3 sets of 2 dimensions
#     scalers = [[StandardScaler() for _ in range(2)] for _ in range(3)]

#     # Standardize the data
#     for i in range(3):  # Iterate over the second dimension (3 sets)
#         for j in range(2):  # Iterate over the third dimension (2D vectors)
#             values = np.vstack([data.pos[:, i, j].numpy() for data in dataset]).reshape(-1, 1)
#             # Fit the corresponding scaler on the extracted values
#             scalers[i][j].fit(values)
#             for data in dataset:
#                 transformed = scalers[i][j].transform(data.pos[:, i, j].numpy().reshape(-1, 1)).flatten()
#                 # Convert the transformed NumPy array back into a torch tensor
#                 data.pos[:, i, j] = torch.tensor(transformed, dtype=torch.float32)
#     # Save the scalers using joblib
#     for i in range(3):
#         for j in range(2):
#             # Dump the scalers with meaningful names to differentiate them
#             scaler_path = directory_path + f'scaler_pos_{i}_{j}.pkl'
#             joblib.dump(scalers[i][j], scaler_path)

#     print("Postional scalers saved and dataset standardized.")
#     return dataset



# working version, but only up to 2000 datapoints
# def get_combined_data(dataset_path, max_batches=None):
#     data_list = []
#     batch_num = 1
#     while max_batches is None or batch_num <= max_batches:
#         try:
#             batch_file = os.path.join(dataset_path, f'datalist_batch_{batch_num}.pt')
#             batch_data = torch.load(batch_file, map_location='cpu')
#             print(f"Batch {batch_num} type: {type(batch_data)}, length: {len(batch_data)}")
            
#             if isinstance(batch_data, list):
#                 for idx, item in enumerate(batch_data):
#                     try:
#                         # print(f"Item {idx} type: {type(item)}")
#                         if isinstance(item, Data):
#                             required_attrs = ['x', 'edge_index', 'pos', 'y', 'mode_stats']
#                             missing_attrs = [attr for attr in required_attrs if not hasattr(item, attr)]
#                             if not missing_attrs:
#                                 data_list.append(item)
#                                 # print(f"Added item {idx} to data_list")
#                             else:
#                                 print(f"Skipping invalid item {idx} in batch {batch_num}. Missing attributes: {missing_attrs}")
#                         else:
#                             print(f"Skipping non-Data item {idx} in batch {batch_num}.")
#                     except Exception as e:
#                         print(f"Error processing item {idx} in batch {batch_num}: {str(e)}")
#             else:
#                 print(f"Unexpected batch data type in batch {batch_num}: {type(batch_data)}")
            
#             batch_num += 1
#             print(f"Loaded batch {batch_num-1}, current total: {len(data_list)} items")
            
#             if len(data_list) % 1000 == 0:
#                 if psutil.virtual_memory().percent > 90:
#                     print("Memory usage high, stopping data loading")
#                     break
#         except FileNotFoundError:
#             print(f"Finished loading {batch_num-1} batches")
#             break
#         except Exception as e:
#             print(f"Error loading batch {batch_num}: {str(e)}")
#             batch_num += 1

#     print(f"Successfully loaded {len(data_list)} data points")
#     return data_list

# def get_combined_data(dataset_path, max_data_points=10000):
#     data_list = []
#     batch_num = 1
#     while len(data_list) < max_data_points:
#         try:
#             batch_file = os.path.join(dataset_path, f'datalist_batch_{batch_num}.pt')
#             batch_data = torch.load(batch_file, map_location='cpu')
#             print(f"Batch {batch_num} type: {type(batch_data)}, length: {len(batch_data)}")
            
#             if isinstance(batch_data, list):
#                 for idx, item in enumerate(batch_data):
#                     if len(data_list) >= max_data_points:
#                         break
#                     try:
#                         if isinstance(item, Data):
#                             required_attrs = ['x', 'edge_index', 'pos', 'y', 'mode_stats']
#                             missing_attrs = [attr for attr in required_attrs if not hasattr(item, attr)]
#                             if not missing_attrs:
#                                 data_list.append(item)
#                             else:
#                                 print(f"Skipping invalid item {idx} in batch {batch_num}. Missing attributes: {missing_attrs}")
#                         else:
#                             print(f"Skipping non-Data item {idx} in batch {batch_num}.")
#                     except Exception as e:
#                         print(f"Error processing item {idx} in batch {batch_num}: {str(e)}")
#             else:
#                 print(f"Unexpected batch data type in batch {batch_num}: {type(batch_data)}")
            
#             batch_num += 1
#             print(f"Loaded batch {batch_num-1}, current total: {len(data_list)} items")
            
#             if psutil.virtual_memory().percent > 90:
#                 print("Memory usage high, stopping data loading")
#                 break
#         except FileNotFoundError:
#             print(f"Finished loading {batch_num-1} batches")
#             break
#         except Exception as e:
#             print(f"Error loading batch {batch_num}: {str(e)}")
#             batch_num += 1

#     print(f"Successfully loaded {len(data_list)} data points")
#     return data_list

# # Usage
# try:
#     dataset_path = '../../data/train_data/sim_output_1pm_capacity_reduction_10k_11_10_2024/'
#     data_list = get_combined_data(dataset_path)  # Let's look at the first two batches
#     print(f"Final count: Successfully loaded {len(data_list)} data points")
# except Exception as e:
#     print(f"An error occurred: {str(e)}")

# def normalize_x_values(dataset, directory_path):
#     try:
#         shape_of_x = dataset[0].x.shape[1]
#         print(f"Shape of x: {shape_of_x}")
        
#         list_of_scalers_to_save = []
#         print("Concatenating x values...")
#         x_values = torch.cat([data.x for data in dataset], dim=0)
#         print(f"Concatenated x_values shape: {x_values.shape}")

#         for i in range(shape_of_x):
#             print(f"Processing feature {i}/{shape_of_x}")
#             all_node_features = replace_invalid_values(x_values[:, i].reshape(-1, 1)).numpy()
            
#             scaler = StandardScaler()
#             scaler.fit(all_node_features)
#             list_of_scalers_to_save.append(scaler)

#             for j, data in enumerate(dataset):
#                 if j % 100 == 0:
#                     print(f"Processing data point {j}/{len(dataset)}")
#                 data_x_dim = replace_invalid_values(data.x[:, i].reshape(-1, 1))
#                 normalized_x_dim = torch.tensor(scaler.transform(data_x_dim.numpy()), dtype=torch.float)
#                 if i == 0:
#                     data.normalized_x = normalized_x_dim
#                 else:
#                     data.normalized_x = torch.cat((data.normalized_x, normalized_x_dim), dim=1)

#         joblib.dump(list_of_scalers_to_save, (directory_path + 'x_scaler.pkl'))
#         for data in dataset:
#             data.x = data.normalized_x
#             del data.normalized_x
#         return dataset
#     except Exception as e:
#         print(f"Error in normalize_x_values: {str(e)}")
#         traceback.print_exc()
#         raise

# def normalize_positional_features(dataset, directory_path):
#     try:
#         shape_of_pos = dataset[0].pos.shape
#         print(f"Shape of pos: {shape_of_pos}")
        
#         list_of_scalers_to_save = []
#         print("Concatenating positional values...")
#         pos_values = torch.cat([data.pos.reshape(data.pos.shape[0], -1) for data in dataset], dim=0)
#         print(f"Concatenated pos_values shape: {pos_values.shape}")

#         for i in range(pos_values.shape[1]):
#             print(f"Processing positional feature {i}/{pos_values.shape[1]}")
#             all_pos_features = replace_invalid_values(pos_values[:, i].reshape(-1, 1)).numpy()
            
#             scaler = StandardScaler()
#             scaler.fit(all_pos_features)
#             list_of_scalers_to_save.append(scaler)

#             for j, data in enumerate(dataset):
#                 if j % 100 == 0:
#                     print(f"Processing data point {j}/{len(dataset)}")
#                 data_pos_dim = replace_invalid_values(data.pos.reshape(data.pos.shape[0], -1)[:, i].reshape(-1, 1))
#                 normalized_pos_dim = torch.tensor(scaler.transform(data_pos_dim.numpy()), dtype=torch.float)
#                 if i == 0:
#                     data.normalized_pos = normalized_pos_dim
#                 else:
#                     data.normalized_pos = torch.cat((data.normalized_pos, normalized_pos_dim), dim=1)

#         print("Saving positional scalers...")
#         joblib.dump(list_of_scalers_to_save, (directory_path + 'pos_scaler.pkl'))
#         print("Positional scalers saved successfully")

#         print("Updating pos values in dataset...")
#         for data in dataset:
#             data.pos = data.normalized_pos.reshape(shape_of_pos)
#             del data.normalized_pos
#         print("Dataset pos values updated")

#         return dataset
#     except Exception as e:
#         print(f"Error in normalize_positional_features: {str(e)}")
#         traceback.print_exc()
#         raise

# def normalize_mode_stats(dataset, directory_path):
#     try:
#         print("Starting mode stats normalization...")
#         # Initialize 12 StandardScalers for 6 sets of 2 dimensions
#         scalers = [[StandardScaler() for _ in range(2)] for _ in range(6)]

#         # Standardize the data
#         for i in range(6):  # Iterate over the first dimension (6 sets)
#             for j in range(2):  # Iterate over the second dimension (2D vectors)
#                 print(f"Processing mode stats dimension {i}, {j}")
#                 values = np.vstack([replace_invalid_values(data.mode_stats[i, j].reshape(-1, 1)) for data in dataset])
#                 print(f"Collected values shape: {values.shape}")
                
#                 # Fit the corresponding scaler on the extracted values
#                 scalers[i][j].fit(values)
                
#                 for k, data in enumerate(dataset):
#                     if k % 100 == 0:
#                         print(f"Transforming data point {k}/{len(dataset)} for dimension {i}, {j}")
#                     data_mode_stats_dim = replace_invalid_values(data.mode_stats[i, j].reshape(-1, 1))
#                     transformed = scalers[i][j].transform(data_mode_stats_dim).flatten()
#                     # Convert the transformed NumPy array back into a torch tensor
#                     data.mode_stats[i, j] = torch.tensor(transformed, dtype=torch.float32)

#         print("Saving mode stats scalers...")
#         # Save the scalers using joblib
#         for i in range(6):
#             for j in range(2):
#                 # Dump the scalers with meaningful names to differentiate them
#                 scaler_path = directory_path + f'scaler_mode_stats_{i}_{j}.pkl'
#                 joblib.dump(scalers[i][j], scaler_path)

#         print("Mode stats scalers saved and dataset standardized.")
#         return dataset
#     except Exception as e:
#         print(f"Error in normalize_mode_stats: {str(e)}")
#         traceback.print_exc()
#         raise