In [1]:
import math
import numpy as np
import wandb
import random
import torch
import torch_geometric
from torch_geometric.data import Data
import sys
import os
from tqdm import tqdm
import signal
import joblib
import argparse
import json
import os
import subprocess
from torch.utils.data import DataLoader, Dataset, Subset

from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
import help_functions as hf

import psutil
from torch_geometric.data import Data

scripts_path = os.path.abspath(os.path.join('..'))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)
    
import gnn_io as gio
import gnn_architectures_district_features as garch
import copy
from torch_geometric.utils import to_undirected

# This is current working status (11.10.2024)


In [2]:
import traceback

def normalize_dataset(dataset_input, directory_path):
    try:
        print(f"Starting normalization for {len(dataset_input)} items")
        dataset = copy_subset(dataset_input)
        print("Dataset copied successfully")
        
        dataset = normalize_x_values(dataset, directory_path)
        print("X values normalized successfully")
        
        dataset = normalize_positional_features(dataset, directory_path)
        print("Positional features normalized successfully")
        
        dataset = normalize_mode_stats(dataset, directory_path)
        print("Mode stats normalized successfully")
        
        return dataset
    except Exception as e:
        print(f"Error in normalize_dataset: {str(e)}")
        traceback.print_exc()
        raise
    

# Call this function during training without the scalars and with the directory path, and during the testing with the saved scalars and without a directory path to save.
# def normalize_dataset(dataset_input, directory_path):
#     dataset = copy_subset(dataset_input)
#     dataset = normalize_x_values(dataset, directory_path)
#     dataset = normalize_positional_features(dataset, directory_path)
#     dataset = normalize_mode_stats(dataset, directory_path)
#     return dataset
    
    
# def normalize_x_values(dataset, directory_path):
#     try:
#         shape_of_x = dataset[0].x.shape[1]
#         print(f"Shape of x: {shape_of_x}")
        
#         list_of_scalers_to_save = []
#         print("Processing x values...")

#         # Process in batches
#         batch_size = 100  # Adjust this value based on your available memory
#         for i in range(shape_of_x):
#             print(f"Processing feature {i}/{shape_of_x}")
#             scaler = StandardScaler()
            
#             # Fit scaler in batches
#             for j in range(0, len(dataset), batch_size):
#                 batch = dataset[j:j+batch_size]
#                 print(f"Processing batch {j//batch_size + 1}/{len(dataset)//batch_size + 1}")
#                 batch_x_values = torch.cat([data.x[:, i].reshape(-1, 1) for data in batch], dim=0)
#                 batch_x_values = replace_invalid_values(batch_x_values)
#                 scaler.partial_fit(batch_x_values.numpy())

#             list_of_scalers_to_save.append(scaler)

#             # Transform data
#             for j, data in enumerate(dataset):
#                 if j % 100 == 0:
#                     print(f"Transforming data point {j}/{len(dataset)}")
#                 data_x_dim = replace_invalid_values(data.x[:, i].reshape(-1, 1))
#                 normalized_x_dim = torch.tensor(scaler.transform(data_x_dim.numpy()), dtype=torch.float)
#                 if i == 0:
#                     data.normalized_x = normalized_x_dim
#                 else:
#                     data.normalized_x = torch.cat((data.normalized_x, normalized_x_dim), dim=1)

#         print("Saving scalers...")
#         joblib.dump(list_of_scalers_to_save, (directory_path + 'x_scaler.pkl'))
#         print("Scalers saved successfully")

#         print("Updating x values in dataset...")
#         for data in dataset:
#             data.x = data.normalized_x
#             del data.normalized_x
#         print("Dataset x values updated")

#         return dataset
#     except Exception as e:
#         print(f"Error in normalize_x_values: {str(e)}")
#         traceback.print_exc()
#         raise   
    
