# Oversmoothing Analysis of Node Classification

### Setup

In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.utils import add_remaining_self_loops, degree
from torch_scatter import scatter
from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool, GCNConv, GATConv, SimpleConv
from torch_geometric.datasets import HeterophilousGraphDataset
from torch_geometric.loader import DataLoader
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### 1) Find and Download Datasets for Node Classification dataset link

***Roman Empire***

* Based on the Roman Empire article from Wikipedia
* Each node corresponds to a word in the text
* Each word is connected to another if the word follows or there is a dependency within the sentence
* The class of the node is its syntactic role found using spaCy
* Node features are fastText word embeddings

In [2]:
rome = HeterophilousGraphDataset(root='data/', name='Roman-empire')

In [3]:
rome.num_classes

18

***Amazon Ratings***

***Minesweeper***

* 100x100 grid
* Each cell is connected to its eight neighbors
* Mines vs not mines
  * Binary classes
  * 20% are mines
* Node features are one-hot-encoded numbers of neighboring mines

In [4]:
minesweeper = HeterophilousGraphDataset(root='data/', name='Minesweeper')

***Tolokers***

***Questions***

### Model Definition

In [None]:
class GCN(torch.nn.Module):
    def __init__(
            self, 
            in_channels,
            conv_channels: list,
            mlp_channels: list,
            out_channels,
            heads=4,
        ):
        super(GCN, self).__init__()

        self.n_conv_layers = len(conv_channels)
        self.n_mlp_layers = len(mlp_channels)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)

        convs = [
            GATConv(in_channels, conv_channels[0], heads, concat=True),
            *[GATConv(conv_channels[i - 1] * heads, conv_channels[i], heads) for i in range(1, self.n_conv_layers)],
            GATConv(conv_channels[-1] * heads, mlp_channels[0], heads), # Probs for n_classes
        ]

        batch_norms = [
            nn.BatchNorm1d(conv_channels[0] * heads),
            *[nn.BatchNorm1d(conv_channels[i] * heads) for i in range(1, self.n_conv_layers)],
            nn.BatchNorm1d(mlp_channels[0] * heads),
        ]

        self.convs = nn.ModuleList(convs)
        self.batch_norms = nn.ModuleList(batch_norms)

        self.fc1 = nn.Linear(mlp_channels[0] * heads, mlp_channels[1])
        self.fc2 = nn.Linear(mlp_channels[1], out_channels)
        
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.batch_norms[i](x)
            x = self.relu(x)
            x = self.dropout(x)

        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)
    
conv_channels_list = [512, 512]
mlp_channels_list = [64, 32]

dataset = rome.to(device)
class_weights = None # torch.tensor([1/0.8, 1/0.2]).to(device)

loader = DataLoader(dataset, batch_size=32) #shuffle=True)

model = GCN(
    in_channels=dataset.num_node_features,
    conv_channels=conv_channels_list,
    mlp_channels=mlp_channels_list,
    out_channels=dataset.num_classes,
    heads=8,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-5, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights)

split_index = 0
train_mask = dataset.train_mask[:, split_index]
val_mask = dataset.val_mask[:, split_index]
test_mask = dataset.test_mask[:, split_index]

# Training loop
def train():
    model.train()
    optimizer.zero_grad()
    out = model(dataset.x, dataset.edge_index)
    loss = criterion(out[train_mask], dataset.y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

# Evaluation loop
def evaluate():
    model.eval()
    with torch.no_grad():
        logits = model(dataset.x, dataset.edge_index)
        pred = logits.argmax(dim=1)
        accs = []
        for mask in [train_mask, val_mask, test_mask]:
            acc = (pred[mask] == dataset.y[mask]).sum().item() / mask.sum().item()
            accs.append(acc)
    return accs

# Training process
num_epochs = 200
for epoch in range(num_epochs):
    loss = train()
    train_acc, val_acc, test_acc = evaluate()
    if epoch % 50 == 0 or epoch == (num_epochs - 1):
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch 000, Loss: 2.9430, Train Acc: 0.0138, Val Acc: 0.0138, Test Acc: 0.0138


In [7]:
sum(p.numel() for p in model.parameters())

35760738

In [40]:
idx = 53

batch = next(iter(loader))
with torch.no_grad():
    logits = model(batch.x, batch.edge_index)

predictions = logits.argmax(dim=1)
predictions

tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')