## Preparation

### Imports

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import wandb


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loaded torch. Using *{device}* device.")

__builtins__.device = device  # A hack to allow imported functions to use the device

### Set up the run

In [None]:
# Define the configuration.
config = {
    ## Model configuration
    "architecture": "GraphSAGE",
    "hidden_channels": 32,
    "gnn_layers": 5,
    "mlp_layers": 2,
    "activation": "tanh",
    "pool": "max",
    "jk": "cat",
    "dropout": 0.0,
    ## Training configuration
    "optimizer": "adam",
    "learning_rate": 0.01,
    "epochs": 2000,
    ## Dataset configuration
}

# Set up default values.
selected_graph_sizes = {
    3: -1,
    4: -1,
    5: -1,
    6: -1,
    7: -1,
    8: -1,
    # 9:  100000,
    # 10: 100000
}

# Set up the run
run = wandb.init(mode="disabled", project="gnn_fiedler_approx", tags=["lambda2", "baseline"], config=config)
config = wandb.config

### Load the data

In [None]:
from algebraic_connectivity_script import load_dataset

# Load the dataset.
train_data_obj, test_data_obj, dataset_config, features, dataset_props = load_dataset(
        selected_graph_sizes,
        selected_features=config.get("selected_features", []),
        label_normalization=None,
        split=config.get("dataset", {}).get("split", 0.8),
    )

wandb.config["dataset"] = dataset_config
if "selected_features" not in wandb.config or not wandb.config["selected_features"]:
    wandb.config["selected_features"] = features

### Set up the model, optimizer and loss function

In [None]:
from algebraic_connectivity_script import generate_model, generate_optimizer

model_kwargs = config.get("model_kwargs", {})

model = generate_model(
    config["architecture"],
    dataset_props["feature_dim"],
    config["hidden_channels"],
    config["gnn_layers"],
    mlp_layers=config["mlp_layers"],
    act=config["activation"],
    dropout=float(config["dropout"]),
    pool=config["pool"],
    jk=config["jk"] if config["jk"] != "none" else None,
    **model_kwargs,
)
optimizer = generate_optimizer(model, config["optimizer"], config["learning_rate"])
criterion = torch.nn.L1Loss()

### Training

In [None]:
from algebraic_connectivity_script import train, plot_training_curves

# Run training.
train_results = train(
    model, optimizer, criterion, train_data_obj, test_data_obj, config["epochs"], save_best=True
)
run.summary["best_train_loss"] = min(train_results["train_losses"])
run.summary["best_test_loss"] = min(train_results["test_losses"])
run.summary["duration"] = train_results["duration"]
plot_training_curves(
    config["epochs"], train_results["train_losses"], train_results["test_losses"], type(criterion).__name__
)

### Evaluation

In [None]:
# Load best model
from algebraic_connectivity_script import BEST_MODEL_PATH

checkpoint = torch.load(BEST_MODEL_PATH)
model.load_state_dict(checkpoint["model_state_dict"])
eval_epoch = checkpoint["epoch"]
model.eval()

In [None]:
from algebraic_connectivity_script import evaluate

eval_results = evaluate(
            model, eval_epoch, criterion, train_data_obj, test_data_obj, dataset_props["transformation"],
        )
run.summary["mean_err"] = eval_results["mean_err"]
run.summary["stddev_err"] = eval_results["stddev_err"]
run.summary["good_within"] = eval_results["good_within"]
run.log(
    {
        "abs_err_hist": eval_results["fig_abs_err"],
        "rel_err_hist": eval_results["fig_rel_err"],
        "err_curve": eval_results["fig_err_curve"],
    }
)
# run.log({"results_table": eval_results["table"]})

In [None]:
# Stop the W&B run.
run.finish()

## Explain

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer, PGExplainer, AttentionExplainer

train_data_obj, test_data_obj, dataset_config, features = load_dataset(None, batch_size=1, is_sweep=True)
# model = generate_model("GCN", len(features), 10, 3)


#### GNNExplainer for model

In [None]:
# TODO: Are these results ok?
# Seems like the results are different on every run. Plus, how to interpret the
# results? What hyperparaters to use?

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),  # PGExplainer, AttentionExplainer, CaptumExplainer
    # explanation_type='phenomenon',  # what phenomenon leads from inputs to outputs, labels are targets for explanation
    explanation_type='model',  # open the black box and explain model decisions, predictions are targets for explanation
    node_mask_type="attributes",  # "object", "common_attributes", "attributes"
    edge_mask_type="object",
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