def normalize_x_values(dataset, directory_path):
    try:
        shape_of_x = dataset[0].x.shape[1]
        print(f"Shape of x: {shape_of_x}")
        
        list_of_scalers_to_save = []
        print("Concatenating x values...")
        x_values = torch.cat([data.x for data in dataset], dim=0)
        print(f"Concatenated x_values shape: {x_values.shape}")

        for i in range(shape_of_x):
            print(f"Processing feature {i}/{shape_of_x}")
            all_node_features = replace_invalid_values(x_values[:, i].reshape(-1, 1)).numpy()
            
            scaler = StandardScaler()
            scaler.fit(all_node_features)
            list_of_scalers_to_save.append(scaler)

            for j, data in enumerate(dataset):
                if j % 100 == 0:
                    print(f"Processing data point {j}/{len(dataset)}")
                data_x_dim = replace_invalid_values(data.x[:, i].reshape(-1, 1))
                normalized_x_dim = torch.tensor(scaler.transform(data_x_dim.numpy()), dtype=torch.float)
                if i == 0:
                    data.normalized_x = normalized_x_dim
                else:
                    data.normalized_x = torch.cat((data.normalized_x, normalized_x_dim), dim=1)

        joblib.dump(list_of_scalers_to_save, (directory_path + 'x_scaler.pkl'))
        for data in dataset:
            data.x = data.normalized_x
            del data.normalized_x
        return dataset
    except Exception as e:
        print(f"Error in normalize_x_values: {str(e)}")
        traceback.print_exc()
        raise

# Function to copy a Subset
def copy_subset(subset):
    return Subset(copy.deepcopy(subset.dataset), copy.deepcopy(subset.indices))


def normalize_positional_features(dataset, directory_path):
    try:
        shape_of_pos = dataset[0].pos.shape
        print(f"Shape of pos: {shape_of_pos}")
        
        list_of_scalers_to_save = []
        print("Concatenating positional values...")
        pos_values = torch.cat([data.pos.reshape(data.pos.shape[0], -1) for data in dataset], dim=0)
        print(f"Concatenated pos_values shape: {pos_values.shape}")

        for i in range(pos_values.shape[1]):
            print(f"Processing positional feature {i}/{pos_values.shape[1]}")
            all_pos_features = replace_invalid_values(pos_values[:, i].reshape(-1, 1)).numpy()
            
            scaler = StandardScaler()
            scaler.fit(all_pos_features)
            list_of_scalers_to_save.append(scaler)

            for j, data in enumerate(dataset):
                if j % 100 == 0:
                    print(f"Processing data point {j}/{len(dataset)}")
                data_pos_dim = replace_invalid_values(data.pos.reshape(data.pos.shape[0], -1)[:, i].reshape(-1, 1))
                normalized_pos_dim = torch.tensor(scaler.transform(data_pos_dim.numpy()), dtype=torch.float)
                if i == 0:
                    data.normalized_pos = normalized_pos_dim
                else:
                    data.normalized_pos = torch.cat((data.normalized_pos, normalized_pos_dim), dim=1)

        print("Saving positional scalers...")
        joblib.dump(list_of_scalers_to_save, (directory_path + 'pos_scaler.pkl'))
        print("Positional scalers saved successfully")

        print("Updating pos values in dataset...")
        for data in dataset:
            data.pos = data.normalized_pos.reshape(shape_of_pos)
            del data.normalized_pos
        print("Dataset pos values updated")

        return dataset
    except Exception as e:
        print(f"Error in normalize_positional_features: {str(e)}")
        traceback.print_exc()
        raise

def normalize_mode_stats(dataset, directory_path):
    try:
        print("Starting mode stats normalization...")
        # Initialize 12 StandardScalers for 6 sets of 2 dimensions
        scalers = [[StandardScaler() for _ in range(2)] for _ in range(6)]

        # Standardize the data
        for i in range(6):  # Iterate over the first dimension (6 sets)
            for j in range(2):  # Iterate over the second dimension (2D vectors)
                print(f"Processing mode stats dimension {i}, {j}")
                values = np.vstack([replace_invalid_values(data.mode_stats[i, j].reshape(-1, 1)) for data in dataset])
                print(f"Collected values shape: {values.shape}")
                
                # Fit the corresponding scaler on the extracted values
                scalers[i][j].fit(values)
                
                for k, data in enumerate(dataset):
                    if k % 100 == 0:
                        print(f"Transforming data point {k}/{len(dataset)} for dimension {i}, {j}")
                    data_mode_stats_dim = replace_invalid_values(data.mode_stats[i, j].reshape(-1, 1))
                    transformed = scalers[i][j].transform(data_mode_stats_dim).flatten()
                    # Convert the transformed NumPy array back into a torch tensor
                    data.mode_stats[i, j] = torch.tensor(transformed, dtype=torch.float32)

        print("Saving mode stats scalers...")
        # Save the scalers using joblib
        for i in range(6):
            for j in range(2):
                # Dump the scalers with meaningful names to differentiate them
                scaler_path = directory_path + f'scaler_mode_stats_{i}_{j}.pkl'
                joblib.dump(scalers[i][j], scaler_path)

        print("Mode stats scalers saved and dataset standardized.")
        return dataset
    except Exception as e:
        print(f"Error in normalize_mode_stats: {str(e)}")
        traceback.print_exc()
        raise

