# Oversmoothing Analysis of Node Classification

### Setup

In [None]:
import os

import torch
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.utils import add_remaining_self_loops, degree
from torch_scatter import scatter
from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool, GCNConv, GATConv, SimpleConv
from torch_geometric.datasets import HeterophilousGraphDataset
from torch_geometric.loader import DataLoader
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### 1) Find and Download Datasets for Node Classification dataset link

***Roman Empire***

* Based on the Roman Empire article from Wikipedia
* Each node corresponds to a word in the text
* Each word is connected to another if the word follows or there is a dependency within the sentence
* The class of the node is its syntactic role found using spaCy
* Node features are fastText word embeddings

In [2]:
rome = HeterophilousGraphDataset(root='data/', name='Roman-empire')

***Amazon Ratings***

In [28]:
amazon = HeterophilousGraphDataset(root='data/', name='Amazon-ratings')

Downloading https://github.com/yandex-research/heterophilous-graphs/raw/main/data/amazon_ratings.npz
Processing...
Done!


***Minesweeper***

* 100x100 grid
* Each cell is connected to its eight neighbors
* Mines vs not mines
  * Binary classes
  * 20% are mines
* Node features are one-hot-encoded numbers of neighboring mines

In [4]:
minesweeper = HeterophilousGraphDataset(root='data/', name='Minesweeper')

***Tolokers***

In [29]:
tolo = HeterophilousGraphDataset(root='data/', name='Tolokers')

Downloading https://github.com/yandex-research/heterophilous-graphs/raw/main/data/tolokers.npz
Processing...
Done!


***Questions***

In [30]:
questions = HeterophilousGraphDataset(root='data/', name='Questions')

Downloading https://github.com/yandex-research/heterophilous-graphs/raw/main/data/questions.npz
Processing...
Done!


### 2) Node Classification

In [39]:
class GCN(torch.nn.Module):
    def __init__(
            self, 
            in_channels,
            conv_channels: list,
            mlp_channels: list,
            out_channels,
            heads=4,
        ):
        super(GCN, self).__init__()

        self.n_conv_layers = len(conv_channels)
        self.n_mlp_layers = len(mlp_channels)

        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(p=0.2)

        self.convs = nn.ModuleList([
            GATConv(in_channels, conv_channels[0], heads, concat=True),
            *[GATConv(conv_channels[i - 1] * heads, conv_channels[i], heads) for i in range(1, self.n_conv_layers)],
            GATConv(conv_channels[-1] * heads, mlp_channels[0], heads),
        ])

        self.batch_norms = nn.ModuleList([
            nn.BatchNorm1d(conv_channels[0] * heads),
            *[nn.BatchNorm1d(conv_channels[i] * heads) for i in range(1, self.n_conv_layers)],
            nn.BatchNorm1d(mlp_channels[0] * heads),
        ])

        self.projections = nn.ModuleList([
            nn.Linear(in_channels + conv_channels[0] * heads, conv_channels[0] * heads),
            *[nn.Linear(conv_channels[i - 1] * heads + conv_channels[i] * heads, conv_channels[i] * heads) for i in range(1, self.n_conv_layers)],
            nn.Linear(conv_channels[-1] * heads + mlp_channels[0] * heads, mlp_channels[0] * heads),
        ])

        self.fc1 = nn.Linear(mlp_channels[0] * heads, mlp_channels[1])
        self.fc2 = nn.Linear(mlp_channels[1], out_channels)
        
    def forward(self, x, edge_index):
        skip = x
        for i, conv in enumerate(self.convs):
            # Conv + batch norm
            x = conv(x, edge_index)
            x = self.batch_norms[i](x)

            # Skip connections + linear projection
            x = torch.cat([x, skip], dim=1)
            x = self.projections[i](x)
            skip = x

            # Activation + dropout
            x = self.gelu(x)
            x = self.dropout(x)

        # MLP
        x = self.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)
    
conv_channels_list = [512]
mlp_channels_list = [64, 32]

