## Preparation

### Imports

In [None]:
import os
import pathlib

import codetiming
import numpy as np
import torch
import wandb


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loaded torch. Using *{device}* device.")

### Load dataset

In [None]:
from torch_geometric.loader import DataLoader
from my_graphs_dataset import GraphDataset
from algebraic_connectivity_dataset import ConnectivityDataset


def load_dataset(selected_graph_sizes, selected_features=[], split=0.8, batch_size=0, seed=42, is_sweep=False):
    if selected_graph_sizes is None:
        selected_graph_sizes = {
            3: -1,
            4: -1,
            5: -1,
            6: -1,
            7: -1,
            8: -1,
            # 9:  100000,
            # 10: 100000
        }

    dataset_config = {
        "name": "ConnectivityDataset",
        "selected_graphs": str(selected_graph_sizes),
        "split": split,
        "batch_size": batch_size,
        "seed": seed,
    }

    # Load the dataset.
    root = pathlib.Path(os.getcwd()) / "Dataset"
    graphs_loader = GraphDataset(selection=selected_graph_sizes)
    dataset = ConnectivityDataset(root, graphs_loader, selected_features=selected_features)

    # General information
    if not is_sweep:
        print()
        print(f"Dataset: {dataset}:")
        print("====================")
        print(f"Number of graphs: {len(dataset)}")
        print(f"Number of features: {dataset.num_features}")

    # Store information about the dataset.
    dataset_config["num_graphs"] = len(dataset)
    features = selected_features if selected_features else dataset.features

    # Shuffle and split the dataset.
    torch.manual_seed(seed)
    dataset = dataset.shuffle()

    train_size = round(dataset_config["split"] * len(dataset))
    train_dataset = dataset[:train_size]
    test_dataset = dataset[train_size:]

    if not is_sweep:
        print()
        print(f"Number of training graphs: {len(train_dataset)}")
        print(f"Number of test graphs: {len(test_dataset)}")

    # Batch and load data.
    # TODO: Batch size?
    batch_size = dataset_config["batch_size"] if dataset_config["batch_size"] > 0 else len(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # type: ignore
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # type: ignore

    train_batch = None
    test_batch = None
    # If the whole dataset fits in memory, we can use the following lines to get a single large batch.
    train_batch = next(iter(train_loader))
    test_batch = next(iter(test_loader))

    train_data_obj = train_batch if train_batch is not None else train_loader
    test_data_obj = test_batch if test_batch is not None else test_loader

    if not is_sweep:
        print()
        print("Batches:")
        for step, data in enumerate(train_loader):
            print(f"Step {step + 1}:")
            print("=======")
            print(f"Number of graphs in the current batch: {data.num_graphs}")
            print(data)
            print()

    return train_data_obj, test_data_obj, dataset_config, features

### Models

#### Basic GCN

In [None]:
from torch.nn import Linear, ReLU
from torch_geometric.nn import GCNConv, Sequential
from torch_geometric.nn import global_mean_pool


class MyGCN(torch.nn.Module):
    def __init__(self, input_channels, mp_layers):
        super(MyGCN, self).__init__()

        # Message-passing layers - GCNConv
        layers = []
        for i, layer_size in enumerate(mp_layers):
            if i == 0:
                layers.append((GCNConv(input_channels, layer_size), "x, edge_index -> x"))
            else:
                layers.append((GCNConv(mp_layers[i - 1], layer_size), "x, edge_index -> x"))
            layers.append(ReLU())
        self.mp_layers = Sequential("x, edge_index", layers)

        # Final readout layer
        self.lin = Linear(mp_layers[-1], 1)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.mp_layers(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = self.lin(x)

        return x

In [None]:
custom_gnns = {x.__name__: x for x in [MyGCN]}

#### Wrapper for pre-made models

In [None]:
from torch.nn import Linear
from torch_geometric.nn import GCN, GraphSAGE, GIN, GAT
from torch_geometric.nn import global_mean_pool

premade_gnns = {x.__name__: x for x in [GCN, GraphSAGE, GIN, GAT]}


class GNNWrapper(torch.nn.Module):
    def __init__(self, gnn_model, in_channels: int, hidden_channels: int, num_layers: int, **kwargs):
        super().__init__()
        self.gnn = gnn_model(in_channels, hidden_channels, num_layers, **kwargs)
        self.classifier = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        x = self.gnn(x, edge_index)
        x = global_mean_pool(x, batch)
        return self.classifier(x)

## Training & Evaluation

### Definitions

#### Training definitions

In [None]:
from torch_geometric.data import Data


def generate_model(architecture, in_channels, hidden_channels, num_layers):
    """Generate a Neural Network model based on the architecture and hyperparameters."""
    # GLOBALS: device, premade_gnns, custom_gnns
    if architecture in premade_gnns:
        model = GNNWrapper(
            gnn_model=premade_gnns[architecture],
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            num_layers=num_layers,
        ).to(device)
    else:
        MyGNN = custom_gnns[architecture]
        model = MyGNN(input_channels=in_channels, mp_layers=[hidden_channels] * num_layers).to(device)
    return model


def generate_optimizer(model, optimizer, lr):
    """Generate optimizer object based on the model and hyperparameters."""
    if optimizer == "adam":
        return torch.optim.Adam(model.parameters(), lr=lr)
    else:
        raise ValueError("Only Adam optimizer is currently supported.")


def training_pass(model, batch, optimizer, criterion):
    """Perofrm a single training pass over the batch."""
    data = batch.to(device)  # Move to CUDA if available.
    out = model.forward(data.x, data.edge_index, batch=data.batch)  # Perform a single forward pass.
    loss = criterion(out.squeeze(), data.y)  # Compute the loss.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    optimizer.zero_grad()  # Clear gradients.


def testing_pass(model, batch, criterion):
    """Perform a single testing pass over the batch."""
    with torch.no_grad():
        data = batch.to(device)
        out = model.forward(data.x, data.edge_index, batch=data.batch)
        loss = criterion(out.squeeze(), data.y).item()  # Compute the loss.
    return loss


def do_train(model, data, optimizer, criterion):
    """Train the model on individual batches or the entire dataset."""
    model.train()

    if isinstance(data, DataLoader):
        for batch in data:  # Iterate in batches over the training dataset.
            training_pass(model, batch, optimizer, criterion)
    elif isinstance(data, Data):
        training_pass(model, data, optimizer, criterion)
    else:
        raise ValueError("Data must be a DataLoader or a Batch object.")


def do_test(model, data, criterion):
    """Test the model on individual batches or the entire dataset."""
    model.eval()

    if isinstance(data, DataLoader):
        for batch in data:
            loss = testing_pass(model, batch, criterion)
    elif isinstance(data, Data):
        loss = testing_pass(model, data, criterion)
    else:
        raise ValueError("Data must be a DataLoader or a Batch object.")

    return loss


def train(model, optimizer, criterion, train_data_obj, test_data_obj, num_epochs=100, is_sweep=False):
    # GLOBALS: device, dataset, train_data_obj, test_data_obj

    # Prepare for training.
    train_losses = np.zeros(num_epochs)
    test_losses = np.zeros(num_epochs)

    # Start the training loop with timer.
    training_timer = codetiming.Timer(logger=None)
    epoch_timer = codetiming.Timer(logger=None)
    training_timer.start()
    epoch_timer.start()
    for epoch in range(1, num_epochs + 1):
        # Perform one pass over the training set and then test on both sets.
        do_train(model, train_data_obj, optimizer, criterion)
        train_loss = do_test(model, train_data_obj, criterion)
        test_loss = do_test(model, test_data_obj, criterion)

        # Store the losses.
        train_losses[epoch - 1] = train_loss
        test_losses[epoch - 1] = test_loss
        wandb.log({"train_loss": train_loss, "test_loss": test_loss})

        # Print the losses every 10 epochs.
        if epoch % 10 == 0 and not is_sweep:
            print(
                f"Epoch: {epoch:03d}, "
                f"Train Loss: {train_loss:.4f}, "
                f"Test Loss: {test_loss:.4f}, "
                f"Avg. duration: {epoch_timer.stop() / 10:.4f} s"
            )
            epoch_timer.start()
    epoch_timer.stop()
    duration = training_timer.stop()

    results = {"train_losses": train_losses, "test_losses": test_losses, "duration": duration}
    return results

#### Evaluation definitions

In [None]:
import concurrent.futures

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from utils import create_graph_wandb, extract_graphs_from_batch, graphs_to_tuple


def plot_training_curves(num_epochs, train_losses, test_losses, criterion):
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=list(range(1, num_epochs + 1)), y=train_losses, mode="lines", name="Train Loss"))
    fig.add_trace(go.Scatter(x=list(range(1, num_epochs + 1)), y=test_losses, mode="lines", name="Test Loss"))
    fig.update_layout(title="Training and Test Loss", xaxis_title="Epoch", yaxis_title=criterion)
    fig.show()