data = train_data_obj.to(device)
explanation = explainer(data.x, data.edge_index, batch=data.batch)
for exp in explanation.available_explanations:
    print(f"{exp}:\n{explanation.__getattr__(exp)}\n")

explanation.visualize_feature_importance(feat_labels=features)
explanation.visualize_graph()

#### GNNExplainer for phenomenon

In [None]:
# TODO: Are these results ok?
# Seems like the results are different on every run. Plus, how to interpret the
# results? What hyperparaters to use?

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),  # PGExplainer, AttentionExplainer, CaptumExplainer
    explanation_type='phenomenon',  # what phenomenon leads from inputs to outputs, labels are targets for explanation
    # explanation_type='model',  # open the black box and explain model decisions, predictions are targets for explanation
    node_mask_type="attributes",  # "object", "common_attributes", "attributes"
    edge_mask_type="object",
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

data = train_data_obj.to(device)
explanation = explainer(data.x, data.edge_index, target=data.y, batch=data.batch)
for exp in explanation.available_explanations:
    print(f"{exp}:\n{explanation.__getattr__(exp)}\n")

explanation.visualize_feature_importance(feat_labels=features)
explanation.visualize_graph()

#### AttentionExplainer for model

In [None]:
# TODO: Are these results ok?
# Seems like the results are different on every run. Plus, how to interpret the
# results? What hyperparaters to use?

explainer = Explainer(
    model=model,
    algorithm=AttentionExplainer(),  # PGExplainer, AttentionExplainer, CaptumExplainer
    # explanation_type='phenomenon',  # what phenomenon leads from inputs to outputs, labels are targets for explanation
    explanation_type='model',  # open the black box and explain model decisions, predictions are targets for explanation
    node_mask_type=None,  # "object", "common_attributes", "attributes"
    edge_mask_type="object",
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

data = train_data_obj.to(device)
explanation = explainer(data.x, data.edge_index, batch=data.batch)
for exp in explanation.available_explanations:
    print(f"{exp}:\n{explanation.__getattr__(exp)}\n")

# explanation.visualize_feature_importance(feat_labels=features)
explanation.visualize_graph()

#### PGEExplainer - WIP

In [None]:
# FIXME: Something is wrong with the implementation.

explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),  # PGExplainer, AttentionExplainer, CaptumExplainer
    explanation_type='phenomenon',  # what phenomenon leads from inputs to outputs, labels are targets for explanation
    # explanation_type='model',  # open the black box and explain model decisions, predictions are targets for explanation
    # node_mask_type="common_attributes",  # Node masks are not supported.
    edge_mask_type="object",
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

data = train_data_obj.to(device)

for epoch in range(30):
  for index in torch.LongTensor(np.random.randint(0, len(data.x), 20)):
    loss = explainer.algorithm.train(epoch, model, data.x, data.edge_index, target=data.y, batch=data.batch, index=index.item())

explanation = explainer(data.x, data.edge_index, target=data.y, batch=data.batch)

for exp in explanation.available_explanations:
    print(f"{exp}:\n{explanation.__getattr__(exp)}\n")

# explanation.visualize_feature_importance(feat_labels=features)
explanation.visualize_graph()

## Housekeeping

### Save the model

In [None]:
# torch.save(model.state_dict(), "model.pth")
# print("Saved PyTorch Model State to model.pth")


### Make predictions with loaded model

In [None]:
# model = NeuralNetwork().to(device)
# model.load_state_dict(torch.load("model.pth"))

# classes = [
#     "T-shirt/top",
#     "Trouser",
#     "Pullover",
#     "Dress",
#     "Coat",
#     "Sandal",
#     "Shirt",
#     "Sneaker",
#     "Bag",
#     "Ankle boot",
# ]

# model.eval()
# x, y = test_data[0][0], test_data[0][1]
# with torch.no_grad():
#     x = x.to(device)
#     pred = model(x)
#     predicted, actual = classes[pred[0].argmax(0)], classes[y]
#     print(f'Predicted: "{predicted}", Actual: "{actual}"')

## Additional W&B APIs

In [None]:
# api = wandb.Api()

# # Access attributes directly from the run object
# # or from the W&B App
# username = "marko-krizmancic"
# project = "gnn_fiedler_approx"
# run_id = ["nrcdc1y4", "11l94b1a", "ptj7b0vx"]

# for id in run_id:
#     run = api.run(f"{username}/{project}/{id}")
#     run.config["model"] = "GCN"
#     run.update()