Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Rgcn integration #584

Open
wants to merge 29 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5e28dfb
initial merge, awaiting test
xiao03 Nov 16, 2022
1d794a9
add library code and test
xiao03 Nov 17, 2022
5d57dab
add rgcn for QG
hugochan Nov 19, 2022
c4437c1
update config
xiao03 Nov 19, 2022
ef70aa3
format the script
xiao03 Nov 19, 2022
c20b9c2
remove device options from model
xiao03 Nov 19, 2022
ca03f95
add direction_options for rgcn
xiao03 Nov 19, 2022
af5fc60
format
xiao03 Nov 19, 2022
e1fa01d
remove unused parameters
xiao03 Nov 19, 2022
279e0b0
isort fix
xiao03 Nov 19, 2022
124c205
isort fix again
xiao03 Nov 19, 2022
a6c1d45
fix ci
AlanSwift Nov 19, 2022
e5d53a6
update rgcn & test case
xiao03 Nov 19, 2022
570dc89
fix
AlanSwift Nov 24, 2022
06703a4
bug fix on bi sep
xiao03 Nov 24, 2022
5b5f77a
bug fix on bi_fuse
xiao03 Nov 24, 2022
6b0bdc8
Change implementation of RGCN linear layer to DGL impl
SaizhuoWang Nov 24, 2022
e5b9808
Merge branch 'add_rgnn_for_qg' of https://github.com/graph4ai/graph4n…
SaizhuoWang Nov 24, 2022
1327a97
Implemented regularizer in RGCN
SaizhuoWang Dec 6, 2022
9e39bb2
Sync RGCNLayer implementation
SaizhuoWang Dec 9, 2022
1e0c597
modified DGL benchmark test code for rgcn
SaizhuoWang Dec 9, 2022
9c82c72
add rgcn for text classification
hugochan Nov 19, 2022
82fbb9c
update readme
hugochan Nov 20, 2022
887f2d1
linter & update readme
hugochan Dec 10, 2022
1321d1e
switch to RGCNLayer implemented in https://github.com/graph4ai/graph4…
hugochan Dec 11, 2022
a6d0b98
Migrated RGCN layer to hetero version in DGL
SaizhuoWang Dec 17, 2022
96b4539
fixed rgcn interface with dgl issue with tricks
SaizhuoWang Jan 4, 2023
7121e02
Bugfix in GraphData and RGCN testing
SaizhuoWang Jan 10, 2023
8eb46ad
mlflow integration and test script
SaizhuoWang Jan 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,5 @@ cscope.*
# config file
/config
local_scripts/

profiler/
4 changes: 2 additions & 2 deletions examples/pytorch/math_word_problem/mawps/src/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase

import sympy
from sympy.parsing.sympy_parser import parse_expr

from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase


class SolutionMatch(EvaluationMetricBase):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"config_path": "examples/pytorch/question_generation/config/squad_split2/qg.yaml",
"model_args.graph_construction_args.graph_construction_share.topology_subdir": "DependencyGraphForRGCN",
"model_args.graph_construction_args.graph_construction_private.edge_strategy": "heterogeneous",
"model_args.graph_construction_args.graph_construction_private.merge_strategy": "tailhead",
"model_args.graph_construction_args.graph_construction_private.sequential_link": true,
"model_args.graph_construction_args.graph_construction_private.as_node": false,
"model_args.graph_embedding_name": "rgcn",
"model_args.graph_embedding_args.graph_embedding_private.num_rels": 80,
"model_args.graph_embedding_args.graph_embedding_private.num_bases": 4,
"checkpoint_args.out_dir": "out/squad_split2/rgcn_dependency_ckpt"
}
7 changes: 6 additions & 1 deletion examples/pytorch/question_generation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from graph4nlp.pytorch.modules.utils.generic_utils import EarlyStopping, to_cuda
from graph4nlp.pytorch.modules.utils.logger import Logger

from examples.pytorch.semantic_parsing.graph2seq.rgcn_lib.graph2seq import RGCNGraph2Seq

from .fused_embedding_construction import FusedEmbeddingConstruction