dataset = minesweeper.to(device)

class_weights = None # torch.tensor([1/0.8, 1/0.2]).to(device)

loader = DataLoader(dataset, batch_size=32) #shuffle=True)

model = GCN(
    in_channels=dataset.num_node_features,
    conv_channels=conv_channels_list,
    mlp_channels=mlp_channels_list,
    out_channels=dataset.num_classes,
    heads=8,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-5, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=50)
criterion = nn.NLLLoss(weight=class_weights)  # Use if returning log_softmax

split_index = 0
train_mask = dataset.train_mask[:, split_index]
val_mask = dataset.val_mask[:, split_index]
test_mask = dataset.test_mask[:, split_index]

# Training loop
def train():
    model.train()
    optimizer.zero_grad()
    out = model(dataset.x, dataset.edge_index)
    loss = criterion(out[train_mask], dataset.y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

# Evaluation loop
def evaluate():
    model.eval()
    with torch.no_grad():
        logits = model(dataset.x, dataset.edge_index)
        pred = logits.argmax(dim=1)
        accs = []
        for mask in [train_mask, val_mask, test_mask]:
            acc = (pred[mask] == dataset.y[mask]).sum().item() / mask.sum().item()
            accs.append(acc)
    return accs

# Training process
num_epochs = 1000
for epoch in range(num_epochs):
    loss = train()
    train_acc, val_acc, test_acc = evaluate()
    if epoch % 50 == 0 or epoch == (num_epochs - 1):
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch 000, Loss: 0.6590, Train Acc: 0.8000, Val Acc: 0.8000, Test Acc: 0.8000
Epoch 050, Loss: 0.4006, Train Acc: 0.8000, Val Acc: 0.8000, Test Acc: 0.8000
Epoch 100, Loss: 0.3415, Train Acc: 0.8418, Val Acc: 0.8380, Test Acc: 0.8400
Epoch 150, Loss: 0.3086, Train Acc: 0.8614, Val Acc: 0.8364, Test Acc: 0.8512
Epoch 200, Loss: 0.2964, Train Acc: 0.8680, Val Acc: 0.8320, Test Acc: 0.8472
Epoch 250, Loss: 0.2873, Train Acc: 0.8718, Val Acc: 0.8320, Test Acc: 0.8484
Epoch 300, Loss: 0.2808, Train Acc: 0.8766, Val Acc: 0.8300, Test Acc: 0.8492
Epoch 350, Loss: 0.2749, Train Acc: 0.8794, Val Acc: 0.8296, Test Acc: 0.8460
Epoch 400, Loss: 0.2651, Train Acc: 0.8806, Val Acc: 0.8324, Test Acc: 0.8456
Epoch 450, Loss: 0.2630, Train Acc: 0.8852, Val Acc: 0.8288, Test Acc: 0.8484
Epoch 500, Loss: 0.2604, Train Acc: 0.8852, Val Acc: 0.8268, Test Acc: 0.8428
Epoch 550, Loss: 0.2534, Train Acc: 0.8878, Val Acc: 0.8272, Test Acc: 0.8428
Epoch 600, Loss: 0.2540, Train Acc: 0.8900, Val Acc: 0.8276, Tes

In [40]:
# Save model
weights_path = f"models/gcn_{dataset.name}.pth"
torch.save(model.state_dict(), weights_path)

In [None]:
# Load model from file
model.load_state_dict(torch.load(weights_path, weights_only=True))

In [31]:
model = GCN(
    in_channels=dataset.num_node_features,
    conv_channels=conv_channels_list,
    mlp_channels=mlp_channels_list,
    out_channels=dataset.num_classes,
    heads=8,
).to(device)

sum(p.numel() for p in model.parameters())

74092146

In [40]:
idx = 53

batch = next(iter(loader))
with torch.no_grad():
    logits = model(batch.x, batch.edge_index)

predictions = logits.argmax(dim=1)
predictions

tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')