# Exercise 03

In this exercise we look at graphs and post-hoc explanations and want to give you practical experience with dimensionality reduction.

## Graph Explanations
In the lecture we talked about graphs and how they can be used as scene representations for robot learning tasks. In the exercise, we will have a look at more traditional graph tasks and how we can produce post-hoc explanations from the predictions. We'll use the MUTAG dataset, a collection of graphs representing chemical compounds. The task is to predict whether a compound is mutagenic or not.
 1.  Train a simple Graph Convolutional Network (GCN) on the MUTAG dataset for a graph classification task.
 2.  Use **GNNExplainer** to find the most influential subgraph and node features for a specific prediction.
 3.  Use **PGExplainer** to learn a parameterized explanation for the model's predictions.
 4.  Visualize and compare the explanations from both methods.

Therefore, we first need to install pytorch-geometric.

In [None]:
!pip install torch
!pip install torch-geometric

Let's visualize the first molecule in the dataset to understand its structure. The node colors represent the different types of atoms.

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv, global_mean_pool
from torch_geometric.explain.algorithm import GNNExplainer, PGExplainer
from torch_geometric.explain import Explainer
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx
from IPython.display import Image, display

dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
print(f"Dataset: {dataset.name}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of node features: {dataset.num_node_features}")

# Select the first graph
graph_to_viz = dataset[0]

# Convert to a NetworkX graph for visualization
g = to_networkx(graph_to_viz, to_undirected=True)

# Get node colors based on atom type
node_colors = [graph_to_viz.x[i].argmax().item() for i in range(graph_to_viz.num_nodes)]

plt.figure(figsize=(8, 8))
nx.draw(g, with_labels=True, node_color=node_colors, cmap="tab10", node_size=800)
plt.title("First Graph in MUTAG Dataset (Molecule)")
plt.show()

loader = DataLoader(dataset, batch_size=64, shuffle=True)

Here, we define a simple Graph Convolutional Network (GCN) with two convolutional layers followed by a global mean pooling layer and a final linear layer for classification.

In [None]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 64) # You can also try out the GATv2Conv layer
        self.conv2 = GCNConv(64, 64)
        self.lin = torch.nn.Linear(64, dataset.num_classes)

    def forward(self, x, edge_index, batch=torch.tensor([1])):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

Now, we'll train our GCN model on the MUTAG dataset.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()

def test():
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

In [None]:
for epoch in range(1, 201):
    train()
    if epoch % 10 == 0:
        test_acc = test()
        print(f'Epoch: {epoch:03d}, Test Accuracy: {test_acc:.4f}')

GNNExplainer is a model-agnostic explainer that identifies a compact subgraph structure and a small subset of node features that are most influential for a prediction.

Visualizing the GNNExplainer Explanation:
The highlighted subgraph shows the most important edges for the model's prediction.



In [None]:
# Let's pick a single graph to explain
data_to_explain = dataset[0].to(device)

gnn_explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200, lr=0.01),
    explanation_type='model',
    node_mask_type='object',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs',
    )
)

explanation = gnn_explainer(data_to_explain.x, data_to_explain.edge_index)

path = 'gnn_explainer_explanation.png'
explanation.visualize_graph(path, backend='networkx')
display(Image(filename=path))

In [None]:
pg_explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs',
    ),
)


# Train the PGExplainer
# We need to train the explainer on the training data
for epoch in range(30):
    for data in loader:
        data = data.to(device)
        # PGExplainer needs a target for training, we use the model's prediction
        target = model(data.x, data.edge_index, data.batch).argmax(dim=1)
        pg_explainer.algorithm.train(epoch, model, data.x, data.edge_index, target=target, batch=data.batch)


# Get an explanation for our chosen graph
pg_explanation = pg_explainer(data_to_explain.x, data_to_explain.edge_index, target=data_to_explain.y)

path = 'pg_explainer_explanation.png'
pg_explanation.visualize_graph(path, backend='networkx')
display(Image(filename=path))

## PCA

PCA transforms data accoording to their main axes of variance. To explore the effect of PCA, we use `scikit-learn` to construct a "Swiss Roll" Dataset.

We can plot the data in 3D and a simple 2D XZ-Plane projection. The Projection highlights the spiral structure of the swiss roll, but looses all information about its width.

In [None]:
import numpy as np
from sklearn import datasets

import matplotlib.pyplot as plt

In [None]:
# Create Dataset
data, t = datasets.make_swiss_roll(n_samples=1500, noise=0.05)

# 3D Projection
fig = plt.figure()
ax = fig.add_subplot(1,2,1,  projection='3d')
ax.scatter(data[:, 0], data[:, 1], data[:, 2],c=t,alpha=0.5)
ax.set_title('Swiss Roll 3D')

# 2D XZ Projection
ax = fig.add_subplot(1,2,2)
ax.scatter(data[:, 0], data[:, 2],c=t,alpha=0.5)
ax.set_title('Swiss Roll 2D Projection')
plt.show()

We compute the covariance matrix of the dataset and then the covariance matrix' eigenvalue and eigenvectors.

You may notice, that the eigenvalues, indicating "amount of variance" explained by each eigenvector, is unsorted.

For dimensionality reduction, we would sort the vectors by the values and look at the first `n` dimensions only.

In [None]:
# data is shape SAMPLES x FEATURES
# To get the covariance per Feature, we need the transpose.
cov = np.cov(data.T)
eig_val, eig_vec = np.linalg.eigh(cov)

print(cov)
print()
print(eig_val)

As covered in the lecture, the eigenvectors span a matrix acting as a linear transformation. To transform the original data into the "PCA Space", we apply a matrix multiplication.

Afterwards we plot the transformed data according to the two axes with the highest eigenvalues.

Even though this is a 2D projection as above, due to the PCA transformation it better covers the characteristics of the data.

In [None]:
data_pca = np.dot(data, eig_vec)

fig = plt.figure()
ax = fig.add_subplot()

# Hardcoded indices 1 and 2. Should try to retrive the index of the highest eigenvalues instead.
ax.scatter( data_pca[:, 1], data_pca[:, 2],c=t,alpha=0.5)
plt.show()

## TSNE and UMAP in Praxis

While doing our own implementation of tSNE and UMAP would be interesting, we instead want to recommend the following, interactive blockposts. They highlight the effects of the various hyper-parameters.

https://distill.pub/2016/misread-tsne/

https://pair-code.github.io/understanding-umap/