Expand All @@ -39,7 +41,10 @@ def __init__(self, vocab, config):
]

# build Graph2Seq model
self.g2s = Graph2Seq.from_args(config, self.vocab)
if config["model_args"]["graph_embedding_name"] == "rgcn":
self.g2s = RGCNGraph2Seq.from_args(config, self.vocab)
else:
self.g2s = Graph2Seq.from_args(config, self.vocab)

if "w2v" in self.g2s.graph_initializer.embedding_layer.word_emb_layers:
self.word_emb = self.g2s.graph_initializer.embedding_layer.word_emb_layers[
Expand Down
14 changes: 4 additions & 10 deletions examples/pytorch/rgcn/rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ def __init__(
num_bases=None,
use_self_loop=True,
dropout=0.0,
device="cuda",
):
super(RGCN, self).__init__()
self.num_layers = num_layers
self.num_rels = num_rels
self.num_bases = num_bases
self.use_self_loop = use_self_loop
self.dropout = dropout
self.device = device

self.RGCN_layers = nn.ModuleList()

Expand Down Expand Up @@ -185,35 +183,31 @@ def __init__(
self_loop=False,
dropout=0.0,
layer_norm=False,
device="cuda",
):
super(RGCNLayer, self).__init__()
self.linear_dict = {
i: nn.Linear(input_size, output_size, bias=bias, device=device) for i in range(num_rels)
i: nn.Linear(input_size, output_size, bias=bias) for i in range(num_rels)
}
# self.linear_r = TypedLinear(input_size, output_size, num_rels, regularizer, num_bases)
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.layer_norm = layer_norm
self.device = device

# bias
if self.bias:
self.h_bias = nn.Parameter(torch.Tensor(output_size)).to(device)
self.h_bias = nn.Parameter(torch.Tensor(output_size))
nn.init.zeros_(self.h_bias)

# TODO(minjie): consider remove those options in the future to make
# the module only about graph convolution.
# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(
output_size, elementwise_affine=True, device=device
)
self.layer_norm_weight = nn.LayerNorm(output_size, elementwise_affine=True)

# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(torch.Tensor(input_size, output_size)).to(device)
self.loop_weight = nn.Parameter(torch.Tensor(input_size, output_size))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain("relu"))

self.dropout = nn.Dropout(dropout)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"config_path": "examples/pytorch/semantic_parsing/graph2seq/config/dependency_rgcn_undirected.yaml",
"model_args.graph_embedding_args.graph_embedding_share.direction_option": "bi_sep",
"training_args.log_file": "examples/pytorch/semantic_parsing/graph2seq/log/dependency_rgcn_bi_sep.txt"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"config_path": "examples/pytorch/semantic_parsing/graph2seq/config/dependency_rgcn_undirected.yaml",
"model_args.graph_embedding_args.graph_embedding_share.direction_option": "undirected",
"training_args.log_file": "examples/pytorch/semantic_parsing/graph2seq/log/dependency_rgcn_undirected.txt"
}
2 changes: 1 addition & 1 deletion examples/pytorch/semantic_parsing/graph2seq/main_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _build_logger(self, log_file):
import os

log_folder = os.path.split(log_file)[0]
if not os.path.exists(log_file):
if not os.path.exists(log_folder):
os.makedirs(log_folder)
self.logger = get_log(log_file)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from graph4nlp.pytorch.models.graph2seq import Graph2Seq

from examples.pytorch.rgcn.rgcn import RGCN
# from examples.pytorch.rgcn.rgcn import RGCN
from graph4nlp.pytorch.modules.graph_embedding_learning.rgcn import RGCN


