Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[example] add GGCM #6899

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
8 changes: 8 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ The folder contains example implementations of selected research papers related
* For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree/<release_version>/examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples)

To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/).

## 2023

- <a name="labor"></a> Zheng Wang et al. From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/pytorch/ogc)

- Tags: semi-supervised node classification

## 2022
- <a name="labor"></a> Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/labor/train_lightning.py)
Expand Down
41 changes: 41 additions & 0 deletions examples/pytorch/ggcm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# DGL Implementation of GGCM

This DGL example implements the GGCM method from the paper: [From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited](https://arxiv.org/abs/2309.13599).
The authors' original implementation can be found [here](https://github.com/zhengwang100/ogc_ggcm).


## Example Implementor

This example was implemented by [Sinuo Xu](https://github.com/SinuoXu) when she was an undergraduate at SJTU.


## Dependencies
Python 3.11.5<br>
PyTorch 2.0.1<br>
DGL 1.1.2<br>
scikit-learn 1.3.1<br>


## Dataset
The DGL's built-in Citeseer, Cora and Pubmed datasets, as follows:
| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |
|Cora |2,708| 10,556| 1,433| 7 |140| 500| 1000|
|Pubmed| 19,717| 88,651| 500 |3| 60| 500| 1000|


## Usage
Run with the following (available dataset: "cora", "citeseer", "pubmed")
```bash
python train.py --dataset citeseer
python train.py --dataset cora --decline 1.0 --alpha 0.15 --epochs 100 --lr 0.2 --layer_num 16 --negative_rate 20.0 --wd 1e-5 --decline_neg 0.5
python train.py --dataset pubmed --decline 1.0 --alpha 0.1 --epochs 100 --lr 0.2 --layer_num 16 --negative_rate 20.0 --wd 2e-5 --decline_neg 0.5
```

## Performance

|Dataset|citeseer|cora|pubmed|
| :-: | :-: | :-: | :-: |
| GGCM (DGL)|74.1|83.5|80.7|
|GGCM (reported) |74.2|83.6|80.8|
55 changes: 55 additions & 0 deletions examples/pytorch/ggcm/ggcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import dgl.sparse as dglsp

import torch
import torch.nn as nn

from utils import (
inverse_graph_convolution,
lazy_random_walk,
symmetric_normalize_adjacency,
)


class GGCM(nn.Module):
def __init__(self):
super(GGCM, self).__init__()

def get_embedding(self, graph, args):
# get the learned node embeddings
beta = 1.0
beta_neg = 1.0
layer_num, alpha = args.layer_num, args.alpha
device = args.device
features = graph.ndata["feat"]
orig_feats = features.clone()
temp_sum = torch.zeros_like(features)

node_num = features.shape[0]
I_N = dglsp.identity((node_num, node_num))
A_hat = symmetric_normalize_adjacency(graph)

# the inverser random adj
edge_num = int(args.negative_rate * graph.num_edges() / node_num)
# need n*k odd, for networkx
edge_num = ((edge_num + 1) // 2) * 2

for _ in range(layer_num):
# inverse graph convlution (IGC), lazy version
neg_A_hat = inverse_graph_convolution(edge_num, node_num, I_N).to(
device
)
inv_lazy_A = lazy_random_walk(neg_A_hat, beta_neg, I_N).to(device)
inv_features = dglsp.spmm(inv_lazy_A, features)

# lazy graph convolution (LGC)
lazy_A = lazy_random_walk(A_hat, beta, I_N).to(device)
features = dglsp.spmm(lazy_A, features)

# add for multi-scale version
temp_sum += (features + inv_features) / 2.0
beta *= args.decline
beta_neg *= args.decline_neg
embeds = alpha * orig_feats + (1 - alpha) * (
temp_sum / (layer_num * 1.0)
)
return embeds
124 changes: 124 additions & 0 deletions examples/pytorch/ggcm/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import argparse
import copy

import torch
import torch.nn.functional as F
import torch.optim as optim

from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

from ggcm import GGCM
from utils import Classifier


def evaluate(model, embeds, graph):
model.eval()
with torch.no_grad():
output = model(embeds)
pred = output.argmax(dim=-1)
label = graph.ndata["label"]
val_mask, test_mask = graph.ndata["val_mask"], graph.ndata["test_mask"]
loss = F.cross_entropy(output[val_mask], label[val_mask])
accs = []
for mask in [val_mask, test_mask]:
accs.append(float((pred[mask] == label[mask]).sum() / mask.sum()))
return loss.item(), accs[0], accs[1]


def main(args):
# prepare data
transform = AddSelfLoop()
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform)
elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform)
elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform)
else:
raise ValueError("Unknown dataset: {}".format(args.dataset))

graph = data[0].to(args.device)
features = graph.ndata["feat"]
train_mask = graph.ndata["train_mask"]
in_feats = features.shape[1]
n_classes = data.num_classes

# get node embedding
ggcm = GGCM()
embeds = ggcm.get_embedding(graph, args)

# create classifier model
classifier = Classifier(in_feats, n_classes)
optimizer = optim.Adam(
classifier.parameters(), lr=args.lr, weight_decay=args.wd
)

# train classifier
best_acc = -1
for i in range(args.epochs):
classifier.train()
output = classifier(embeds)
loss = F.cross_entropy(
output[train_mask], graph.ndata["label"][train_mask]
)
optimizer.zero_grad()
loss.backward()
optimizer.step()

