In [1]:
import math
import numpy as np
import wandb
import pickle

import torch
import torch_geometric
from torch_geometric.data import Data

from gnn_architectures import MyGnn

import gnn_io as gio
import gnn_architectures as garch

## 1. Define model and parameters

## 2. Load data and model

In [2]:
# Load the list of dictionaries
data_dict_list = torch.load('../data/dataset_1pct_0_100.pt')

# Reconstruct the Data objects
datalist = [Data(x=d['x'], edge_index=d['edge_index'], pos=d['pos'], y=d['y']) for d in data_dict_list]

In [3]:
# # Apply normalization to your dataset
dataset_normalized = gio.normalize_dataset(datalist)

baseline_error = gio.compute_baseline_error(datalist)
print(f'Baseline error: {baseline_error}')

# Apply the function to the dataset
dataset_updated = gio.replace_x_with_normalized_x(dataset_normalized)

# Apply the function to the dataset
dataset_updated = gio.cut_dimensions(dataset_updated)

Baseline error: 0.0058123222552239895


In [4]:
test_dl = gio.create_dataloader(dataset=dataset_updated, is_train=True, batch_size=16, train_ratio=0, is_test=True)

Total dataset length: 101


In [5]:
wandb.login()
wandb.init(project="test_with_1pct_data")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the model from the pickle file
with open('model.pkl', 'rb') as f:
    checkpoint = pickle.load(f)

state_dict = checkpoint['state_dict']
config = checkpoint['config']

# Recreate the model using the saved configuration
loaded_gnn_instance = MyGnn(
    in_channels=config['in_channels'],
    out_channels=config['out_channels'],
    hidden_size=config['hidden_size'],
    gat_layers=config['gat_layers'],
    gcn_layers=config['gcn_layers'],
    output_layer=config['output_layer']
)

# Load the state dictionary into the model
loaded_model = loaded_gnn_instance.to(device)
loaded_model.load_state_dict(state_dict)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33menatterer[0m ([33mtum-traffic-engineering[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Currently logged in as: [33menatterer[0m. Use [1m`wandb login --relogin`[0m to force relogin
wandb: ERROR Error while calling W&B API: entity your_entity_name not found during upsertBucket (<Response [404]>)


CommError: It appears that you do not have permission to access the requested resource. Please reach out to the project owner to grant you access. If you have the correct permissions, verify that there are no issues with your networking setup.(Error 404: Not Found)

## 4. Test the model

In [None]:
# Function to evaluate the model
def evaluate(model, test_dl, device):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    with torch.no_grad():  # Disable gradient computation
        for batch in test_dl:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = torch.nn.MSELoss()(outputs, targets)
            test_loss += loss.item()
    avg_test_loss = test_loss / len(test_dl)
    
    # Log the test loss to Wandb
    wandb.log({"test_loss": avg_test_loss})

# Evaluate the loaded model
test_loss = evaluate(loaded_model, test_dl, device)
print(f'Test Loss: {test_loss}')