class RGCNGraph2Seq(Graph2Seq):
Expand Down Expand Up @@ -74,10 +74,12 @@ def __init__(

def _build_gnn_encoder(
self,
gnn,
num_layers,
input_size,
hidden_size,
output_size,
direction_option,
feats_dropout,
gnn_num_rels=80,
gnn_num_bases=4,
Expand All @@ -89,6 +91,8 @@ def _build_gnn_encoder(
hidden_size,
output_size,
num_rels=gnn_num_rels,
num_bases=gnn_num_bases,
dropout=feats_dropout,
direction_option=direction_option,
# num_bases=gnn_num_bases,
# dropout=feats_dropout,
feat_drop=feats_dropout,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
"model_args.graph_embedding_args.graph_embedding_private.negative_slope": "0.2",
"model_args.graph_embedding_args.graph_embedding_private.residual": "false",
"model_args.graph_embedding_args.graph_embedding_private.allow_zero_in_degree": "true",
"checkpoint_args.out_dir": "out/trec/gat_bi_sep_dependency_ckpt"
"checkpoint_args.out_dir": "out/CAirline/gat_bi_sep_dependency_ckpt"
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"config_path": "examples/pytorch/text_classification/config/CAirline/text_clf.yaml",
"model_args.graph_construction_name": "constituency",
"model_args.graph_construction_args.graph_construction_share.topology_subdir": "constituency_graph",
"checkpoint_args.out_dir": "out/trec/ggnn_bi_sep_constituency_ckpt"
"checkpoint_args.out_dir": "out/CAirline/ggnn_bi_sep_constituency_ckpt"
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
"model_args.graph_embedding_args.graph_embedding_share.hidden_size": "300",
"model_args.graph_embedding_args.graph_embedding_share.output_size": "300",
"model_args.graph_embedding_args.graph_embedding_private.use_edge_weight": "true",
"checkpoint_args.out_dir": "out/trec/ggnn_bi_sep_node_emb_ckpt"
"checkpoint_args.out_dir": "out/CAirline/ggnn_bi_sep_node_emb_ckpt"
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
"model_args.graph_embedding_args.graph_embedding_share.hidden_size": "300",
"model_args.graph_embedding_args.graph_embedding_share.output_size": "300",
"model_args.graph_embedding_args.graph_embedding_private.use_edge_weight": "true",
"checkpoint_args.out_dir": "out/trec/ggnn_bi_sep_node_emb_refined_dependency_ckpt"
"checkpoint_args.out_dir": "out/CAirline/ggnn_bi_sep_node_emb_refined_dependency_ckpt"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"config_path": "examples/pytorch/text_classification/config/CAirline/text_clf.yaml",
"model_args.graph_construction_args.graph_construction_share.topology_subdir": "dependency_graph_for_rgcn",
"model_args.graph_construction_args.graph_construction_private.edge_strategy": "heterogeneous",
"model_args.graph_construction_args.graph_construction_private.merge_strategy": "tailhead",
"model_args.graph_construction_args.graph_construction_private.sequential_link": true,
"model_args.graph_construction_args.graph_construction_private.as_node": false,
"model_args.graph_embedding_name": "rgcn",
"model_args.graph_embedding_args.graph_embedding_share.direction_option": "undirected",
"model_args.graph_embedding_args.graph_embedding_private.num_rels": 80,
"model_args.graph_embedding_args.graph_embedding_private.num_bases": 4,
"checkpoint_args.out_dir": "out/CAirline/rgcn_dependency_ckpt"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"config_path": "examples/pytorch/text_classification/config/trec/text_clf.yaml",
"model_args.graph_construction_args.graph_construction_share.topology_subdir": "dependency_graph_for_rgcn",
"model_args.graph_construction_args.graph_construction_private.edge_strategy": "heterogeneous",
"model_args.graph_construction_args.graph_construction_private.merge_strategy": "tailhead",
"model_args.graph_construction_args.graph_construction_private.sequential_link": true,
"model_args.graph_construction_args.graph_construction_private.as_node": false,
"model_args.graph_embedding_name": "rgcn",
"model_args.graph_embedding_args.graph_embedding_share.direction_option": "undirected",
"model_args.graph_embedding_args.graph_embedding_private.num_rels": 80,
"model_args.graph_embedding_args.graph_embedding_private.num_bases": 4,
"training_args.lr": "0.002",
"checkpoint_args.out_dir": "out/trec/rgcn_dependency_ckpt"
}
24 changes: 12 additions & 12 deletions examples/pytorch/text_classification/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ TREC Results
-------


| GraphType\GNN | GAT-BiSep | GAT-BiFuse | GraphSAGE-BiSep | GraphSAGE-BiFuse | GGNN-BiSep | GGNN-BiFuse |
| ------------- | ------------- | --------------| ------------------- | ----------------- |-------------- | ------------- |
| Dependency | 0.9480 | 0.9460 | 0.942 | 0.958 | 0.954 | 0.9440 |
| Constituency | 0.9420 | 0.9300 | 0.952 | 0.950 | 0.952 | 0.9400 |
| NodeEmb | N/A | N/A | 0.930 | 0.908 | | |
| NodeEmbRefined | N/A | N/A | 0.940 | 0.926 | | |
| GraphType\GNN | GAT-BiSep | GAT-BiFuse | GraphSAGE-BiSep | GraphSAGE-BiFuse | GGNN-BiSep | GGNN-BiFuse | RGCN |
| ------------- | ------------- | --------------| ------------------- | ----------------- |-------------- | ------------- | ----- |
| Dependency | 0.9480 | 0.9460 | 0.942 | 0.958 | 0.954 | 0.944 | 0.946 |
| Constituency | 0.9420 | 0.9300 | 0.952 | 0.950 | 0.952 | 0.94 | N/A |
| NodeEmb | N/A | N/A | 0.930 | 0.908 | N/A | N/A | N/A |
| NodeEmbRefined | N/A | N/A | 0.940 | 0.926 | N/A | N/A | N/A |



CAirline Results
-------


| GraphType\GNN | GAT-BiSep | GGNN-BiSep |GraphSage-BiSep|
| -------------- | ------------ | ------------- |---------------|
| Dependency | 0.7496 | 0.8020 | 0.7977 |
| Constituency | 0.7846 | 0.7933 | 0.7948 |
| NodeEmb | N/A | 0.8108 | 0.8108 |
| NodeEmbRefined | N/A | 0.7991 | 0.8020 |
| GraphType\GNN | GAT-BiSep | GGNN-BiSep |GraphSage-BiSep| RGCN |
| -------------- | ------------ | ------------- |---------------|---------------|
| Dependency | 0.7496 | 0.8020 | 0.7977 | 0.7525 |
| Constituency | 0.7846 | 0.7933 | 0.7948 | N/A |
| NodeEmb | N/A | 0.8108 | 0.8108 | N/A |
| NodeEmbRefined | N/A | 0.7991 | 0.8020 | N/A |

26 changes: 26 additions & 0 deletions examples/pytorch/text_classification/run_text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from graph4nlp.pytorch.modules.utils.generic_utils import EarlyStopping, to_cuda
from graph4nlp.pytorch.modules.utils.logger import Logger

from graph4nlp.pytorch.modules.graph_embedding_learning.rgcn import RGCN

torch.multiprocessing.set_sharing_strategy("file_system")


Expand Down Expand Up @@ -217,6 +219,30 @@ def __init__(self, vocab, label_model, config):
"graph_embedding_private"
]["use_edge_weight"],
)
elif config["model_args"]["graph_embedding_name"] == "rgcn":
self.gnn = RGCN(
config["model_args"]["graph_embedding_args"]["graph_embedding_share"]["num_layers"],
config["model_args"]["graph_embedding_args"]["graph_embedding_share"]["input_size"],
config["model_args"]["graph_embedding_args"]["graph_embedding_share"][
"hidden_size"
],
config["model_args"]["graph_embedding_args"]["graph_embedding_share"][
"output_size"
],
num_rels=config["model_args"]["graph_embedding_args"]["graph_embedding_private"][
"num_rels"
],
direction_option=config["model_args"]["graph_embedding_args"][
"graph_embedding_share"
]["direction_option"],
feat_drop=config["model_args"]["graph_embedding_args"]["graph_embedding_share"][
"feat_drop"
],
regularizer="basis",
num_bases=config["model_args"]["graph_embedding_args"]["graph_embedding_private"][
"num_bases"
],
)
else:
raise RuntimeError(
"Unknown gnn type: {}".format(config["model_args"]["graph_embedding_name"])
Expand Down
26 changes: 16 additions & 10 deletions graph4nlp/pytorch/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def add_nodes(self, node_num: int, ntypes: List[str] = None):
)