def eval_batch(model, batch, plot_graphs=False):
    # Make predictions.
    data = batch.to(device)
    out = model(data.x, data.edge_index, data.batch)
    predictions = out.cpu().numpy().squeeze()
    ground_truth = data.y.cpu().numpy()

    # Extract graphs and create visualizations.
    nx_graphs = extract_graphs_from_batch(data)
    graphs, node_nums, edge_nums = zip(*graphs_to_tuple(nx_graphs))
    # FIXME: This is the only way to parallelize in Jupyter but runs out of memory.
    # with concurrent.futures.ProcessPoolExecutor(4) as executor:
    #     graph_visuals = executor.map(create_graph_wandb, nx_graphs, chunksize=10)
    if plot_graphs:
        graph_visuals = [create_graph_wandb(g) for g in nx_graphs]
    else:
        graph_visuals = ["N/A"] * len(nx_graphs)

    # Store to pandas DataFrame.
    return pd.DataFrame(
                {
                    "GraphVis": graph_visuals,
                    "Graph": graphs,
                    "Nodes": node_nums,
                    "Edges": edge_nums,
                    "True": ground_truth,
                    "Predicted": predictions,
                }
            )


def evaluate(model, test_data, plot_graphs=False, is_sweep=False):
    # GLOBALS: dataset_config, train_loader, test_loader

    # Evaluate the model on the test set.
    model.eval()
    df = pd.DataFrame()

    with torch.no_grad():
        if isinstance(test_data, DataLoader):
            for batch in test_data:
                df = pd.concat([df, eval_batch(model, batch, plot_graphs)])
        elif isinstance(test_data, Data):
            df = eval_batch(model, test_data, plot_graphs)
        else:
            raise ValueError("Data must be a DataLoader or a Batch object.")

    # Calculate the statistics.
    df["Error"] = df["True"] - df["Predicted"]
    df["Error %"] = 100 * df["Error"] / df["True"]
    df["abs(Error)"] = np.abs(df["Error"])
    err_mean = np.mean(df["abs(Error)"])
    err_stddev = np.std(df["abs(Error)"])

    # Create a W&B table.
    table = wandb.Table(dataframe=df)

    # Print and plot.
    # df = df.sort_values(by="abs(Error)")
    fig_abs_err = px.histogram(df, x="Error")
    fig_rel_err = px.histogram(df, x="Error %")

    if not is_sweep:
        print(f"Mean error: {err_mean:.4f}\n" f"Std. dev.: {err_stddev:.4f}\n")
        fig_abs_err.show()
        fig_rel_err.show()
        df = df.sort_values(by="Nodes")
        print(df)

    results = {
        "mean_err": err_mean,
        "stddev_err": err_stddev,
        "fig_abs_err": fig_abs_err,
        "fig_rel_err": fig_rel_err,
        "table": table,
    }
    return results

