# Graph Transformers

## Self-Attention as Message Passing

Let's assume that we have three tokens with scalar features:
$$x_1 = 1, \quad x_2 = 2, \quad x_3 = 0$$

Note: This is a toy model. In practice, token features are high-dimensional vectors.

For simplicity, we define $W^Q=W^K=W^V=I$, therefore:
$$q_i = x_i,\quad k_i = x_i,\quad v_i = x_i$$

And the self-attention update for token 1:
$$z_1 = \sum_{j=1}^3 \alpha_{1j} v_j, \quad \alpha_{1j} = \frac{e^{q_1 k_j}}{\sum_{\ell} e^{q_1 k_\ell}}$$

Before doing any math, answer:

1. Which token do you expect token 1 to pay most attention to?
2. Which token should receive the least attention from others?

Now, compute and answer:

3. All $\alpha_{1j}$ and then $z_1$. Describe what $z_1$ represents in this example.
4. Are $\alpha_{12}$ and $\alpha_{21}$ equal? What does this say about attention as an "edge weight"?
5. What happens if all $x_i$ are equal? What kind of GNN is this equivalent to?
6. What changes in the self-attention mechanism when moving from a Transformer to a GAT?

## Graph Laplacian Magic

In previous lectures, we used the Laplacian matrix $L$, but what does it actually mean?

Given an undirected graph with adjacency matrix $A$ and degree matrix $D$, $L$ is defined as:
$$L = D - A$$

Consider the following path graph of $3$ nodes:

![alt text](assets/w8_path.png "G")

1. What is the Laplacian matrix for this graph?

2. Verify that the following are eigenvectors of $L$, and find the corresponding eigenvalues:
    - $(-1,0,1)^\top$
    - $(1,1,1)^\top$
    - $(1,-2,1)^\top$

    Refresher: $Lv=\lambda v$, where $v$ is the eigenvector and $\lambda$ is the eigenvalue.

3. Order the eigenvectors from slowly varying to rapidly varying (i.e., smooth vs oscillatory) across the graph. Then, look at the corresponding eigenvalues. What do you observe about how the eigenvalues are ordered relative to the variation speed?

4. Based on your finding in (3), which eigenvectors capture the global structure and which eigenvectors capture the local variation?

5. Now let's verify our findings in a relatively larger graph. We will revisit the Karate Club graph and plot what the eigenvectors highlight within the graph.

    The following code snippet computes the eigenvectors and sorts them based on the corresponding eigenvalues. Plot different eigenvectors across the graph to see which information they encode (local vs global). Does it match with your finding in (4)?

In [None]:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

# load the Karate Club graph
G = nx.karate_club_graph()

# compute the Laplacian matrix
L = nx.laplacian_matrix(G).toarray()

# compute eigenvalues and eigenvectors
eigvals, eigvecs = np.linalg.eigh(L)

# sort eigenvalues and eigenvectors
idx = np.argsort(eigvals)
eigvals = eigvals[idx]
eigvecs = eigvecs[:, idx]

# plot the eigenvectors
pos = nx.spring_layout(G, seed=42)
for k in [1]:
	plt.figure()
	nx.draw(G, pos, node_color=eigvecs[:, k], cmap='coolwarm', with_labels=False)
	plt.title(f"Laplacian eigenvector {k}")
	plt.show()

## Eigenvector Sign Ambiguity

Recall that $Lv=\lambda v$, where $v$ is the eigenvector and $\lambda$ is the eigenvalue. This also means:
$$L(-v)=\lambda (-v)$$

1. Why is this a problem when we want to use the eigenvectors as positional encodings?

2. Would an attention-based model automatically be invariant to this sign ambiguity? Why or why not?

3. If we need sign invariance, why don't we just take $|v|$ before using it as a positional encoding?

4. How does `SignNet` ([paper link](https://arxiv.org/pdf/2202.13013)) solve this problem? You can play with the following code snippet which implements a simplifed version of `SignNet`.

In [None]:
import torch
import torch.nn as nn

class SimpleSignNet(nn.Module):
    def __init__(self, hidden_dim=16):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, v):
        # v: (n,) eigenvector
        v = v.unsqueeze(-1)        # (n, 1)

        out_pos = self.phi(v)      # φ(v)
        out_neg = self.phi(-v)     # φ(-v)

        return out_pos + out_neg   # sign-invariant
    
# toy eigenvector
v = torch.tensor([1.0, -2.0, 0.5])

model = SimpleSignNet()

z1 = model(v)
z2 = model(-v)

print(torch.allclose(z1, z2))