def replace_invalid_values(tensor):
    # print(f"Input tensor shape: {tensor.shape}")
    # nan_count = torch.isnan(tensor).sum().item()
    # inf_count = torch.isinf(tensor).sum().item()
    # print(f"NaN count: {nan_count}, Inf count: {inf_count}")
    
    tensor[torch.isnan(tensor)] = 0  # replace NaNs with 0
    tensor[torch.isinf(tensor)] = 0  # replace inf and -inf with 0
    return tensor


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [3]:
# Set parameters here
params = {"project_name": "test",
            "num_epochs": 1000,
            "batch_size": 8,
            "point_net_conv_layer_structure_local_mlp": [64, 128],
            "point_net_conv_layer_structure_global_mlp": [256, 64],
            "gat_conv_layer_structure": [128, 256, 256, 128],
            "graph_mlp_layer_structure": [128, 256, 128],
            "lr": 0.001,
            "gradient_accumulation_steps": 3,
            "in_channels": 15,
            "out_channels": 1,
            "early_stopping_patience": 100,
            "unique_model_description": "my_test",
            "dropout": 0.3,
            "use_dropout": False
        } 

# Define the paths here
def get_paths(base_dir: str, unique_model_description: str, model_save_path: str = 'trained_model/model.pth'):
    data_path = os.path.join(base_dir, unique_model_description)
    os.makedirs(data_path, exist_ok=True)
    model_save_to = os.path.join(data_path, model_save_path)
    path_to_save_dataloader = os.path.join(data_path, 'data_created_during_training/')
    os.makedirs(os.path.dirname(model_save_to), exist_ok=True)
    os.makedirs(path_to_save_dataloader, exist_ok=True)
    return model_save_to, path_to_save_dataloader


def get_combined_data(dataset_path, max_batches=None):
    data_list = []
    batch_num = 1
    while max_batches is None or batch_num <= max_batches:
        try:
            batch_file = os.path.join(dataset_path, f'datalist_batch_{batch_num}.pt')
            batch_data = torch.load(batch_file, map_location='cpu')
            print(f"Batch {batch_num} type: {type(batch_data)}, length: {len(batch_data)}")
            
            if isinstance(batch_data, list):
                for idx, item in enumerate(batch_data):
                    try:
                        # print(f"Item {idx} type: {type(item)}")
                        if isinstance(item, Data):
                            required_attrs = ['x', 'edge_index', 'edge_attr', 'pos', 'y', 'mode_stats']
                            missing_attrs = [attr for attr in required_attrs if not hasattr(item, attr)]
                            if not missing_attrs:
                                data_list.append(item)
                                # print(f"Added item {idx} to data_list")
                            else:
                                print(f"Skipping invalid item {idx} in batch {batch_num}. Missing attributes: {missing_attrs}")
                        else:
                            print(f"Skipping non-Data item {idx} in batch {batch_num}.")
                    except Exception as e:
                        print(f"Error processing item {idx} in batch {batch_num}: {str(e)}")
            else:
                print(f"Unexpected batch data type in batch {batch_num}: {type(batch_data)}")
            
            batch_num += 1
            print(f"Loaded batch {batch_num-1}, current total: {len(data_list)} items")
            
            if len(data_list) % 1000 == 0:
                if psutil.virtual_memory().percent > 90:
                    print("Memory usage high, stopping data loading")
                    break
        except FileNotFoundError:
            print(f"Finished loading {batch_num-1} batches")
            break
        except Exception as e:
            print(f"Error loading batch {batch_num}: {str(e)}")
            batch_num += 1

    print(f"Successfully loaded {len(data_list)} data points")
    return data_list

# Usage
try:
    dataset_path = '../../data/train_data/sim_output_1pm_capacity_reduction_10k_09_10_2024/'
    data_list = get_combined_data(dataset_path, max_batches=10)  # Let's look at the first two batches
    print(f"Final count: Successfully loaded {len(data_list)} data points")