#### Main definition

In [None]:
def main(config=None, skip_evaluation=False):
    # GLOBALS: device

    is_sweep = config is None

    # Set up dataset.
    selected_graph_sizes = {
        3: -1,
        4: -1,
        5: -1,
        6: -1,
        7: -1,
        8: -1,
        # 9:  100000,
        # 10: 100000
    }

    # Set up the run
    run = wandb.init(mode="disabled", project="gnn_fiedler_approx", tags=["lambda2", "fiedler", "baseline"], config=config)
    config = wandb.config
    if is_sweep:
        print(f"Running sweep with config: {config}...")

    # Load the dataset.
    train_data_obj, test_data_obj, dataset_config, features = load_dataset(
        selected_graph_sizes, selected_features=config.get("selected_features", []), is_sweep=is_sweep
    )

    wandb.config["dataset"] = dataset_config
    if "selected_features" not in wandb.config or not wandb.config["selected_features"]:
        print(features)
        wandb.config["selected_features"] = features
        print(wandb.config["selected_features"])

    # Set up the model, optimizer, and criterion.
    model = generate_model(
        config["architecture"],
        len(wandb.config["selected_features"]),
        config["hidden_channels"],
        config["num_layers"],
    )
    optimizer = generate_optimizer(model, config["optimizer"], config["learning_rate"])
    criterion = torch.nn.L1Loss()

    # Run training.
    train_results = train(
        model, optimizer, criterion, train_data_obj, test_data_obj, config["epochs"], is_sweep=is_sweep
    )
    run.summary["best_train_loss"] = min(train_results["train_losses"])
    run.summary["best_test_loss"] = min(train_results["test_losses"])
    run.summary["duration"] = train_results["duration"]
    if not is_sweep:
        plot_training_curves(
            config["epochs"], train_results["train_losses"], train_results["test_losses"], type(criterion).__name__
        )

    # Run evaluation.
    if not skip_evaluation:
        eval_results = evaluate(model, test_data_obj, plot_graphs=not is_sweep, is_sweep=is_sweep)
        run.summary["mean_err"] = eval_results["mean_err"]
        run.summary["stddev_err"] = eval_results["stddev_err"]
        run.log({"abs_err_hist": eval_results["fig_abs_err"], "rel_err_hist": eval_results["fig_rel_err"]})
        run.log({"results_table": eval_results["table"]})

    if is_sweep:
        print(
            f"    ...DONE. "
            f"Mean error: {eval_results['mean_err']:.4f}, "
            f"Std. dev.: {eval_results['stddev_err']:.4f}, "
            f"Duration: {train_results['duration']:.4f} s."
        )

    return run, model