loss_val, acc_val, acc_test = evaluate(classifier, embeds, graph)
if acc_val > best_acc:
best_acc, best_model = acc_val, copy.deepcopy(classifier)

print(f"{i+1} {loss_val:.4f} {acc_val:.3f} acc_test={acc_test:.3f}")

_, _, acc_test = evaluate(best_model, embeds, graph)
print(f"Final test acc: {acc_test:.4f}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GGCM")
parser.add_argument(
"--dataset",
type=str,
default="citeseer",
choices=["citeseer", "cora", "pubmed"],
help="Dataset to use.",
)
parser.add_argument("--decline", type=float, default=1, help="Decline.")
parser.add_argument("--alpha", type=float, default=0.15, help="Alpha.")
parser.add_argument(
"--epochs", type=int, default=100, help="Number of epochs to train."
)
parser.add_argument(
"--lr", type=float, default=0.13, help="Initial learning rate."
)
parser.add_argument(
"--layer_num", type=int, default=16, help="Degree of the approximation."
)
parser.add_argument(
"--negative_rate",
type=float,
default=20.0,
help="Negative sampling rate for a negative graph.",
)
parser.add_argument(
"--wd",
type=float,
nargs="*",
default=2e-3,
help="Weight decay (L2 loss on parameters).",
)
parser.add_argument(
"--decline_neg", type=float, default=1.0, help="Decline negative."
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="device to use",
)
args, _ = parser.parse_known_args()

main(args)
41 changes: 41 additions & 0 deletions examples/pytorch/ggcm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import dgl
import dgl.sparse as dglsp
import networkx as nx
import torch
import torch.nn as nn


class Classifier(nn.Module):
def __init__(self, in_feats, n_classes):
super(Classifier, self).__init__()
self.fc = nn.Linear(in_feats, n_classes)
self.reset_parameters()

def reset_parameters(self):
self.fc.reset_parameters()

def forward(self, x):
return self.fc(x)


def symmetric_normalize_adjacency(graph):
"""Symmetric normalize graph adjacency matrix."""
indices = torch.stack(graph.edges())
n = graph.num_nodes()
adj = dglsp.spmatrix(indices, shape=(n, n))
deg_invsqrt = dglsp.diag(adj.sum(0)) ** -0.5
return deg_invsqrt @ adj @ deg_invsqrt


def inverse_graph_convolution(edge_num, node_num, I_N):
graph = dgl.from_networkx(nx.random_regular_graph(edge_num, node_num))
indices = torch.stack(graph.edges())
adj = dglsp.spmatrix(indices, shape=(node_num, node_num)).coalesce()

# re-normalization trick
adj_sym_nor = dglsp.sub(2 * I_N, adj) / (edge_num + 2)
return adj_sym_nor


def lazy_random_walk(adj, beta, I_N):
return dglsp.add((1 - beta) * I_N, beta * adj)
44 changes: 44 additions & 0 deletions examples/pytorch/ogc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Optimized Graph Convolution (OGC)

This DGL example implements the OGC method from the paper: [From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited](https://arxiv.org/abs/2309.13599).
With only one trainable layer, OGC is a very simple but powerful graph convolution method.


## Example Implementor

This example was implemented by [Sinuo Xu](https://github.com/SinuoXu) when she was an undergraduate at SJTU.


## Dependencies

Python 3.11.5
PyTorch 2.0.1
DGL 1.1.2
scikit-learn 1.3.1


## Dataset

The DGL's built-in Cora, Pubmed and Citeseer datasets, as follows:

| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |
| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |
| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |


## Usage

```bash
python main.py --dataset cora
python main.py --dataset citeseer
python main.py --dataset pubmed
```

## Performance

| Dataset | Cora | Citeseer | Pubmed |
| :-: | :-: | :-: | :-: |
| OGC (DGL) | **86.9(±0.2)** | **77.4(±0.1)** | **83.6(±0.1)** |
| OGC (Reported) | **86.9(±0.0)** | **77.4(±0.0)** | 83.4(±0.0) |
44 changes: 44 additions & 0 deletions examples/pytorch/ogc/ogc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import dgl.sparse as dglsp
import torch.nn as nn
import torch.nn.functional as F

from utils import LinearNeuralNetwork


class OGC(nn.Module):
def __init__(self, graph):
super(OGC, self).__init__()
self.linear_clf = LinearNeuralNetwork(
nfeat=graph.ndata["feat"].shape[1],
nclass=graph.ndata["label"].max().item() + 1,
bias=False,
)

self.label = graph.ndata["label"]
self.label_one_hot = F.one_hot(graph.ndata["label"]).float()
# LIM trick, else use both train and val set to construct this matrix.
self.label_idx_mat = dglsp.diag(graph.ndata["train_mask"]).float()

self.test_mask = graph.ndata["test_mask"]
self.tv_mask = graph.ndata["train_mask"] + graph.ndata["val_mask"]

def forward(self, x):
return self.linear_clf(x)

def update_embeds(self, embeds, lazy_adj, args):
"""Update classifier's weight by training a linear supervised model."""
pred_label = self(embeds).data
clf_weight = self.linear_clf.W.weight.data

# Update the smoothness loss via LGC.
embeds = dglsp.spmm(lazy_adj, embeds)

# Update the supervised loss via SEB.
deriv_sup = 2 * dglsp.matmul(
dglsp.spmm(self.label_idx_mat, -self.label_one_hot + pred_label),
clf_weight,
)
embeds = embeds - args.lr_sup * deriv_sup

args.lr_sup = args.lr_sup * args.decline
return embeds