diff --git a/.gitignore b/.gitignore index 1dd7612f..3f9d9a44 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,5 @@ cscope.* # config file /config local_scripts/ + +profiler/ \ No newline at end of file diff --git a/examples/pytorch/math_word_problem/mawps/src/evaluation.py b/examples/pytorch/math_word_problem/mawps/src/evaluation.py index c26280cc..c4f0fb87 100644 --- a/examples/pytorch/math_word_problem/mawps/src/evaluation.py +++ b/examples/pytorch/math_word_problem/mawps/src/evaluation.py @@ -1,8 +1,8 @@ -from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase - import sympy from sympy.parsing.sympy_parser import parse_expr +from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase + class SolutionMatch(EvaluationMetricBase): def __init__(self): diff --git a/examples/pytorch/question_generation/config/squad_split2/rgcn_dependency.json b/examples/pytorch/question_generation/config/squad_split2/rgcn_dependency.json new file mode 100644 index 00000000..3536da4e --- /dev/null +++ b/examples/pytorch/question_generation/config/squad_split2/rgcn_dependency.json @@ -0,0 +1,12 @@ +{ +"config_path": "examples/pytorch/question_generation/config/squad_split2/qg.yaml", +"model_args.graph_construction_args.graph_construction_share.topology_subdir": "DependencyGraphForRGCN", +"model_args.graph_construction_args.graph_construction_private.edge_strategy": "heterogeneous", +"model_args.graph_construction_args.graph_construction_private.merge_strategy": "tailhead", +"model_args.graph_construction_args.graph_construction_private.sequential_link": true, +"model_args.graph_construction_args.graph_construction_private.as_node": false, +"model_args.graph_embedding_name": "rgcn", +"model_args.graph_embedding_args.graph_embedding_private.num_rels": 80, +"model_args.graph_embedding_args.graph_embedding_private.num_bases": 4, +"checkpoint_args.out_dir": "out/squad_split2/rgcn_dependency_ckpt" +} diff --git a/examples/pytorch/question_generation/main.py b/examples/pytorch/question_generation/main.py index 326c9a80..e35cf87d 100644 --- a/examples/pytorch/question_generation/main.py +++ b/examples/pytorch/question_generation/main.py @@ -26,6 +26,8 @@ from graph4nlp.pytorch.modules.utils.generic_utils import EarlyStopping, to_cuda from graph4nlp.pytorch.modules.utils.logger import Logger +from examples.pytorch.semantic_parsing.graph2seq.rgcn_lib.graph2seq import RGCNGraph2Seq + from .fused_embedding_construction import FusedEmbeddingConstruction @@ -39,7 +41,10 @@ def __init__(self, vocab, config): ] # build Graph2Seq model - self.g2s = Graph2Seq.from_args(config, self.vocab) + if config["model_args"]["graph_embedding_name"] == "rgcn": + self.g2s = RGCNGraph2Seq.from_args(config, self.vocab) + else: + self.g2s = Graph2Seq.from_args(config, self.vocab) if "w2v" in self.g2s.graph_initializer.embedding_layer.word_emb_layers: self.word_emb = self.g2s.graph_initializer.embedding_layer.word_emb_layers[ diff --git a/examples/pytorch/rgcn/rgcn.py b/examples/pytorch/rgcn/rgcn.py index 0779e904..7f738a54 100644 --- a/examples/pytorch/rgcn/rgcn.py +++ b/examples/pytorch/rgcn/rgcn.py @@ -46,7 +46,6 @@ def __init__( num_bases=None, use_self_loop=True, dropout=0.0, - device="cuda", ): super(RGCN, self).__init__() self.num_layers = num_layers @@ -54,7 +53,6 @@ def __init__( self.num_bases = num_bases self.use_self_loop = use_self_loop self.dropout = dropout - self.device = device self.RGCN_layers = nn.ModuleList() @@ -185,35 +183,31 @@ def __init__( self_loop=False, dropout=0.0, layer_norm=False, - device="cuda", ): super(RGCNLayer, self).__init__() self.linear_dict = { - i: nn.Linear(input_size, output_size, bias=bias, device=device) for i in range(num_rels) + i: nn.Linear(input_size, output_size, bias=bias) for i in range(num_rels) } # self.linear_r = TypedLinear(input_size, output_size, num_rels, regularizer, num_bases) self.bias = bias self.activation = activation self.self_loop = self_loop self.layer_norm = layer_norm - self.device = device # bias if self.bias: - self.h_bias = nn.Parameter(torch.Tensor(output_size)).to(device) + self.h_bias = nn.Parameter(torch.Tensor(output_size)) nn.init.zeros_(self.h_bias) # TODO(minjie): consider remove those options in the future to make # the module only about graph convolution. # layer norm if self.layer_norm: - self.layer_norm_weight = nn.LayerNorm( - output_size, elementwise_affine=True, device=device - ) + self.layer_norm_weight = nn.LayerNorm(output_size, elementwise_affine=True) # weight for self loop if self.self_loop: - self.loop_weight = nn.Parameter(torch.Tensor(input_size, output_size)).to(device) + self.loop_weight = nn.Parameter(torch.Tensor(input_size, output_size)) nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain("relu")) self.dropout = nn.Dropout(dropout) diff --git a/examples/pytorch/semantic_parsing/graph2seq/config/train_dep_rgcn_bi_sep.json b/examples/pytorch/semantic_parsing/graph2seq/config/train_dep_rgcn_bi_sep.json new file mode 100644 index 00000000..b19c6336 --- /dev/null +++ b/examples/pytorch/semantic_parsing/graph2seq/config/train_dep_rgcn_bi_sep.json @@ -0,0 +1,5 @@ +{ + "config_path": "examples/pytorch/semantic_parsing/graph2seq/config/dependency_rgcn_undirected.yaml", + "model_args.graph_embedding_args.graph_embedding_share.direction_option": "bi_sep", + "training_args.log_file": "examples/pytorch/semantic_parsing/graph2seq/log/dependency_rgcn_bi_sep.txt" +} diff --git a/examples/pytorch/semantic_parsing/graph2seq/config/train_dep_rgcn_undirected.json b/examples/pytorch/semantic_parsing/graph2seq/config/train_dep_rgcn_undirected.json new file mode 100644 index 00000000..77deec0e --- /dev/null +++ b/examples/pytorch/semantic_parsing/graph2seq/config/train_dep_rgcn_undirected.json @@ -0,0 +1,5 @@ +{ + "config_path": "examples/pytorch/semantic_parsing/graph2seq/config/dependency_rgcn_undirected.yaml", + "model_args.graph_embedding_args.graph_embedding_share.direction_option": "undirected", + "training_args.log_file": "examples/pytorch/semantic_parsing/graph2seq/log/dependency_rgcn_undirected.txt" +} diff --git a/examples/pytorch/semantic_parsing/graph2seq/main_rgcn.py b/examples/pytorch/semantic_parsing/graph2seq/main_rgcn.py index 5b3b7a24..eb683ebc 100644 --- a/examples/pytorch/semantic_parsing/graph2seq/main_rgcn.py +++ b/examples/pytorch/semantic_parsing/graph2seq/main_rgcn.py @@ -52,7 +52,7 @@ def _build_logger(self, log_file): import os log_folder = os.path.split(log_file)[0] - if not os.path.exists(log_file): + if not os.path.exists(log_folder): os.makedirs(log_folder) self.logger = get_log(log_file) diff --git a/examples/pytorch/semantic_parsing/graph2seq/rgcn_lib/graph2seq.py b/examples/pytorch/semantic_parsing/graph2seq/rgcn_lib/graph2seq.py index 917264c8..0a43a59a 100644 --- a/examples/pytorch/semantic_parsing/graph2seq/rgcn_lib/graph2seq.py +++ b/examples/pytorch/semantic_parsing/graph2seq/rgcn_lib/graph2seq.py @@ -1,6 +1,6 @@ from graph4nlp.pytorch.models.graph2seq import Graph2Seq - -from examples.pytorch.rgcn.rgcn import RGCN +# from examples.pytorch.rgcn.rgcn import RGCN +from graph4nlp.pytorch.modules.graph_embedding_learning.rgcn import RGCN class RGCNGraph2Seq(Graph2Seq): @@ -74,10 +74,12 @@ def __init__( def _build_gnn_encoder( self, + gnn, num_layers, input_size, hidden_size, output_size, + direction_option, feats_dropout, gnn_num_rels=80, gnn_num_bases=4, @@ -89,6 +91,8 @@ def _build_gnn_encoder( hidden_size, output_size, num_rels=gnn_num_rels, - num_bases=gnn_num_bases, - dropout=feats_dropout, + direction_option=direction_option, + # num_bases=gnn_num_bases, + # dropout=feats_dropout, + feat_drop=feats_dropout, ) diff --git a/examples/pytorch/text_classification/config/CAirline/gat_bi_sep_dependency.json b/examples/pytorch/text_classification/config/CAirline/gat_bi_sep_dependency.json index 9fa30f7b..20c14193 100644 --- a/examples/pytorch/text_classification/config/CAirline/gat_bi_sep_dependency.json +++ b/examples/pytorch/text_classification/config/CAirline/gat_bi_sep_dependency.json @@ -7,5 +7,5 @@ "model_args.graph_embedding_args.graph_embedding_private.negative_slope": "0.2", "model_args.graph_embedding_args.graph_embedding_private.residual": "false", "model_args.graph_embedding_args.graph_embedding_private.allow_zero_in_degree": "true", -"checkpoint_args.out_dir": "out/trec/gat_bi_sep_dependency_ckpt" +"checkpoint_args.out_dir": "out/CAirline/gat_bi_sep_dependency_ckpt" } diff --git a/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_constituency.json b/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_constituency.json index 4ad1b6ff..ffc9be49 100644 --- a/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_constituency.json +++ b/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_constituency.json @@ -2,5 +2,5 @@ "config_path": "examples/pytorch/text_classification/config/CAirline/text_clf.yaml", "model_args.graph_construction_name": "constituency", "model_args.graph_construction_args.graph_construction_share.topology_subdir": "constituency_graph", -"checkpoint_args.out_dir": "out/trec/ggnn_bi_sep_constituency_ckpt" +"checkpoint_args.out_dir": "out/CAirline/ggnn_bi_sep_constituency_ckpt" } diff --git a/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_node_emb.json b/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_node_emb.json index 12b10e7b..b5470993 100644 --- a/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_node_emb.json +++ b/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_node_emb.json @@ -13,5 +13,5 @@ "model_args.graph_embedding_args.graph_embedding_share.hidden_size": "300", "model_args.graph_embedding_args.graph_embedding_share.output_size": "300", "model_args.graph_embedding_args.graph_embedding_private.use_edge_weight": "true", -"checkpoint_args.out_dir": "out/trec/ggnn_bi_sep_node_emb_ckpt" +"checkpoint_args.out_dir": "out/CAirline/ggnn_bi_sep_node_emb_ckpt" } diff --git a/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_node_emb_refined_dependency.json b/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_node_emb_refined_dependency.json index 946f50a4..53b77370 100644 --- a/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_node_emb_refined_dependency.json +++ b/examples/pytorch/text_classification/config/CAirline/ggnn_bi_sep_node_emb_refined_dependency.json @@ -15,5 +15,5 @@ "model_args.graph_embedding_args.graph_embedding_share.hidden_size": "300", "model_args.graph_embedding_args.graph_embedding_share.output_size": "300", "model_args.graph_embedding_args.graph_embedding_private.use_edge_weight": "true", -"checkpoint_args.out_dir": "out/trec/ggnn_bi_sep_node_emb_refined_dependency_ckpt" +"checkpoint_args.out_dir": "out/CAirline/ggnn_bi_sep_node_emb_refined_dependency_ckpt" } diff --git a/examples/pytorch/text_classification/config/CAirline/rgcn_dependency.json b/examples/pytorch/text_classification/config/CAirline/rgcn_dependency.json new file mode 100644 index 00000000..14f8f181 --- /dev/null +++ b/examples/pytorch/text_classification/config/CAirline/rgcn_dependency.json @@ -0,0 +1,13 @@ +{ +"config_path": "examples/pytorch/text_classification/config/CAirline/text_clf.yaml", +"model_args.graph_construction_args.graph_construction_share.topology_subdir": "dependency_graph_for_rgcn", +"model_args.graph_construction_args.graph_construction_private.edge_strategy": "heterogeneous", +"model_args.graph_construction_args.graph_construction_private.merge_strategy": "tailhead", +"model_args.graph_construction_args.graph_construction_private.sequential_link": true, +"model_args.graph_construction_args.graph_construction_private.as_node": false, +"model_args.graph_embedding_name": "rgcn", +"model_args.graph_embedding_args.graph_embedding_share.direction_option": "undirected", +"model_args.graph_embedding_args.graph_embedding_private.num_rels": 80, +"model_args.graph_embedding_args.graph_embedding_private.num_bases": 4, +"checkpoint_args.out_dir": "out/CAirline/rgcn_dependency_ckpt" +} diff --git a/examples/pytorch/text_classification/config/trec/rgcn_dependency.json b/examples/pytorch/text_classification/config/trec/rgcn_dependency.json new file mode 100644 index 00000000..7253078a --- /dev/null +++ b/examples/pytorch/text_classification/config/trec/rgcn_dependency.json @@ -0,0 +1,14 @@ +{ +"config_path": "examples/pytorch/text_classification/config/trec/text_clf.yaml", +"model_args.graph_construction_args.graph_construction_share.topology_subdir": "dependency_graph_for_rgcn", +"model_args.graph_construction_args.graph_construction_private.edge_strategy": "heterogeneous", +"model_args.graph_construction_args.graph_construction_private.merge_strategy": "tailhead", +"model_args.graph_construction_args.graph_construction_private.sequential_link": true, +"model_args.graph_construction_args.graph_construction_private.as_node": false, +"model_args.graph_embedding_name": "rgcn", +"model_args.graph_embedding_args.graph_embedding_share.direction_option": "undirected", +"model_args.graph_embedding_args.graph_embedding_private.num_rels": 80, +"model_args.graph_embedding_args.graph_embedding_private.num_bases": 4, +"training_args.lr": "0.002", +"checkpoint_args.out_dir": "out/trec/rgcn_dependency_ckpt" +} diff --git a/examples/pytorch/text_classification/readme.md b/examples/pytorch/text_classification/readme.md index fbf30db5..d8039858 100644 --- a/examples/pytorch/text_classification/readme.md +++ b/examples/pytorch/text_classification/readme.md @@ -29,12 +29,12 @@ TREC Results ------- -| GraphType\GNN | GAT-BiSep | GAT-BiFuse | GraphSAGE-BiSep | GraphSAGE-BiFuse | GGNN-BiSep | GGNN-BiFuse | -| ------------- | ------------- | --------------| ------------------- | ----------------- |-------------- | ------------- | -| Dependency | 0.9480 | 0.9460 | 0.942 | 0.958 | 0.954 | 0.9440 | -| Constituency | 0.9420 | 0.9300 | 0.952 | 0.950 | 0.952 | 0.9400 | -| NodeEmb | N/A | N/A | 0.930 | 0.908 | | | -| NodeEmbRefined | N/A | N/A | 0.940 | 0.926 | | | +| GraphType\GNN | GAT-BiSep | GAT-BiFuse | GraphSAGE-BiSep | GraphSAGE-BiFuse | GGNN-BiSep | GGNN-BiFuse | RGCN | +| ------------- | ------------- | --------------| ------------------- | ----------------- |-------------- | ------------- | ----- | +| Dependency | 0.9480 | 0.9460 | 0.942 | 0.958 | 0.954 | 0.944 | 0.946 | +| Constituency | 0.9420 | 0.9300 | 0.952 | 0.950 | 0.952 | 0.94 | N/A | +| NodeEmb | N/A | N/A | 0.930 | 0.908 | N/A | N/A | N/A | +| NodeEmbRefined | N/A | N/A | 0.940 | 0.926 | N/A | N/A | N/A | @@ -42,10 +42,10 @@ CAirline Results ------- -| GraphType\GNN | GAT-BiSep | GGNN-BiSep |GraphSage-BiSep| -| -------------- | ------------ | ------------- |---------------| -| Dependency | 0.7496 | 0.8020 | 0.7977 | -| Constituency | 0.7846 | 0.7933 | 0.7948 | -| NodeEmb | N/A | 0.8108 | 0.8108 | -| NodeEmbRefined | N/A | 0.7991 | 0.8020 | +| GraphType\GNN | GAT-BiSep | GGNN-BiSep |GraphSage-BiSep| RGCN | +| -------------- | ------------ | ------------- |---------------|---------------| +| Dependency | 0.7496 | 0.8020 | 0.7977 | 0.7525 | +| Constituency | 0.7846 | 0.7933 | 0.7948 | N/A | +| NodeEmb | N/A | 0.8108 | 0.8108 | N/A | +| NodeEmbRefined | N/A | 0.7991 | 0.8020 | N/A | diff --git a/examples/pytorch/text_classification/run_text_classifier.py b/examples/pytorch/text_classification/run_text_classifier.py index 037189e6..c316d772 100644 --- a/examples/pytorch/text_classification/run_text_classifier.py +++ b/examples/pytorch/text_classification/run_text_classifier.py @@ -32,6 +32,8 @@ from graph4nlp.pytorch.modules.utils.generic_utils import EarlyStopping, to_cuda from graph4nlp.pytorch.modules.utils.logger import Logger +from graph4nlp.pytorch.modules.graph_embedding_learning.rgcn import RGCN + torch.multiprocessing.set_sharing_strategy("file_system") @@ -217,6 +219,30 @@ def __init__(self, vocab, label_model, config): "graph_embedding_private" ]["use_edge_weight"], ) + elif config["model_args"]["graph_embedding_name"] == "rgcn": + self.gnn = RGCN( + config["model_args"]["graph_embedding_args"]["graph_embedding_share"]["num_layers"], + config["model_args"]["graph_embedding_args"]["graph_embedding_share"]["input_size"], + config["model_args"]["graph_embedding_args"]["graph_embedding_share"][ + "hidden_size" + ], + config["model_args"]["graph_embedding_args"]["graph_embedding_share"][ + "output_size" + ], + num_rels=config["model_args"]["graph_embedding_args"]["graph_embedding_private"][ + "num_rels" + ], + direction_option=config["model_args"]["graph_embedding_args"][ + "graph_embedding_share" + ]["direction_option"], + feat_drop=config["model_args"]["graph_embedding_args"]["graph_embedding_share"][ + "feat_drop" + ], + regularizer="basis", + num_bases=config["model_args"]["graph_embedding_args"]["graph_embedding_private"][ + "num_bases" + ], + ) else: raise RuntimeError( "Unknown gnn type: {}".format(config["model_args"]["graph_embedding_name"]) diff --git a/graph4nlp/pytorch/data/data.py b/graph4nlp/pytorch/data/data.py index 67ccafa0..383c6647 100644 --- a/graph4nlp/pytorch/data/data.py +++ b/graph4nlp/pytorch/data/data.py @@ -176,7 +176,7 @@ def add_nodes(self, node_num: int, ntypes: List[str] = None): ) if not self.is_hetero: - if ntypes is not None: + if ntypes is not None and len(set(ntypes)) > 1: raise ValueError( "The graph is homogeneous, ntypes should be None. Got {}".format(ntypes) ) @@ -787,7 +787,7 @@ def _data_dict(self) -> Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Ten ) return data_dict - def to_dgl(self) -> dgl.DGLGraph: + def to_dgl(self) -> dgl.DGLHeteroGraph: """ Convert to dgl.DGLGraph Note that there will be some information loss when calling this function, @@ -796,8 +796,8 @@ def to_dgl(self) -> dgl.DGLGraph: Returns ------- - g : dgl.DGLGraph - The converted dgl.DGLGraph + g : dgl.DGLHeteroGraph + The converted dgl.DGLHeteroGraph """ u, v = self._edge_indices.src, self._edge_indices.tgt num_nodes = self.get_node_num() @@ -903,13 +903,13 @@ def make_num_nodes_dict( return dgl_g - def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False): + def from_dgl(self, dgl_g: dgl.DGLHeteroGraph, is_hetero=False): """ - Build the graph from dgl.DGLGraph + Build the graph from dgl.DGLHeteroGraph Parameters ---------- - dgl_g : dgl.DGLGraph + dgl_g : dgl.DGLHeteroGraph The source graph """ if not (self.get_edge_num() == 0 and self.get_node_num() == 0): @@ -950,6 +950,10 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False): processed_node_types = False node_feat_dict = {} for feature_name, data_dict in node_data.items(): + if not isinstance(data_dict, Dict): + # DGL will return tensor if ntype is single + # This can happen when graph is a multigraph + data_dict = {dgl_g.ntypes[0]: data_dict} if not processed_node_types: for node_type, node_feature in data_dict.items(): ntypes += [node_type] * len(node_feature) @@ -967,7 +971,8 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False): num_edges = dgl_g.num_edges(etype) src_type, r_type, dst_type = etype srcs, dsts = dgl_g.find_edges( - torch.tensor(list(range(num_edges)), dtype=torch.long), etype + torch.tensor(list(range(num_edges)), dtype=torch.long, device=dgl_g.device), + etype, ) srcs, dsts = ( srcs.detach().cpu().numpy().tolist(), @@ -1386,8 +1391,9 @@ def from_dgl(g: dgl.DGLGraph) -> GraphData: GraphData The converted graph in GraphData format. """ - graph = GraphData(is_hetero=not g.is_homogeneous) - graph.from_dgl(g, is_hetero=not g.is_homogeneous) + dgl_g_is_hetero = (not g.is_homogeneous) or g.is_multigraph + graph = GraphData(is_hetero=dgl_g_is_hetero) + graph.from_dgl(g, is_hetero=dgl_g_is_hetero) return graph diff --git a/graph4nlp/pytorch/models/graph2seq.py b/graph4nlp/pytorch/models/graph2seq.py index 682da3e5..5b2ea919 100644 --- a/graph4nlp/pytorch/models/graph2seq.py +++ b/graph4nlp/pytorch/models/graph2seq.py @@ -26,7 +26,7 @@ class Graph2Seq(Graph2XBase): >>> "It is just a how-to-use example." >>> from graph4nlp.pytorch.modules.config import get_basic_args >>> opt = get_basic_args(graph_construction_name="node_emb", graph_embedding_name="gat", decoder_name="stdrnn") - >>> graph2seq = Graph2Seq.from_args(opt=opt, vocab_model=vocab_model, device=torch.device("cuda:0")) + >>> graph2seq = Graph2Seq.from_args(opt=opt, vocab_model=vocab_model) >>> batch_graph = [GraphData() for _ in range(2)] >>> tgt_seq = torch.Tensor([[1, 2, 3], [4, 5, 6]]) >>> seq_out, _, _ = graph2seq(batch_graph=batch_graph, tgt_seq=tgt_seq) diff --git a/graph4nlp/pytorch/modules/graph_embedding_learning/rgcn.py b/graph4nlp/pytorch/modules/graph_embedding_learning/rgcn.py new file mode 100644 index 00000000..8d568519 --- /dev/null +++ b/graph4nlp/pytorch/modules/graph_embedding_learning/rgcn.py @@ -0,0 +1,826 @@ +import dgl +import dgl.function as fn +import torch +import torch.nn as nn +from dgl.nn.pytorch.linear import TypedLinear +import dgl.nn as dglnn +import typing as tp + +from .base import GNNBase, GNNLayerBase +from ...data import GraphData, from_dgl + +# The implementation of RGCN is copied from DGL +class RelGraphConvLayer(nn.Module): + r"""Relational graph convolution layer. + + Parameters + ---------- + in_feat : int + Input feature size. + out_feat : int + Output feature size. + rel_names : list[str] + Relation names. + num_bases : int, optional + Number of bases. If is none, use number of relations. Default: None. + weight : bool, optional + True if a linear layer is applied after message passing. Default: True + bias : bool, optional + True if bias is added. Default: True + activation : callable, optional + Activation function. Default: None + self_loop : bool, optional + True to include self loop message. Default: False + dropout : float, optional + Dropout rate. Default: 0.0 + """ + + def __init__( + self, + in_feat, + out_feat, + num_rels, + num_bases, + *, + weight=True, + bias=True, + activation=None, + self_loop=False, + dropout=0.0, + ): + super(RelGraphConvLayer, self).__init__() + self.in_feat = in_feat + self.out_feat = out_feat + self.num_rels = num_rels + self.num_bases = num_bases + self.bias = bias + self.activation = activation + self.self_loop = self_loop + + self.conv = dglnn.HeteroGraphConv( + { + f"rel_{rel}": dglnn.GraphConv(in_feat, out_feat, norm="right", weight=False, bias=False) + for rel in range(num_rels) + } + ) + + self.use_weight = weight + self.use_basis = num_bases < self.num_rels and weight + if self.use_weight: + if self.use_basis: + self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, self.num_rels) + else: + self.weight = nn.Parameter(torch.Tensor(self.num_rels, in_feat, out_feat)) + nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain("relu")) + + # bias + if bias: + self.h_bias = nn.Parameter(torch.Tensor(out_feat)) + nn.init.zeros_(self.h_bias) + + # weight for self loop + if self.self_loop: + self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat)) + nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain("relu")) + + self.dropout = nn.Dropout(dropout) + self.etype_map = {} + + def forward(self, g: dgl.DGLHeteroGraph, inputs: tp.Dict[str, torch.Tensor]): + """Forward computation + + Parameters + ---------- + g : DGLHeteroGraph + Input graph. + inputs : dict[str, torch.Tensor] + Node feature for each node type. + + Returns + ------- + dict[str, torch.Tensor] + New node features for each node type. + """ + g = g.local_var() + + # def create_new_graph(): + + + new_canonical_etypes = [] + new_etypes = [] + for src_type, edge_type, dst_type in g.canonical_etypes: + new_edge_type = self.etype_map.setdefault(edge_type, f"rel_{len(self.etype_map)}") + new_canonical_etypes.append((src_type, new_edge_type, dst_type)) + new_etypes.append(new_edge_type) + g._etypes = new_etypes + g._canonical_etypes = new_canonical_etypes + g._etype2canonical = {etype: canonical_etype for etype, canonical_etype in zip(new_etypes, new_canonical_etypes)} + g._etypes_invmap = {canonical_etype: i for i, canonical_etype in enumerate(new_canonical_etypes)} + + if self.use_weight: + weight = self.basis() if self.use_basis else self.weight + wdict = { + f"rel_{i}": {"weight": w.squeeze(0)} + for i, w in enumerate(torch.split(weight, 1, dim=0)) + } + else: + wdict = {} + + if g.is_block: + inputs_src = inputs + inputs_dst = {k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()} + else: + inputs_src = inputs_dst = inputs + + hs = self.conv(g, inputs, mod_kwargs=wdict) + + def _apply(ntype, h): + if self.self_loop: + h = h + torch.matmul(inputs_dst[ntype], self.loop_weight) + if self.bias: + h = h + self.h_bias + if self.activation: + h = self.activation(h) + return self.dropout(h) + + return {ntype: _apply(ntype, h) for ntype, h in hs.items()} + + +class RGCN(GNNBase): + r"""Multi-layered `RGCN Network `__ + + .. math:: + TODO:Add Calculation. + + Parameters + ---------- + num_layers: int + Number of RGCN layers. + input_size : int, or pair of ints + Input feature size. + hidden_size: int list of int + Hidden layer size. + If a scalar is given, the sizes of all the hidden layers are the same. + If a list of scalar is given, each element in the list is the size of each hidden layer. + Example: [100,50] + output_size : int + Output feature size. + rel_names : List[str] + List of relation names. + num_bases : int, optional + Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + self_loop : bool, optional + True to include self loop message. Default: ``True``. + feat_drop : float, optional + dropout rate. Default: ``0.0`` + """ + + def __init__( + self, + num_layers, + input_size, + hidden_size, + output_size, + num_rels=None, + direction_option=None, + bias=True, + activation=None, + self_loop=True, + feat_drop=0.0, + regularizer="none", + num_bases=4, + ): + super(RGCN, self).__init__() + self.num_layers = num_layers + self.num_rels = num_rels + self.self_loop = self_loop + self.feat_drop = feat_drop + self.direction_option = direction_option + self.activation = activation + self.bias = bias + self.RGCN_layers = nn.ModuleList() + self.regularizer = regularizer + self.num_basis = num_bases + + # if isinstance(self.num_rels, int): + # self.num_rels = [str(i) for i in range(self.num_rels)] + + # transform the hidden size format + if self.num_layers > 1 and type(hidden_size) is int: + hidden_size = [hidden_size for i in range(self.num_layers - 1)] + + if self.num_layers > 1: + # input projection + self.RGCN_layers.append( + RGCNLayer( + input_size, + hidden_size[0], + num_rels=self.num_rels, + direction_option=self.direction_option, + bias=self.bias, + activation=self.activation, + self_loop=self.self_loop, + feat_drop=self.feat_drop, + regularizer=regularizer, + num_bases=num_bases, + ) + ) + # hidden layers + for l in range(1, self.num_layers - 1): + # due to multi-head, the input_size = hidden_size * num_heads + self.RGCN_layers.append( + RGCNLayer( + hidden_size[l - 1], + hidden_size[l], + num_rels=self.num_rels, + direction_option=self.direction_option, + bias=self.bias, + activation=self.activation, + self_loop=self.self_loop, + feat_drop=self.feat_drop, + regularizer=regularizer, + num_bases=num_bases, + ) + ) + # output projection + self.RGCN_layers.append( + RGCNLayer( + hidden_size[-1] if self.num_layers > 1 else input_size, + output_size, + num_rels=self.num_rels, + direction_option=self.direction_option, + bias=self.bias, + activation=self.activation, + self_loop=self.self_loop, + feat_drop=self.feat_drop, + regularizer=regularizer, + num_bases=num_bases, + ) + ) + # Print named parameters + # for k, v in self.named_parameters(): + # print(f'{k}: {v}') + + def forward(self, graph: GraphData): + r"""Compute RGCN layer. + + Parameters + ---------- + graph : GraphData + The graph with node feature stored in the feature field named as + "node_feat". + The node features are used for message passing. + + Returns + ------- + graph : GraphData + The graph with generated node embedding stored in the feature field + named as "node_emb". + """ + # feat = graph.node_features["node_feat"] + # if self.direction_option == "bi_sep": + # h = [feat, feat] + # else: + # h = feat + + # get the node feature tensor from graph + g = graph.to_dgl() # transfer the current NLPgraph to DGL graph + h: torch.Tensor = g.ndata["node_feat"] + + # Make node feature dictionary + feat_dict: tp.Dict[str, torch.Tensor] = {} + import numpy as np + node_types = np.array(graph.ntypes,) + for i in range(max(node_types) + 1): + index = torch.tensor(np.where(node_types == i)[0], device=graph.device) + feat_dict[i] = torch.index_select(h, 0, index) + + # output projection + if self.num_layers > 1: + for l in range(0, self.num_layers - 1): + h = self.RGCN_layers[l](g, feat_dict) + + h = self.RGCN_layers[-1](g, h) + + if self.direction_option == "bi_sep": + logits = torch.cat(logits, -1) + + # Unpack node feature dictionary + if len(g.ntypes) == 1: + h = h[0] + g.ndata["node_emb"] = h # put the results into the NLPGraph + graph_data = from_dgl(g=g) + if graph.batch is not None: + graph_data.copy_batch_info(graph) + return graph_data + + +class RGCNLayer(GNNLayerBase): + r"""A wrapper for RGCNLayer. + + .. math:: + TODO + + Parameters + ---------- + input_size : int, or pair of ints + Input feature size. + output_size : int + Output feature size. + num_rels: int + number of relations + regularizer : str, optional + Which weight regularizer to use "basis" or "bdd": + - "basis" is short for basis-decomposition. + - "bdd" is short for block-diagonal-decomposition. + Default applies no regularization. + num_bases : int, optional + Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + bias : bool, optional + True if bias is added. Default: ``True``. + activation : callable, optional + Activation function. Default: ``None``. + self_loop : bool, optional + True to include self loop message. Default: ``True``. + feat_drop : float, optional + Dropout rate. Default: ``0.0`` + layer_norm: float, optional + Add layer norm. Default: ``False`` + """ + + def __init__( + self, + input_size, + output_size, + num_rels, + direction_option=None, + bias=True, + activation=None, + self_loop=False, + feat_drop=0.0, + layer_norm=False, + regularizer=None, + num_bases=None, + ): + super(RGCNLayer, self).__init__() + if direction_option == "undirected": + self.model = UndirectedRGCNLayer( + input_size, + output_size, + num_rels=num_rels, + bias=bias, + activation=activation, + self_loop=self_loop, + feat_drop=feat_drop, + layer_norm=layer_norm, + regularizer=regularizer, + num_bases=num_bases, + ) + elif direction_option == "bi_sep": + self.model = BiSepRGCNLayer( + input_size, + output_size, + num_rels=num_rels, + bias=bias, + activation=activation, + self_loop=self_loop, + feat_drop=feat_drop, + layer_norm=layer_norm, + regularizer=regularizer, + num_bases=num_bases, + ) + elif direction_option == "bi_fuse": + self.model = BiFuseRGCNLayer( + input_size, + output_size, + num_rels=num_rels, + bias=bias, + activation=activation, + self_loop=self_loop, + feat_drop=feat_drop, + layer_norm=layer_norm, + regularizer=regularizer, + num_bases=num_bases, + ) + else: + raise RuntimeError("Unknown `direction_option` value: {}".format(direction_option)) + + def forward(self, graph: dgl.DGLHeteroGraph, feat: tp.Dict[str, torch.Tensor]): + r"""Compute graph attention network layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : torch.Tensor or pair of torch.Tensor + If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where + :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. + If a pair of torch.Tensor is given, the pair must contain two tensors of shape + :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, H, D_{out})` where :math:`H` + is the number of heads, and :math:`D_{out}` is size of output feature. + """ + return self.model(graph, feat) + + +class UndirectedRGCNLayer(GNNLayerBase): + r"""An undirected RGCN layer. + + .. math:: + TODO + + Parameters + ---------- + input_size : int, or pair of ints + Input feature size. + output_size : int + Output feature size. + num_rels: int + number of relations + regularizer : str, optional + Which weight regularizer to use "basis" or "bdd": + - "basis" is short for basis-decomposition. + - "bdd" is short for block-diagonal-decomposition. + Default applies no regularization. + num_bases : int, optional + Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + bias : bool, optional + True if bias is added. Default: ``True``. + activation : callable, optional + Activation function. Default: ``None``. + self_loop : bool, optional + True to include self loop message. Default: ``True``. + feat_drop : float, optional + Dropout rate. Default: ``0.0`` + layer_norm: float, optional + Add layer norm. Default: ``False`` + """ + + def __init__( + self, + input_size, + output_size, + num_rels, + bias=True, + activation=None, + self_loop=False, + feat_drop=0.0, + layer_norm=False, + regularizer=None, + num_bases=None, + dropout=0.0, + **kwargs, + ): + super(UndirectedRGCNLayer, self).__init__() + self.layer = RelGraphConvLayer( + in_feat=input_size, + out_feat=output_size, + num_rels=num_rels, + num_bases=num_bases, + activation=activation, + self_loop=self_loop, + dropout=dropout, + ) + + def forward(self, g: dgl.DGLHeteroGraph, feat: tp.Dict[str, torch.Tensor]): + return self.layer(g, feat) + + +class BiFuseRGCNLayer(GNNLayerBase): + r"""A Bidirectional version for RGCNLayer, with an additional fuse layer. + + .. math:: + TODO + + Parameters + ---------- + input_size : int, or pair of ints + Input feature size. + output_size : int + Output feature size. + num_rels: int + number of relations + regularizer : str, optional + Which weight regularizer to use "basis" or "bdd": + - "basis" is short for basis-decomposition. + - "bdd" is short for block-diagonal-decomposition. + Default applies no regularization. + num_bases : int, optional + Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + bias : bool, optional + True if bias is added. Default: ``True``. + activation : callable, optional + Activation function. Default: ``None``. + self_loop : bool, optional + True to include self loop message. Default: ``True``. + feat_drop : float, optional + Dropout rate. Default: ``0.0`` + layer_norm: float, optional + Add layer norm. Default: ``False`` + regularizer: str, optional + Which weight regularizer to use "basis" or "bdd": + - "basis" is short for basis-decomposition. + num_bases : int, optional + Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + """ + + def __init__( + self, + input_size, + output_size, + num_rels, + bias=True, + activation=None, + self_loop=False, + feat_drop=0.0, + layer_norm=False, + regularizer=None, + num_bases=None, + ): + super(BiFuseRGCNLayer, self).__init__() + self.ln_fwd = TypedLinear(input_size, output_size, num_rels, regularizer, num_bases) + self.ln_bwd = TypedLinear(input_size, output_size, num_rels, regularizer, num_bases) + + # self.linear_dict_forward = nn.ModuleDict( + # {str(i): nn.Linear(input_size, output_size, bias=bias) for i in range(num_rels)} + # ) + # self.linear_dict_backward = nn.ModuleDict( + # {str(i): nn.Linear(input_size, output_size, bias=bias) for i in range(num_rels)} + # ) + + # self.linear_r = TypedLinear(input_size, output_size, num_rels, regularizer, num_bases) + self.bias = bias + self.activation = activation + self.self_loop = self_loop + self.layer_norm = layer_norm + + # bias + if self.bias: + self.h_bias_forward = nn.Parameter(torch.Tensor(output_size)) + nn.init.zeros_(self.h_bias_forward) + self.h_bias_backward = nn.Parameter(torch.Tensor(output_size)) + nn.init.zeros_(self.h_bias_backward) + + # layer norm + if self.layer_norm: + self.layer_norm_weight_forward = nn.LayerNorm(output_size, elementwise_affine=True) + self.layer_norm_weight_backward = nn.LayerNorm(output_size, elementwise_affine=True) + + # weight for self loop + if self.self_loop: + self.loop_weight_forward = nn.Parameter(torch.Tensor(input_size, output_size)) + nn.init.xavier_uniform_(self.loop_weight_forward, gain=nn.init.calculate_gain("relu")) + + self.loop_weight_backward = nn.Parameter(torch.Tensor(input_size, output_size)) + nn.init.xavier_uniform_(self.loop_weight_backward, gain=nn.init.calculate_gain("relu")) + + self.fuse_linear = nn.Linear(4 * output_size, output_size, bias=True) + self.dropout = nn.Dropout(feat_drop) + + def forward(self, g: dgl.DGLHeteroGraph, feat: torch.Tensor, norm=None): + def message(edges, g, direction): + """Message function.""" + # linear_dict = ( + # self.linear_dict_forward if direction == "forward" else self.linear_dict_backward + # ) + # ln = linear_dict[str(g.canonical_etypes.index(edges._etype))] + # m = ln(edges.src["h"]) + + ln = self.ln_fwd if direction == "forward" else self.ln_bwd + etypes = torch.tensor( + [g.canonical_etypes.index(edges._etype)] * edges.src["h"].shape[0] + ).to(edges.src["h"].device) + m = ln(edges.src["h"], etypes) + if "norm" in edges.data: + m = m * edges.data["norm"] + return {"m": m} + + # self.presorted = presorted + with g.local_scope(): + g.srcdata["h"] = feat + if norm is not None: + g.edata["norm"] = norm + # g.edata['etype'] = etypes + # message passing + from functools import partial + + update_dict = { + etype: (partial(message, g=g, direction="forward"), fn.sum("m", "h")) + for etype in g.canonical_etypes + } + g.multi_update_all(etype_dict=update_dict, cross_reducer="sum") + # g.update_all(self.message, fn.sum('m', 'h')) + # apply bias and activation + h = g.dstdata["h"] + if self.layer_norm: + h = self.layer_norm_weight_forward(h) + if self.bias: + h = h + self.h_bias_forward + if self.self_loop: + h = h + feat[: g.num_dst_nodes()] @ self.loop_weight_forward + h_forward = h + + g = g.reverse() + with g.local_scope(): + g.srcdata["h"] = feat + if norm is not None: + g.edata["norm"] = norm + # g.edata['etype'] = etypes + # message passing + from functools import partial + + update_dict = { + etype: (partial(message, g=g, direction="backward"), fn.sum("m", "h")) + for etype in g.canonical_etypes + } + g.multi_update_all(etype_dict=update_dict, cross_reducer="sum") + # g.update_all(self.message, fn.sum('m', 'h')) + # apply bias and activation + h = g.dstdata["h"] + if self.layer_norm: + h = self.layer_norm_weight_backward(h) + if self.bias: + h = h + self.h_bias_backward + if self.self_loop: + h = h + feat[: g.num_dst_nodes()] @ self.loop_weight_backward + h_backward = h + + fuse_vector = torch.cat( + [h_forward, h_backward, h_forward * h_backward, h_forward - h_backward], dim=-1 + ) + fuse_gate_vector = torch.sigmoid(self.fuse_linear(fuse_vector)) + h = fuse_gate_vector * h_forward + (1 - fuse_gate_vector) * h_backward + + if self.activation: + h = self.activation(h) + h = self.dropout(h) + return h + + +class BiSepRGCNLayer(GNNLayerBase): + r"""A Bidirectional version for RGCNLayer. + + .. math:: + TODO + + Parameters + ---------- + input_size : int, or pair of ints + Input feature size. + output_size : int + Output feature size. + num_rels: int + number of relations + regularizer : str, optional + Which weight regularizer to use "basis" or "bdd": + - "basis" is short for basis-decomposition. + - "bdd" is short for block-diagonal-decomposition. + Default applies no regularization. + num_bases : int, optional + Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + bias : bool, optional + True if bias is added. Default: ``True``. + activation : callable, optional + Activation function. Default: ``None``. + self_loop : bool, optional + True to include self loop message. Default: ``True``. + feat_drop : float, optional + Dropout rate. Default: ``0.0`` + layer_norm: float, optional + Add layer norm. Default: ``False`` + regularizer: str, optional + Which weight regularizer to use "basis" or "bdd": + - "basis" is short for basis-decomposition. + num_bases : int, optional + Number of bases. Needed when ``regularizer`` is specified. Default: ``None``. + """ + + def __init__( + self, + input_size, + output_size, + num_rels, + bias=True, + activation=None, + self_loop=False, + feat_drop=0.0, + layer_norm=False, + regularizer=None, + num_bases=None, + ): + super(BiSepRGCNLayer, self).__init__() + self.ln_fwd = TypedLinear(input_size, output_size, num_rels, regularizer, num_bases) + self.ln_bwd = TypedLinear(input_size, output_size, num_rels, regularizer, num_bases) + + # self.linear_dict_forward = nn.ModuleDict( + # {str(i): nn.Linear(input_size, output_size, bias=bias) for i in range(num_rels)} + # ) + # self.linear_dict_backward = nn.ModuleDict( + # {str(i): nn.Linear(input_size, output_size, bias=bias) for i in range(num_rels)} + # ) + + # self.linear_r = TypedLinear(input_size, output_size, num_rels, regularizer, num_bases) + self.bias = bias + self.activation = activation + self.self_loop = self_loop + self.layer_norm = layer_norm + + # bias + if self.bias: + self.h_bias_forward = nn.Parameter(torch.Tensor(output_size)) + nn.init.zeros_(self.h_bias_forward) + self.h_bias_backward = nn.Parameter(torch.Tensor(output_size)) + nn.init.zeros_(self.h_bias_backward) + + # layer norm + if self.layer_norm: + self.layer_norm_weight_forward = nn.LayerNorm(output_size, elementwise_affine=True) + self.layer_norm_weight_backward = nn.LayerNorm(output_size, elementwise_affine=True) + + # weight for self loop + if self.self_loop: + self.loop_weight_forward = nn.Parameter(torch.Tensor(input_size, output_size)) + nn.init.xavier_uniform_(self.loop_weight_forward, gain=nn.init.calculate_gain("relu")) + + self.loop_weight_backward = nn.Parameter(torch.Tensor(input_size, output_size)) + nn.init.xavier_uniform_(self.loop_weight_backward, gain=nn.init.calculate_gain("relu")) + + self.dropout = nn.Dropout(feat_drop) + + def forward(self, g: dgl.DGLHeteroGraph, feat: torch.Tensor, norm=None): + def message(edges, g, direction): + """Message function.""" + # linear_dict = ( + # self.linear_dict_forward if direction == "forward" else self.linear_dict_backward + # ) + # ln = linear_dict[str(g.canonical_etypes.index(edges._etype))] + ln = self.ln_fwd if direction == "forward" else self.ln_bwd + etypes = torch.tensor( + [g.canonical_etypes.index(edges._etype)] * edges.src["h"].shape[0] + ).to(edges.src["h"].device) + m = ln(edges.src["h"], etypes) + if "norm" in edges.data: + m = m * edges.data["norm"] + return {"m": m} + + feat_forward, feat_backward = feat + # self.presorted = presorted + with g.local_scope(): + g.srcdata["h"] = feat_forward + if norm is not None: + g.edata["norm"] = norm + # g.edata['etype'] = etypes + # message passing + from functools import partial + + update_dict = { + etype: (partial(message, g=g, direction="forward"), fn.sum("m", "h")) + for etype in g.canonical_etypes + } + g.multi_update_all(etype_dict=update_dict, cross_reducer="sum") + # g.update_all(self.message, fn.sum('m', 'h')) + # apply bias and activation + h = g.dstdata["h"] + if self.layer_norm: + h = self.layer_norm_weight_forward(h) + if self.bias: + h = h + self.h_bias_forward + if self.self_loop: + h = h + feat_forward[: g.num_dst_nodes()] @ self.loop_weight_forward + h_forward = h + + g = g.reverse() + with g.local_scope(): + g.srcdata["h"] = feat_backward + if norm is not None: + g.edata["norm"] = norm + # g.edata['etype'] = etypes + # message passing + from functools import partial + + update_dict = { + etype: (partial(message, g=g, direction="backward"), fn.sum("m", "h")) + for etype in g.canonical_etypes + } + g.multi_update_all(etype_dict=update_dict, cross_reducer="sum") + # g.update_all(self.message, fn.sum('m', 'h')) + # apply bias and activation + h = g.dstdata["h"] + if self.layer_norm: + h = self.layer_norm_weight_backward(h) + if self.bias: + h = h + self.h_bias_backward + if self.self_loop: + h = h + feat_backward[: g.num_dst_nodes()] @ self.loop_weight_backward + h_backward = h + + if self.activation: + h_forward = self.activation(h_forward) + h_backward = self.activation(h_backward) + h_forward = self.dropout(h_forward) + h_backward = self.dropout(h_backward) + return [h_forward, h_backward] diff --git a/graph4nlp/pytorch/test/data_structure/test_graphdata.py b/graph4nlp/pytorch/test/data_structure/test_graphdata.py index 7d438715..0cf6d8dd 100644 --- a/graph4nlp/pytorch/test/data_structure/test_graphdata.py +++ b/graph4nlp/pytorch/test/data_structure/test_graphdata.py @@ -1,14 +1,13 @@ import gc import time import matplotlib.pyplot as plt +import pytest import torch import torch.nn as nn from graph4nlp.pytorch.data import GraphData, from_batch, from_dgl, to_batch from graph4nlp.pytorch.data.utils import EdgeNotFoundException, SizeMismatchException -import pytest - def fail_here(): raise Exception("The above line of code shouldn't be executed normally") diff --git a/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_aifb.yaml b/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_aifb.yaml new file mode 100644 index 00000000..38c59cf6 --- /dev/null +++ b/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_aifb.yaml @@ -0,0 +1,10 @@ +num_hidden_layers: 1 +hidden_size: 16 +dataset: 'aifb' +direction_option: "undirected" +self_loop: False +bias: True +feat_drop: 0.0 +lr: 0.01 +wd: 0.0005 +num_epochs: 200 diff --git a/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_am.yaml b/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_am.yaml new file mode 100644 index 00000000..a2a2a0bc --- /dev/null +++ b/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_am.yaml @@ -0,0 +1,10 @@ +num_hidden_layers: 1 +hidden_size: 16 +dataset: 'am' +direction_option: "undirected" +self_loop: False +bias: True +feat_drop: 0.0 +lr: 0.01 +wd: 0.0005 +num_epochs: 200 diff --git a/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_bgs.yaml b/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_bgs.yaml new file mode 100644 index 00000000..d599e1b3 --- /dev/null +++ b/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_bgs.yaml @@ -0,0 +1,10 @@ +num_hidden_layers: 1 +hidden_size: 16 +dataset: 'bgs' +direction_option: "undirected" +self_loop: False +bias: True +feat_drop: 0.0 +lr: 0.01 +wd: 0.0005 +num_epochs: 200 diff --git a/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_mutag.yaml b/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_mutag.yaml new file mode 100644 index 00000000..814fe76c --- /dev/null +++ b/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts/run_rgcn_mutag.yaml @@ -0,0 +1,10 @@ +num_hidden_layers: 1 +hidden_size: 16 +dataset: 'mutag' +direction_option: "undirected" +self_loop: False +bias: True +feat_drop: 0.0 +lr: 0.01 +wd: 0.0005 +num_epochs: 200 diff --git a/graph4nlp/pytorch/test/graph_embedding/run_rgcn.py b/graph4nlp/pytorch/test/graph_embedding/run_rgcn.py new file mode 100644 index 00000000..75762a03 --- /dev/null +++ b/graph4nlp/pytorch/test/graph_embedding/run_rgcn.py @@ -0,0 +1,239 @@ +import argparse +import dgl +import torch +import torch.nn as nn +import torch.nn.functional as F +from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset + +from torchmetrics.functional import accuracy + +from ...data.data import GraphData, from_dgl +from ...modules.graph_embedding_learning.rgcn import RGCNLayer +from ...modules.utils.generic_utils import get_config + + +# Load dataset +# Reference: dgl/examples/pytorch/rgcn/entity_utils.py +# (https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn/entity_utils.py) +def load_data(data_name="aifb", get_norm=False, inv_target=False): + if data_name == "aifb": + dataset = AIFBDataset() + # Test Accuracy: + # 0.9444, 0.8889, 0.9722, 0.9167, 0.9444 without enorm. + # 0.8611, 0.8889, 0.8889, 0.8889, 0.8333 + # avg: 0.93332 (without enorm) + # avg: 0.87222 + # DGL: 0.8889, 0.8889, 0.8056, 0.8889, 0.8611 + # DGL avg: 0.86668 + # paper: 0.9583 + # note: Could stuck at Local minimum of train loss between 0.2-0.35. + elif data_name == "mutag": + dataset = MUTAGDataset() + # Test Accuracy: + # 0.6912, 0.7500, 0.7353, 0.6324, 0.7353 + # avg: 0.68884 + # DGL: 0.6765, 0.7059, 0.7353, 0.6765, 0.6912 + # DGL avg: 0.69724 + # paper: 0.7323 + # note: Could stuck at local minimum of train acc: 0.3897 & loss 0.6931 + elif data_name == "bgs": + dataset = BGSDataset() + # Test Accuracy: + # 0.8966, 0.9310, 0.8966, 0.7931, 0.8621 + # avg: 0.87588 + # DGL: 0.7931, 0.9310, 0.8966, 0.8276, 0.8966 + # DGL avg: 0.86898 + # paper: 0.8310 + # note: Could stuck at local minimum of train acc: 0.6325 & loss: 0.6931 + else: + dataset = AMDataset() + # Test Accuracy: + # 0.7525, 0.7374, 0.7424, 0.7424, 0.7424 + # avg: 0.74342 + # DGL: 0.7677, 0.7677, 0.7323, 0.7879, 0.7677 + # DGL avg: 0.76466 + # paper: 0.8929 + # note: args.hidden_size is 10. + # Could stuck at local minimum of train loss: 0.3-0.5 + + # Load hetero-graph + hg = dataset[0] + + num_rels = len(hg.canonical_etypes) + category = dataset.predict_category + num_classes = dataset.num_classes + labels = hg.nodes[category].data.pop("labels") + train_mask = hg.nodes[category].data.pop("train_mask") + test_mask = hg.nodes[category].data.pop("test_mask") + train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() + test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() + + if get_norm: + # Calculate normalization weight for each edge, + # 1. / d, d is the degree of the destination node + for cetype in hg.canonical_etypes: + hg.edges[cetype].data["norm"] = dgl.norm_by_dst(hg, cetype).unsqueeze(1) + edata = ["norm"] + else: + edata = None + category_id = hg.ntypes.index(category) + g = dgl.to_homogeneous(hg, edata=edata) + node_ids = torch.arange(g.num_nodes()) + + # find out the target node ids in g + loc = g.ndata["_TYPE"] == category_id + target_idx = node_ids[loc] + + if inv_target: + # Map global node IDs to type-specific node IDs. This is required for + # looking up type-specific labels in a minibatch + inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64) + inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype) + return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target + else: + return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx + + +class MyModel(nn.Module): + def __init__( + self, + num_layers, + input_size, + hidden_size, + output_size, + num_rels, + direction_option=None, + bias=True, + activation=None, + self_loop=True, + feat_drop=0.0, + regularizer="none", + num_bases=4, + num_nodes=100, + ): + super(MyModel, self).__init__() + self.emb = nn.Embedding(num_nodes, hidden_size) + self.layer_1 = RGCNLayer( + input_size, + hidden_size, + num_rels=num_rels, + direction_option=direction_option, + bias=bias, + activation=activation, + self_loop=self_loop, + feat_drop=feat_drop, + regularizer=regularizer, + num_bases=num_bases, + ) + self.layer_2 = RGCNLayer( + hidden_size, + output_size, + num_rels=num_rels, + direction_option=direction_option, + bias=bias, + activation=activation, + self_loop=self_loop, + feat_drop=feat_drop, + regularizer=regularizer, + num_bases=num_bases, + ) + for k, v in self.named_parameters(): + print(f'{k} => {v}') + + def forward(self, g: GraphData): + node_features = self.emb(torch.IntTensor(list(range(g.get_node_num()))).to('cuda:0')) + dgl_g = g.to_dgl() + + # Make node feature dictionary + import typing as tp + feat_dict: tp.Dict[str, torch.Tensor] = {} + import numpy as np + node_types = np.array(g.ntypes,) + for i in set(node_types): + index = torch.tensor(np.where(node_types == i)[0], device=g.device) + feat_dict[i] = torch.index_select(node_features, 0, index) + + x1 = self.layer_1(dgl_g, feat_dict) + x2 = self.layer_2(dgl_g, x1) + return x2 + + +def main(config): + import mlflow + mlflow.set_tracking_uri("http://192.168.190.202:45250") + mlflow.set_experiment("rgcn_debug") + mlflow.start_run(run_name=f"rgcn_debug_{config['dataset']}") + + g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data( + data_name=config["dataset"], get_norm=True + ) + + # graph = from_dgl(g, is_hetero=False) + device = "cuda:0" + graph = from_dgl(g).to(device) + labels = labels.to(device) + num_nodes = graph.get_node_num() + my_model = MyModel( + num_layers=config["num_hidden_layers"] + 1, + input_size=config["hidden_size"], + hidden_size=config["hidden_size"], + output_size=num_classes, + direction_option=config["direction_option"], + bias=config["bias"], + activation=F.relu, + num_rels=num_rels, + self_loop=config["self_loop"], + feat_drop=config["feat_drop"], + regularizer="basis", + num_bases=num_rels, + num_nodes=num_nodes, + ).to(device) + optimizer = torch.optim.Adam( + my_model.parameters(), + lr=config["lr"], + weight_decay=config["wd"], + ) + print("start training...") + my_model.train() + for epoch in range(config["num_epochs"]): + logits = my_model(graph)['_N'] + logits = logits[target_idx] + loss = F.cross_entropy(logits[train_idx], labels[train_idx]) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_acc = accuracy(logits[train_idx].argmax(dim=1), labels[train_idx]).item() + print( + "Epoch {:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format( + epoch, train_acc, loss.item() + ) + ) + mlflow.log_metric("loss", loss.item(), step=epoch) + mlflow.log_metric("train_acc", train_acc, step=epoch) + + + print() + # Save Model + # torch.save(model.state_dict(), "./rgcn_model.pt") + print("start evaluating...") + my_model.eval() + with torch.no_grad(): + logits = my_model(graph)['_N'] + logits = logits[target_idx] + test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item() + print("Test Accuracy: {:.4f}".format(test_acc)) + + mlflow.log_metric("test_acc", test_acc) + mlflow.end_run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-config", type=str, help="path to the config file") + parser.add_argument("--grid_search", action="store_true", help="flag: grid search") + cfg = vars(parser.parse_args()) + config = get_config(cfg["config"]) + print(config) + main(config) diff --git a/graph4nlp/pytorch/test/graph_embedding/test_rgcn_perf.sh b/graph4nlp/pytorch/test/graph_embedding/test_rgcn_perf.sh new file mode 100755 index 00000000..a278a813 --- /dev/null +++ b/graph4nlp/pytorch/test/graph_embedding/test_rgcn_perf.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +export test_module=graph4nlp.pytorch.test.graph_embedding.run_rgcn +export python_command="python -m" +export config_root=/student/wangsaizhuo/Codes/graph4nlp/graph4nlp/pytorch/test/graph_embedding/rgcn_scripts + +test_routine() +{ + for dataset in {aifb,am,bgs,mutag} + do + ${python_command} ${test_module} -config ${config_root}/run_rgcn_${dataset}.yaml & + done + wait +} + + +# Test RGCN-Hetero Implementation on dgl benchmarks +git checkout rgcn-integration +test_routine() + +# Test RGCN-Homo Implementation on dgl benchmarks +git checkout debug-orig-rgcn +test_routine()