### Run

#### Standard run

In [None]:
global_config = {
    "seed": 42,
    "architecture": "GAT",
    "hidden_channels": 10,
    "num_layers": 3,
    "optimizer": "adam",
    "learning_rate": 0.01,
    "epochs": 500,
}
run, model = main(global_config, skip_evaluation=True)

#### W&B sweep

In [None]:
%env WANDB_SILENT=True

import time

time.sleep(2)

# TODO: How to include seed and dataset configuration?

full_sweep_configuration = {
    "name": "full_first_sweep",
    "method": "grid",  # grid, random or Bayesian
    "metric": {"goal": "minimize", "name": "test_loss"},
    "parameters": {
        "architecture": {"values": ["GCN", "GraphSAGE", "GIN", "GAT"]},
        "hidden_channels": {"values": [8, 16, 32, 64]},
        "num_layers": {"values": [1, 2, 3, 5]},
        "optimizer": {"value": "adam"},
        "learning_rate": {"values": [0.1, 0.01, 0.001]},
        "epochs": {"value": 1000},
    },
    "early_terminate": {"type": "hyperband", "eta": 3, "min_iter": 300},
}

test_sweep_configuration = {
    "name": "test_sweep",
    "method": "grid",  # grid, random or Bayesian
    "metric": {"goal": "minimize", "name": "test_loss"},
    "parameters": {
        "architecture": {"values": ["GAT"]},
        "hidden_channels": {"values": [16]},
        "num_layers": {"values": [3]},
        "selected_features": {"values": [
            [],
            ["degree"],
            ["degree", "degree_centrality"],
        ]},
        "optimizer": {"value": "adam"},
        "learning_rate": {"values": [0.01]},
        "epochs": {"value": 1000},
    },
}

sweep_id = wandb.sweep(sweep=test_sweep_configuration, project="gnn_fiedler_approx")

wandb.agent(sweep_id, function=main, count=5)

In [None]:
# Stop the W&B run.
run.finish()

## Explain

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer, PGExplainer, AttentionExplainer

train_data_obj, test_data_obj, dataset_config, features = load_dataset(None, batch_size=1, is_sweep=True)
# model = generate_model("GraphSAGE", len(features), 10, 3)


#### GNNExplainer for model

In [None]:
# TODO: Are these results ok?
# Seems like the results are different on every run. Plus, how to interpret the
# results? What hyperparaters to use?

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),  # PGExplainer, AttentionExplainer, CaptumExplainer
    # explanation_type='phenomenon',  # what phenomenon leads from inputs to outputs, labels are targets for explanation
    explanation_type='model',  # open the black box and explain model decisions, predictions are targets for explanation
    node_mask_type="attributes",  # "object", "common_attributes", "attributes"
    edge_mask_type="object",
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

data = train_data_obj.to(device)
explanation = explainer(data.x, data.edge_index, batch=data.batch)
for exp in explanation.available_explanations:
    print(f"{exp}:\n{explanation.__getattr__(exp)}\n")