except Exception as e:
    print(f"An error occurred: {str(e)}")

Batch 1 type: <class 'list'>, length: 100
Loaded batch 1, current total: 100 items
Batch 2 type: <class 'list'>, length: 100
Loaded batch 2, current total: 200 items
Batch 3 type: <class 'list'>, length: 100
Loaded batch 3, current total: 300 items
Batch 4 type: <class 'list'>, length: 100
Loaded batch 4, current total: 400 items
Batch 5 type: <class 'list'>, length: 100
Loaded batch 5, current total: 500 items
Batch 6 type: <class 'list'>, length: 100
Loaded batch 6, current total: 600 items
Batch 7 type: <class 'list'>, length: 100
Loaded batch 7, current total: 700 items
Batch 8 type: <class 'list'>, length: 100
Loaded batch 8, current total: 800 items
Batch 9 type: <class 'list'>, length: 100
Loaded batch 9, current total: 900 items
Batch 10 type: <class 'list'>, length: 100
Loaded batch 10, current total: 1000 items
Successfully loaded 1000 data points
Final count: Successfully loaded 1000 data points


In [4]:
# Create base directory for the run
base_dir = '../../data/' + params['project_name'] + '/'
unique_run_dir = os.path.join(base_dir, params['unique_model_description'])
os.makedirs(unique_run_dir, exist_ok=True)

model_save_path, path_to_save_dataloader = get_paths(base_dir=base_dir, unique_model_description= params['unique_model_description'], model_save_path= 'trained_model/model.pth')

