diff --git a/examples/README.md b/examples/README.md index 3a9da46c57c7..08964a0dceeb 100644 --- a/examples/README.md +++ b/examples/README.md @@ -5,6 +5,14 @@ The folder contains example implementations of selected research papers related * For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree//examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples) To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/). + +## 2023 + +- Zheng Wang et al. From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited. [Paper link](https://arxiv.org/abs/2210.13339) + - Example code: [PyTorch](../examples/pytorch/ogc) + + - Tags: semi-supervised node classification + ## 2022 - Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339) - Example code: [PyTorch](../examples/labor/train_lightning.py) diff --git a/examples/pytorch/ggcm/README.md b/examples/pytorch/ggcm/README.md new file mode 100644 index 000000000000..1f105b0701b2 --- /dev/null +++ b/examples/pytorch/ggcm/README.md @@ -0,0 +1,41 @@ +# DGL Implementation of GGCM + +This DGL example implements the GGCM method from the paper: [From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited](https://arxiv.org/abs/2309.13599). +The authors' original implementation can be found [here](https://github.com/zhengwang100/ogc_ggcm). + + +## Example Implementor + +This example was implemented by [Sinuo Xu](https://github.com/SinuoXu) when she was an undergraduate at SJTU. + + +## Dependencies +Python 3.11.5
+PyTorch 2.0.1
+DGL 1.1.2
+scikit-learn 1.3.1
+ + +## Dataset +The DGL's built-in Citeseer, Cora and Pubmed datasets, as follows: +| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes | +| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | +| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 | +|Cora |2,708| 10,556| 1,433| 7 |140| 500| 1000| +|Pubmed| 19,717| 88,651| 500 |3| 60| 500| 1000| + + +## Usage +Run with the following (available dataset: "cora", "citeseer", "pubmed") +```bash +python train.py --dataset citeseer +python train.py --dataset cora --decline 1.0 --alpha 0.15 --epochs 100 --lr 0.2 --layer_num 16 --negative_rate 20.0 --wd 1e-5 --decline_neg 0.5 +python train.py --dataset pubmed --decline 1.0 --alpha 0.1 --epochs 100 --lr 0.2 --layer_num 16 --negative_rate 20.0 --wd 2e-5 --decline_neg 0.5 +``` + +## Performance + +|Dataset|citeseer|cora|pubmed| +| :-: | :-: | :-: | :-: | +| GGCM (DGL)|74.1|83.5|80.7| +|GGCM (reported) |74.2|83.6|80.8| diff --git a/examples/pytorch/ggcm/ggcm.py b/examples/pytorch/ggcm/ggcm.py new file mode 100644 index 000000000000..f02107fdb3cd --- /dev/null +++ b/examples/pytorch/ggcm/ggcm.py @@ -0,0 +1,55 @@ +import dgl.sparse as dglsp + +import torch +import torch.nn as nn + +from utils import ( + inverse_graph_convolution, + lazy_random_walk, + symmetric_normalize_adjacency, +) + + +class GGCM(nn.Module): + def __init__(self): + super(GGCM, self).__init__() + + def get_embedding(self, graph, args): + # get the learned node embeddings + beta = 1.0 + beta_neg = 1.0 + layer_num, alpha = args.layer_num, args.alpha + device = args.device + features = graph.ndata["feat"] + orig_feats = features.clone() + temp_sum = torch.zeros_like(features) + + node_num = features.shape[0] + I_N = dglsp.identity((node_num, node_num)) + A_hat = symmetric_normalize_adjacency(graph) + + # the inverser random adj + edge_num = int(args.negative_rate * graph.num_edges() / node_num) + # need n*k odd, for networkx + edge_num = ((edge_num + 1) // 2) * 2 + + for _ in range(layer_num): + # inverse graph convlution (IGC), lazy version + neg_A_hat = inverse_graph_convolution(edge_num, node_num, I_N).to( + device + ) + inv_lazy_A = lazy_random_walk(neg_A_hat, beta_neg, I_N).to(device) + inv_features = dglsp.spmm(inv_lazy_A, features) + + # lazy graph convolution (LGC) + lazy_A = lazy_random_walk(A_hat, beta, I_N).to(device) + features = dglsp.spmm(lazy_A, features) + + # add for multi-scale version + temp_sum += (features + inv_features) / 2.0 + beta *= args.decline + beta_neg *= args.decline_neg + embeds = alpha * orig_feats + (1 - alpha) * ( + temp_sum / (layer_num * 1.0) + ) + return embeds diff --git a/examples/pytorch/ggcm/train.py b/examples/pytorch/ggcm/train.py new file mode 100644 index 000000000000..0f8c87fae3f4 --- /dev/null +++ b/examples/pytorch/ggcm/train.py @@ -0,0 +1,124 @@ +import argparse +import copy + +import torch +import torch.nn.functional as F +import torch.optim as optim + +from dgl import AddSelfLoop +from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset + +from ggcm import GGCM +from utils import Classifier + + +def evaluate(model, embeds, graph): + model.eval() + with torch.no_grad(): + output = model(embeds) + pred = output.argmax(dim=-1) + label = graph.ndata["label"] + val_mask, test_mask = graph.ndata["val_mask"], graph.ndata["test_mask"] + loss = F.cross_entropy(output[val_mask], label[val_mask]) + accs = [] + for mask in [val_mask, test_mask]: + accs.append(float((pred[mask] == label[mask]).sum() / mask.sum())) + return loss.item(), accs[0], accs[1] + + +def main(args): + # prepare data + transform = AddSelfLoop() + if args.dataset == "cora": + data = CoraGraphDataset(transform=transform) + elif args.dataset == "citeseer": + data = CiteseerGraphDataset(transform=transform) + elif args.dataset == "pubmed": + data = PubmedGraphDataset(transform=transform) + else: + raise ValueError("Unknown dataset: {}".format(args.dataset)) + + graph = data[0].to(args.device) + features = graph.ndata["feat"] + train_mask = graph.ndata["train_mask"] + in_feats = features.shape[1] + n_classes = data.num_classes + + # get node embedding + ggcm = GGCM() + embeds = ggcm.get_embedding(graph, args) + + # create classifier model + classifier = Classifier(in_feats, n_classes) + optimizer = optim.Adam( + classifier.parameters(), lr=args.lr, weight_decay=args.wd + ) + + # train classifier + best_acc = -1 + for i in range(args.epochs): + classifier.train() + output = classifier(embeds) + loss = F.cross_entropy( + output[train_mask], graph.ndata["label"][train_mask] + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + loss_val, acc_val, acc_test = evaluate(classifier, embeds, graph) + if acc_val > best_acc: + best_acc, best_model = acc_val, copy.deepcopy(classifier) + + print(f"{i+1} {loss_val:.4f} {acc_val:.3f} acc_test={acc_test:.3f}") + + _, _, acc_test = evaluate(best_model, embeds, graph) + print(f"Final test acc: {acc_test:.4f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GGCM") + parser.add_argument( + "--dataset", + type=str, + default="citeseer", + choices=["citeseer", "cora", "pubmed"], + help="Dataset to use.", + ) + parser.add_argument("--decline", type=float, default=1, help="Decline.") + parser.add_argument("--alpha", type=float, default=0.15, help="Alpha.") + parser.add_argument( + "--epochs", type=int, default=100, help="Number of epochs to train." + ) + parser.add_argument( + "--lr", type=float, default=0.13, help="Initial learning rate." + ) + parser.add_argument( + "--layer_num", type=int, default=16, help="Degree of the approximation." + ) + parser.add_argument( + "--negative_rate", + type=float, + default=20.0, + help="Negative sampling rate for a negative graph.", + ) + parser.add_argument( + "--wd", + type=float, + nargs="*", + default=2e-3, + help="Weight decay (L2 loss on parameters).", + ) + parser.add_argument( + "--decline_neg", type=float, default=1.0, help="Decline negative." + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="device to use", + ) + args, _ = parser.parse_known_args() + + main(args) diff --git a/examples/pytorch/ggcm/utils.py b/examples/pytorch/ggcm/utils.py new file mode 100644 index 000000000000..877736b22654 --- /dev/null +++ b/examples/pytorch/ggcm/utils.py @@ -0,0 +1,41 @@ +import dgl +import dgl.sparse as dglsp +import networkx as nx +import torch +import torch.nn as nn + + +class Classifier(nn.Module): + def __init__(self, in_feats, n_classes): + super(Classifier, self).__init__() + self.fc = nn.Linear(in_feats, n_classes) + self.reset_parameters() + + def reset_parameters(self): + self.fc.reset_parameters() + + def forward(self, x): + return self.fc(x) + + +def symmetric_normalize_adjacency(graph): + """Symmetric normalize graph adjacency matrix.""" + indices = torch.stack(graph.edges()) + n = graph.num_nodes() + adj = dglsp.spmatrix(indices, shape=(n, n)) + deg_invsqrt = dglsp.diag(adj.sum(0)) ** -0.5 + return deg_invsqrt @ adj @ deg_invsqrt + + +def inverse_graph_convolution(edge_num, node_num, I_N): + graph = dgl.from_networkx(nx.random_regular_graph(edge_num, node_num)) + indices = torch.stack(graph.edges()) + adj = dglsp.spmatrix(indices, shape=(node_num, node_num)).coalesce() + + # re-normalization trick + adj_sym_nor = dglsp.sub(2 * I_N, adj) / (edge_num + 2) + return adj_sym_nor + + +def lazy_random_walk(adj, beta, I_N): + return dglsp.add((1 - beta) * I_N, beta * adj) diff --git a/examples/pytorch/ogc/README.md b/examples/pytorch/ogc/README.md new file mode 100644 index 000000000000..ca6a9c087933 --- /dev/null +++ b/examples/pytorch/ogc/README.md @@ -0,0 +1,44 @@ +# Optimized Graph Convolution (OGC) + +This DGL example implements the OGC method from the paper: [From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited](https://arxiv.org/abs/2309.13599). +With only one trainable layer, OGC is a very simple but powerful graph convolution method. + + +## Example Implementor + +This example was implemented by [Sinuo Xu](https://github.com/SinuoXu) when she was an undergraduate at SJTU. + + +## Dependencies + +Python 3.11.5 +PyTorch 2.0.1 +DGL 1.1.2 +scikit-learn 1.3.1 + + +## Dataset + +The DGL's built-in Cora, Pubmed and Citeseer datasets, as follows: + +| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes | +| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | +| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 | +| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 | +| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 | + + +## Usage + +```bash +python main.py --dataset cora +python main.py --dataset citeseer +python main.py --dataset pubmed +``` + +## Performance + +| Dataset | Cora | Citeseer | Pubmed | +| :-: | :-: | :-: | :-: | +| OGC (DGL) | **86.9(±0.2)** | **77.4(±0.1)** | **83.6(±0.1)** | +| OGC (Reported) | **86.9(±0.0)** | **77.4(±0.0)** | 83.4(±0.0) | diff --git a/examples/pytorch/ogc/ogc.py b/examples/pytorch/ogc/ogc.py new file mode 100644 index 000000000000..0af54ba0ce6c --- /dev/null +++ b/examples/pytorch/ogc/ogc.py @@ -0,0 +1,44 @@ +import dgl.sparse as dglsp +import torch.nn as nn +import torch.nn.functional as F + +from utils import LinearNeuralNetwork + + +class OGC(nn.Module): + def __init__(self, graph): + super(OGC, self).__init__() + self.linear_clf = LinearNeuralNetwork( + nfeat=graph.ndata["feat"].shape[1], + nclass=graph.ndata["label"].max().item() + 1, + bias=False, + ) + + self.label = graph.ndata["label"] + self.label_one_hot = F.one_hot(graph.ndata["label"]).float() + # LIM trick, else use both train and val set to construct this matrix. + self.label_idx_mat = dglsp.diag(graph.ndata["train_mask"]).float() + + self.test_mask = graph.ndata["test_mask"] + self.tv_mask = graph.ndata["train_mask"] + graph.ndata["val_mask"] + + def forward(self, x): + return self.linear_clf(x) + + def update_embeds(self, embeds, lazy_adj, args): + """Update classifier's weight by training a linear supervised model.""" + pred_label = self(embeds).data + clf_weight = self.linear_clf.W.weight.data + + # Update the smoothness loss via LGC. + embeds = dglsp.spmm(lazy_adj, embeds) + + # Update the supervised loss via SEB. + deriv_sup = 2 * dglsp.matmul( + dglsp.spmm(self.label_idx_mat, -self.label_one_hot + pred_label), + clf_weight, + ) + embeds = embeds - args.lr_sup * deriv_sup + + args.lr_sup = args.lr_sup * args.decline + return embeds diff --git a/examples/pytorch/ogc/train.py b/examples/pytorch/ogc/train.py new file mode 100644 index 000000000000..d78c63ae4ddb --- /dev/null +++ b/examples/pytorch/ogc/train.py @@ -0,0 +1,126 @@ +import argparse +import time + +import dgl.sparse as dglsp + +import torch.nn.functional as F +import torch.optim as optim +from dgl import AddSelfLoop +from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset + +from ogc import OGC +from utils import model_test, symmetric_normalize_adjacency + + +def train(model, embeds, lazy_adj, args): + patience = 0 + _, _, last_acc, last_output = model_test(model, embeds) + + tv_mask = model.tv_mask + optimizer = optim.SGD(model.parameters(), lr=args.lr_clf) + + for i in range(64): + model.train() + output = model(embeds) + loss_tv = F.mse_loss( + output[tv_mask], model.label_one_hot[tv_mask], reduction="sum" + ) + optimizer.zero_grad() + loss_tv.backward() + optimizer.step() + + # Updating node embeds by LGC and SEB jointly. + embeds = model.update_embeds(embeds, lazy_adj, args) + + loss_tv, acc_tv, acc_test, pred = model_test(model, embeds) + print( + "epoch {} loss_tv {:.4f} acc_tv {:.4f} acc_test {:.4f}".format( + i + 1, loss_tv, acc_tv, acc_test + ) + ) + + sim_rate = float(int((pred == last_output).sum()) / int(pred.shape[0])) + if sim_rate > args.max_sim_rate: + patience += 1 + if patience > args.max_patience: + break + last_acc = acc_test + last_output = pred + return last_acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=str, + default="citeseer", + choices=["cora", "citeseer", "pubmed"], + help="dataset to use", + ) + parser.add_argument( + "--decline", type=float, default=0.9, help="decline rate" + ) + parser.add_argument( + "--lr_sup", + type=float, + default=0.001, + help="learning rate for supervised loss", + ) + parser.add_argument( + "--lr_clf", + type=float, + default=0.5, + help="learning rate for the used linear classifier", + ) + parser.add_argument( + "--beta", + type=float, + default=0.1, + help="moving probability that a node moves to its neighbors", + ) + parser.add_argument( + "--max_sim_rate", + type=float, + default=0.995, + help="max label prediction similarity between iterations", + ) + parser.add_argument( + "--max_patience", + type=int, + default=2, + help="tolerance for consecutively similar test predictions", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="device to use", + ) + args, _ = parser.parse_known_args() + + # Load and preprocess dataset. + transform = AddSelfLoop() + if args.dataset == "cora": + data = CoraGraphDataset(transform=transform) + elif args.dataset == "citeseer": + data = CiteseerGraphDataset(transform=transform) + elif args.dataset == "pubmed": + data = PubmedGraphDataset(transform=transform) + else: + raise ValueError("Unknown dataset: {}".format(args.dataset)) + graph = data[0].to(args.device) + features = graph.ndata["feat"] + adj = symmetric_normalize_adjacency(graph) + I_N = dglsp.identity((features.shape[0], features.shape[0])) + # Lazy random walk (also known as lazy graph convolution). + lazy_adj = dglsp.add((1 - args.beta) * I_N, args.beta * adj).to(args.device) + + model = OGC(graph).to(args.device) + start_time = time.time() + res = train(model, features, lazy_adj, args) + time_tot = time.time() - start_time + + print(f"Test Acc:{res:.4f}") + print(f"Total Time:{time_tot:.4f}") diff --git a/examples/pytorch/ogc/utils.py b/examples/pytorch/ogc/utils.py new file mode 100644 index 000000000000..95b61b6b07bf --- /dev/null +++ b/examples/pytorch/ogc/utils.py @@ -0,0 +1,35 @@ +import dgl.sparse as dglsp +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LinearNeuralNetwork(nn.Module): + def __init__(self, nfeat, nclass, bias=True): + super(LinearNeuralNetwork, self).__init__() + self.W = nn.Linear(nfeat, nclass, bias=bias) + + def forward(self, x): + return self.W(x) + + +def symmetric_normalize_adjacency(graph): + """Symmetric normalize graph adjacency matrix.""" + indices = torch.stack(graph.edges()) + n = graph.num_nodes() + adj = dglsp.spmatrix(indices, shape=(n, n)) + deg_invsqrt = dglsp.diag(adj.sum(0)) ** -0.5 + return deg_invsqrt @ adj @ deg_invsqrt + + +def model_test(model, embeds): + model.eval() + with torch.no_grad(): + output = model(embeds) + pred = output.argmax(dim=-1) + test_mask, tv_mask = model.test_mask, model.tv_mask + loss_tv = F.mse_loss(output[tv_mask], model.label_one_hot[tv_mask]) + accs = [] + for mask in [tv_mask, test_mask]: + accs.append(float((pred[mask] == model.label[mask]).sum() / mask.sum())) + return loss_tv.item(), accs[0], accs[1], pred