if not self.is_hetero:
if ntypes is not None:
if ntypes is not None and len(set(ntypes)) > 1:
raise ValueError(
"The graph is homogeneous, ntypes should be None. Got {}".format(ntypes)
)
Expand Down Expand Up @@ -787,7 +787,7 @@ def _data_dict(self) -> Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Ten
)
return data_dict

def to_dgl(self) -> dgl.DGLGraph:
def to_dgl(self) -> dgl.DGLHeteroGraph:
"""
Convert to dgl.DGLGraph
Note that there will be some information loss when calling this function,
Expand All @@ -796,8 +796,8 @@ def to_dgl(self) -> dgl.DGLGraph:

Returns
-------
g : dgl.DGLGraph
The converted dgl.DGLGraph
g : dgl.DGLHeteroGraph
The converted dgl.DGLHeteroGraph
"""
u, v = self._edge_indices.src, self._edge_indices.tgt
num_nodes = self.get_node_num()
Expand Down Expand Up @@ -903,13 +903,13 @@ def make_num_nodes_dict(

return dgl_g

def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False):
def from_dgl(self, dgl_g: dgl.DGLHeteroGraph, is_hetero=False):
"""
Build the graph from dgl.DGLGraph
Build the graph from dgl.DGLHeteroGraph

Parameters
----------
dgl_g : dgl.DGLGraph
dgl_g : dgl.DGLHeteroGraph
The source graph
"""
if not (self.get_edge_num() == 0 and self.get_node_num() == 0):
Expand Down Expand Up @@ -950,6 +950,10 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False):
processed_node_types = False
node_feat_dict = {}
for feature_name, data_dict in node_data.items():
if not isinstance(data_dict, Dict):
# DGL will return tensor if ntype is single
# This can happen when graph is a multigraph
data_dict = {dgl_g.ntypes[0]: data_dict}
if not processed_node_types:
for node_type, node_feature in data_dict.items():
ntypes += [node_type] * len(node_feature)
Expand All @@ -967,7 +971,8 @@ def from_dgl(self, dgl_g: dgl.DGLGraph, is_hetero=False):
num_edges = dgl_g.num_edges(etype)
src_type, r_type, dst_type = etype
srcs, dsts = dgl_g.find_edges(
torch.tensor(list(range(num_edges)), dtype=torch.long), etype
torch.tensor(list(range(num_edges)), dtype=torch.long, device=dgl_g.device),
etype,
)
srcs, dsts = (
srcs.detach().cpu().numpy().tolist(),
Expand Down Expand Up @@ -1386,8 +1391,9 @@ def from_dgl(g: dgl.DGLGraph) -> GraphData:
GraphData
The converted graph in GraphData format.
"""
graph = GraphData(is_hetero=not g.is_homogeneous)
graph.from_dgl(g, is_hetero=not g.is_homogeneous)
dgl_g_is_hetero = (not g.is_homogeneous) or g.is_multigraph
graph = GraphData(is_hetero=dgl_g_is_hetero)
graph.from_dgl(g, is_hetero=dgl_g_is_hetero)
return graph


Expand Down
2 changes: 1 addition & 1 deletion graph4nlp/pytorch/models/graph2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Graph2Seq(Graph2XBase):
>>> "It is just a how-to-use example."
>>> from graph4nlp.pytorch.modules.config import get_basic_args
>>> opt = get_basic_args(graph_construction_name="node_emb", graph_embedding_name="gat", decoder_name="stdrnn")
>>> graph2seq = Graph2Seq.from_args(opt=opt, vocab_model=vocab_model, device=torch.device("cuda:0"))
>>> graph2seq = Graph2Seq.from_args(opt=opt, vocab_model=vocab_model)
>>> batch_graph = [GraphData() for _ in range(2)]
>>> tgt_seq = torch.Tensor([[1, 2, 3], [4, 5, 6]])
>>> seq_out, _, _ = graph2seq(batch_graph=batch_graph, tgt_seq=tgt_seq)
Expand Down
Loading