In [5]:
def prepare_data_with_graph_features(datalist, batch_size, path_to_save_dataloader):
    print(f"Starting prepare_data_with_graph_features with {len(datalist)} items")
    
    try:
        print("Splitting into subsets...")
        train_set, valid_set, test_set = gio.split_into_subsets(dataset=datalist, train_ratio=0.8, val_ratio=0.15, test_ratio=0.05)
        print(f"Split complete. Train: {len(train_set)}, Valid: {len(valid_set)}, Test: {len(test_set)}")
        
        print("Normalizing train set...")
        train_set_normalized = normalize_dataset(dataset_input=train_set, directory_path=path_to_save_dataloader + "train_")
        print("Train set normalized")
        
        print("Normalizing validation set...")
        valid_set_normalized = normalize_dataset(dataset_input=valid_set, directory_path=path_to_save_dataloader + "valid_")
        print("Validation set normalized")
        
        print("Creating train loader...")
        train_loader = DataLoader(dataset=train_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
        print("Train loader created")
        
        print("Creating validation loader...")
        val_loader = DataLoader(dataset=valid_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
        print("Validation loader created")
        
        return train_loader, val_loader
    except Exception as e:
        print(f"Error in prepare_data_with_graph_features: {str(e)}")
        import traceback
        traceback.print_exc()
        raise

train_dl, valid_dl = prepare_data_with_graph_features(datalist=data_list, batch_size= params['batch_size'], path_to_save_dataloader= path_to_save_dataloader)


Starting prepare_data_with_graph_features with 1000 items
Splitting into subsets...
Total dataset length: 1000
Training subset length: 800
Validation subset length: 150
Test subset length: 50
Split complete. Train: 800, Valid: 150, Test: 50
Normalizing train set...
Starting normalization for 800 items
Dataset copied successfully
Shape of x: 15
Concatenating x values...
Concatenated x_values shape: torch.Size([24928000, 15])
Processing feature 0/15
Input tensor shape: torch.Size([24928000, 1])
NaN count: 0, Inf count: 0
Processing data point 0/800
Input tensor shape: torch.Size([31160, 1])
NaN count: 0, Inf count: 0
Input tensor shape: torch.Size([31160, 1])
NaN count: 0, Inf count: 0
Input tensor shape: torch.Size([31160, 1])
NaN count: 0, Inf count: 0
Input tensor shape: torch.Size([31160, 1])
NaN count: 0, Inf count: 0
Input tensor shape: torch.Size([31160, 1])
NaN count: 0, Inf count: 0
Input tensor shape: torch.Size([31160, 1])
NaN count: 0, Inf count: 0
Input tensor shape: torch.S

In [6]:
def check_directionality(data):
    # Get all edges starting from node 0
    outgoing = data.edge_index[1, data.edge_index[0] == 25318].tolist()
    
    # Get all edges ending at node 0
    incoming = data.edge_index[0, data.edge_index[1] == 25318].tolist()
    
    # Check if all outgoing edges have a corresponding incoming edge
    bidirectional = all(node in incoming for node in outgoing) and len(outgoing) == len(incoming)
    
    print(f"Outgoing edges from node 0: {outgoing}")
    print(f"Incoming edges to node 0: {incoming}")
    print(f"The graph is {'bidirectional' if bidirectional else 'unidirectional'}")
    
    return bidirectional

# Use the function on your data object
data = train_dl.dataset[0]
is_bidirectional = check_directionality(data)

Outgoing edges from node 0: [9477, 9478, 9479, 31157]
Incoming edges to node 0: [9479, 23978, 31157]
The graph is unidirectional


In [7]:
data

Data(edge_index=[2, 117679], num_nodes=31160, x=[31160, 15], edge_attr=[117679, 1], pos=[31160, 3, 2], y=[31160, 1], mode_stats=[6, 2])

In [8]:
gpus = hf.get_available_gpus()
best_gpu = hf.select_best_gpu(gpus)
hf.set_cuda_visible_device(best_gpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = hf.setup_wandb(params['project_name'], {
    "epochs": params['num_epochs'],
    "batch_size": params['batch_size'],
    "lr": params['lr'],
    "gradient_accumulation_steps": params['gradient_accumulation_steps'],
    "early_stopping_patience": params['early_stopping_patience'],
    "point_net_conv_local_mlp": params['point_net_conv_layer_structure_local_mlp'],
    "point_net_conv_global_mlp": params['point_net_conv_layer_structure_global_mlp'],
    "gat_conv_layer_structure": params['gat_conv_layer_structure'],
    "graph_mlp_layer_structure": params['graph_mlp_layer_structure'],
    "in_channels": params['in_channels'],
    "out_channels": params['out_channels'],
    "dropout": params['dropout'],
    "use_dropout": params['use_dropout']
})

model = garch.MyGnn(in_channels=config.in_channels, out_channels=config.out_channels, point_net_conv_layer_structure_local_mlp=config.point_net_conv_local_mlp,
                            point_net_conv_layer_structure_global_mlp=config.point_net_conv_global_mlp,
                            gat_conv_layer_structure=config.gat_conv_layer_structure,
                            graph_mlp_layer_structure=config.graph_mlp_layer_structure,
                            dropout=config.dropout, use_dropout=config.use_dropout)

model.to(device)

loss_fct = torch.nn.MSELoss()

baseline_loss_mean_target = gio.compute_baseline_of_mean_target(dataset=train_dl, loss_fct=loss_fct)
baseline_loss = gio.compute_baseline_of_no_policies(dataset=train_dl, loss_fct=loss_fct)
print("baseline loss mean " + str(baseline_loss_mean_target))
print("baseline loss no  " +str(baseline_loss) )

early_stopping = gio.EarlyStopping(patience=params['early_stopping_patience'], verbose=True)
best_val_loss, best_epoch = garch.train(model=model, 
            config=config, 
            loss_fct=loss_fct,
            optimizer=torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-4),
            train_dl=train_dl,  
            valid_dl=valid_dl,
            device=device, 
            early_stopping=early_stopping,
            accumulation_steps=config.gradient_accumulation_steps,
            model_save_path=model_save_path,
            use_gradient_clipping=True,
            lr_scheduler_warmup_steps=20000,
            lr_scheduler_cosine_decay_rate=0.2)
print(f'Best model saved to {model_save_path} with validation loss: {best_val_loss} at epoch {best_epoch}')  

Using GPU 1 with CUDA_VISIBLE_DEVICES=1


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33menatterer[0m ([33mtum-traffic-engineering[0m). Use [1m`wandb login --relogin`[0m to force relogin


Model initialized
baseline loss mean 48.22662353515625
baseline loss no  48.24829864501953


Epoch 1/1000: 100%|██████████| 100/100 [00:23<00:00,  4.19it/s]


epoch: 0, validation loss: 45.98786625109221, lr: 4.950000000000001e-06, r^2: -0.0007609128952026367
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 45.98786625109221
Checkpoint saved to ../../data/test/my_test/trained_model/checkpoints/checkpoint_epoch_0.pt


Epoch 2/1000: 100%|██████████| 100/100 [00:23<00:00,  4.20it/s]


epoch: 1, validation loss: 45.877669083444694, lr: 9.950000000000001e-06, r^2: 2.4020671844482422e-05
Best model saved to ../../data/test/my_test/trained_model/model.pth with validation loss: 45.877669083444694


Epoch 3/1000: 100%|██████████| 100/100 [00:24<00:00,  4.17it/s]


epoch: 2, validation loss: 45.999485417416224, lr: 1.4950000000000001e-05, r^2: 0.00015604496002197266
EarlyStopping counter: 1 out of 100


Epoch 4/1000:  98%|█████████▊| 98/100 [00:23<00:00,  4.51it/s]

In [None]:
# import networkx as nx
# import matplotlib.pyplot as plt
# import numpy as np
# from collections import Counter

# def check_scale_free_distribution(edge_index, num_nodes):
#     # Create a NetworkX graph from the edge_index
#     G = nx.Graph()
#     G.add_nodes_from(range(num_nodes))
#     edge_list = edge_index.t().tolist()
#     G.add_edges_from(edge_list)

#     # Calculate degree for each node
#     degrees = [d for n, d in G.degree()]
#     degree_counts = Counter(degrees)

#     # Sort the degree counts
#     sorted_degree_counts = sorted(degree_counts.items())
#     x = [k for k, v in sorted_degree_counts]
#     y = [v for k, v in sorted_degree_counts]

#     # Plot degree distribution on log-log scale
#     plt.figure(figsize=(10, 6))
#     plt.loglog(x, y, 'bo-')
#     plt.xlabel('Degree (log scale)')
#     plt.ylabel('Count (log scale)')
#     plt.title('Degree Distribution (Log-Log Scale)')
#     plt.grid(True)

#     # Fit a power law distribution
#     x_log = np.log(x)
#     y_log = np.log(y)
#     coeffs = np.polyfit(x_log[1:], y_log[1:], 1)
#     power_law_exponent = -coeffs[0]

#     # Plot the fitted line
#     x_fit = np.logspace(np.log10(min(x)), np.log10(max(x)), 100)
#     y_fit = np.exp(coeffs[1]) * x_fit**(-power_law_exponent)
#     plt.loglog(x_fit, y_fit, 'r--', label=f'Power Law Fit (γ ≈ {power_law_exponent:.2f})')

#     plt.legend()
#     plt.show()

#     print(f"Estimated power law exponent: γ ≈ {power_law_exponent:.2f}")
    
#     if 2 < power_law_exponent < 3:
#         print("The network shows characteristics of a scale-free network.")
#     else:
#         print("The network may not be scale-free.")

#     return power_law_exponent

# # Usage example:
# # Assuming you have a PyTorch Geometric Data object called 'data'
# exponent = check_scale_free_distribution(data.edge_index, data.num_nodes)

# from torch_geometric.utils import to_undirected, is_undirected

# # Assuming you're working with the first graph in your dataset
# data = train_dl.dataset[0]

# # Check if the graph is already undirected
# # if not is_undirected(data.edge_index):
# #     # If it's directed, convert it to undirected
# #     data.edge_index = to_undirected(data.edge_index)
# #     print("Graph has been converted to undirected.")
# # else:
# #     print("Graph is already undirected.")

# # Verify that the graph is now undirected
# print(f"Is the graph undirected? {is_undirected(data.edge_index)}")


# def normalize_mode_stats(dataset, directory_path):
#     # Initialize 12 StandardScalers for 6 sets of 2 dimensions
#     scalers = [[StandardScaler() for _ in range(2)] for _ in range(6)]

#     # Standardize the data
#     for i in range(6):  # Iterate over the first dimension (6 sets)
#         for j in range(2):  # Iterate over the second dimension (2D vectors)
#             values = np.vstack([data.mode_stats[i, j].numpy().reshape(-1, 1) for data in dataset])
#             # Fit the corresponding scaler on the extracted values
#             scalers[i][j].fit(values)
#             for data in dataset:
#                 transformed = scalers[i][j].transform(data.mode_stats[i, j].numpy().reshape(-1, 1)).flatten()
#                 # Convert the transformed NumPy array back into a torch tensor
#                 data.mode_stats[i, j] = torch.tensor(transformed, dtype=torch.float32)
    
#     # Save the scalers using joblib
#     for i in range(6):
#         for j in range(2):
#             # Dump the scalers with meaningful names to differentiate them
#             scaler_path = directory_path + f'scaler_mode_stats_{i}_{j}.pkl'
#             joblib.dump(scalers[i][j], scaler_path)

#     print("Mode stats scalers saved and dataset standardized.")
#     return dataset

# def replace_invalid_values(tensor):
#     tensor[tensor != tensor] = 0  # replace NaNs with 0
#     tensor[tensor == float('inf')] = 0  # replace inf with 0
#     tensor[tensor == float('-inf')] = 0  # replace -inf with 0
#     return tensor



# def prepare_data_with_graph_features(datalist, batch_size, path_to_save_dataloader):
#     # datalist = [Data(x=d['x'], edge_index=d['edge_index'], edge_attr=d['edge_attr'], pos=d['pos'], y=d['y'], mode_stats=d['mode_stats']) for d in data_dict_list]
#     train_set, valid_set, test_set = gio.split_into_subsets(dataset=datalist, train_ratio=0.8, val_ratio=0.15, test_ratio=0.05)
    
#     train_set_normalized = normalize_dataset(dataset_input = train_set, directory_path=path_to_save_dataloader + "train_")
#     valid_set_normalized = normalize_dataset(dataset_input = valid_set, directory_path=path_to_save_dataloader + "valid_")
#     # # test_set_normalized = normalize_dataset(dataset_input = test_set, directory_path=path_to_save_dataloader + "test_")
        
#     train_loader = DataLoader(dataset=train_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
#     val_loader = DataLoader(dataset=valid_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
#     # test_loader = DataLoader(dataset=test_set_normalized, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=gio.collate_fn, worker_init_fn=seed_worker)
#     # gio.save_dataloader(test_loader, path_to_save_dataloader + 'test_dl.pt')
#     # gio.save_dataloader_params(test_loader, path_to_save_dataloader + 'test_loader_params.json')
    
#     return train_loader, val_loader


# def normalize_x_values(dataset, directory_path):
#     shape_of_x = dataset[0].x.shape[1]
#     list_of_scalers_to_save = []
#     x_values = torch.cat([data.x for data in dataset], dim=0)

#     for i in range(shape_of_x):
#         all_node_features = replace_invalid_values(x_values[:, i].reshape(-1, 1)).numpy()
        
#         scaler = StandardScaler()
#         print(f"Scaler created for x values at index {i}: {scaler}")
#         scaler.fit(all_node_features)
#         list_of_scalers_to_save.append(scaler)

#         for data in dataset:
#             data_x_dim = replace_invalid_values(data.x[:, i].reshape(-1, 1))
#             normalized_x_dim = torch.tensor(scaler.transform(data_x_dim.numpy()), dtype=torch.float)
#             if i == 0:
#                 data.normalized_x = normalized_x_dim
#             else:
#                 data.normalized_x = torch.cat((data.normalized_x, normalized_x_dim), dim=1)

#     joblib.dump(list_of_scalers_to_save, (directory_path + 'x_scaler.pkl'))
#     for data in dataset:
#         data.x = data.normalized_x
#         del data.normalized_x
#     return dataset


# def normalize_positional_features(dataset, directory_path):
#     # Initialize 6 StandardScalers for 3 sets of 2 dimensions
#     scalers = [[StandardScaler() for _ in range(2)] for _ in range(3)]

#     # Standardize the data
#     for i in range(3):  # Iterate over the second dimension (3 sets)
#         for j in range(2):  # Iterate over the third dimension (2D vectors)
#             values = np.vstack([data.pos[:, i, j].numpy() for data in dataset]).reshape(-1, 1)
#             # Fit the corresponding scaler on the extracted values
#             scalers[i][j].fit(values)
#             for data in dataset:
#                 transformed = scalers[i][j].transform(data.pos[:, i, j].numpy().reshape(-1, 1)).flatten()
#                 # Convert the transformed NumPy array back into a torch tensor
#                 data.pos[:, i, j] = torch.tensor(transformed, dtype=torch.float32)
#     # Save the scalers using joblib
#     for i in range(3):
#         for j in range(2):
#             # Dump the scalers with meaningful names to differentiate them
#             scaler_path = directory_path + f'scaler_pos_{i}_{j}.pkl'
#             joblib.dump(scalers[i][j], scaler_path)

#     print("Postional scalers saved and dataset standardized.")
#     return dataset