explanation.visualize_feature_importance(feat_labels=features)
explanation.visualize_graph()

#### GNNExplainer for phenomenon

In [None]:
# TODO: Are these results ok?
# Seems like the results are different on every run. Plus, how to interpret the
# results? What hyperparaters to use?

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),  # PGExplainer, AttentionExplainer, CaptumExplainer
    explanation_type='phenomenon',  # what phenomenon leads from inputs to outputs, labels are targets for explanation
    # explanation_type='model',  # open the black box and explain model decisions, predictions are targets for explanation
    node_mask_type="attributes",  # "object", "common_attributes", "attributes"
    edge_mask_type="object",
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

data = train_data_obj.to(device)
explanation = explainer(data.x, data.edge_index, target=data.y, batch=data.batch)
for exp in explanation.available_explanations:
    print(f"{exp}:\n{explanation.__getattr__(exp)}\n")

explanation.visualize_feature_importance(feat_labels=features)
explanation.visualize_graph()

#### AttentionExplainer for model

In [None]:
# TODO: Are these results ok?
# Seems like the results are different on every run. Plus, how to interpret the
# results? What hyperparaters to use?

explainer = Explainer(
    model=model,
    algorithm=AttentionExplainer(),  # PGExplainer, AttentionExplainer, CaptumExplainer
    # explanation_type='phenomenon',  # what phenomenon leads from inputs to outputs, labels are targets for explanation
    explanation_type='model',  # open the black box and explain model decisions, predictions are targets for explanation
    node_mask_type=None,  # "object", "common_attributes", "attributes"
    edge_mask_type="object",
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

data = train_data_obj.to(device)
explanation = explainer(data.x, data.edge_index, batch=data.batch)
for exp in explanation.available_explanations:
    print(f"{exp}:\n{explanation.__getattr__(exp)}\n")

# explanation.visualize_feature_importance(feat_labels=features)
explanation.visualize_graph()

#### PGEExplainer - WIP

In [None]:
# FIXME: Something is wrong with the implementation.

explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),  # PGExplainer, AttentionExplainer, CaptumExplainer
    explanation_type='phenomenon',  # what phenomenon leads from inputs to outputs, labels are targets for explanation
    # explanation_type='model',  # open the black box and explain model decisions, predictions are targets for explanation
    # node_mask_type="common_attributes",  # Node masks are not supported.
    edge_mask_type="object",
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

data = train_data_obj.to(device)

for epoch in range(30):
  for index in torch.LongTensor(np.random.randint(0, len(data.x), 20)):
    loss = explainer.algorithm.train(epoch, model, data.x, data.edge_index, target=data.y, batch=data.batch, index=index.item())

explanation = explainer(data.x, data.edge_index, target=data.y, batch=data.batch)

for exp in explanation.available_explanations:
    print(f"{exp}:\n{explanation.__getattr__(exp)}\n")

# explanation.visualize_feature_importance(feat_labels=features)
explanation.visualize_graph()

## Housekeeping

### Save the model

In [None]:
# torch.save(model.state_dict(), "model.pth")
# print("Saved PyTorch Model State to model.pth")


### Make predictions with loaded model

In [None]:
# model = NeuralNetwork().to(device)
# model.load_state_dict(torch.load("model.pth"))

# classes = [
#     "T-shirt/top",
#     "Trouser",
#     "Pullover",
#     "Dress",
#     "Coat",
#     "Sandal",
#     "Shirt",
#     "Sneaker",
#     "Bag",
#     "Ankle boot",
# ]

# model.eval()
# x, y = test_data[0][0], test_data[0][1]
# with torch.no_grad():
#     x = x.to(device)
#     pred = model(x)
#     predicted, actual = classes[pred[0].argmax(0)], classes[y]
#     print(f'Predicted: "{predicted}", Actual: "{actual}"')

## Additional W&B APIs

In [None]:
# api = wandb.Api()

# # Access attributes directly from the run object
# # or from the W&B App
# username = "marko-krizmancic"
# project = "gnn_fiedler_approx"
# run_id = ["nrcdc1y4", "11l94b1a", "ptj7b0vx"]

# for id in run_id:
#     run = api.run(f"{username}/{project}/{id}")
#     run.config["model"] = "GCN"
#     run.update()