-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring Graph Warp Module #340
Conversation
I finish implementation. Please review. |
Codecov Report
@@ Coverage Diff @@
## master #340 +/- ##
==========================================
+ Coverage 81.65% 82.95% +1.29%
==========================================
Files 210 211 +1
Lines 9647 9813 +166
==========================================
+ Hits 7877 8140 +263
+ Misses 1770 1673 -97 |
def mol_basic_info_feature(mol, atom_array, adj): | ||
n_atoms = mol.GetNumAtoms() | ||
assert n_atoms == len(atom_array) | ||
n_edges = adj.sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just comment: actually this is actual number of edges * 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
|
||
|
||
def construct_supernode_feature(mol, atom_array, adj, feature_functions=None): | ||
# largest_atomic_number=MAX_ATOMIC_NUM, out_size=-1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also delete docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
@@ -22,80 +156,65 @@ class GWM(chainer.Chain): | |||
number of super-node observation attributes | |||
n_edge_types (int): number of edge types witin graphs. | |||
dropout_ratio (default=0.5); if > 0.0, perform dropout | |||
tying_flag (default=false): enable if you want to share params across layers | |||
tying_flag (default=false): enable if you want to share params across | |||
layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docstring for activation, wgu_activation and gtu_activation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
if tying_flag: | ||
num_layer = 1 | ||
n_layers = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes bug. self.n_layers must be set before override this value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead, refactor to remove n_layers dependency and do not use step
in call method.
TODO later.
:param adj: minibatch by bond_types by num_nodes by num_nodes 1/0 array. | ||
Adjacency matrices over several bond types | ||
:param adj: minibatch by bond_types by num_nodes by num_nodes 1/0 | ||
array. Adjacency matrices over several bond types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you change to Google docstring format?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
|
||
|
||
def check_forward(gwm, embed_atom_data, new_embed_atom_data, supernode): | ||
gwm.GRU_local.reset_state() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gwm.reset_state()
is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
|
||
def check_backward(gwm, embed_atom_data, new_embed_atom_data, supernode, | ||
y_grad, supernode_grad): | ||
gwm.GRU_local.reset_state() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gwm.reset_state()
is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
return merged | ||
|
||
|
||
class SuperNodeTransmitterUnit(chainer.Chain): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
return g_trans | ||
|
||
|
||
class GraphTransmitterUnit(chainer.Chain): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
@@ -5,11 +5,145 @@ | |||
from chainer_chemistry.links import GraphLinear | |||
|
|||
|
|||
class WarpGateUnit(chainer.Chain): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
elif output_type == 'super': | ||
LinearFunc = links.Linear | ||
else: | ||
raise ValueError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
show proper error message.
ValueError('output_type = {} is unexpected. graph or super is supported.'.format(output_type))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
super(SuperNodeTransmitterUnit, self).__init__() | ||
with self.init_scope(): | ||
self.F_super = links.Linear(in_size=hidden_dim_super, | ||
out_size=hidden_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally it was out_size=hidden_dim_super, is it ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original code uses F_super for 2 place: update super node feature itself & message to node for transmission.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
F_super is separated and I think it's okay.
# for local updates | ||
g_trans = self.F_super(g) | ||
# intermediate_h_super.shape == (mb, self.hidden_dim) | ||
g_trans = functions.tanh(g_trans) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about return g_trans
here, and let "expand_dims and broadcast" later when necessary in other module?
So that we can remove n_atoms
argument dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO later
*[links.Linear(in_size=hidden_dim_super, | ||
out_size=hidden_dim_super) | ||
for _ in range(n_layers)] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you separate F_super ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is different from original code, but separating F_super is better and it's ok.
# update for attention-message B h_i | ||
# h1.shape == (mb, atom, n_heads * ch) | ||
# Bh_i.shape == (mb, atom, self.n_heads * self.hidden_dim_super) | ||
Bh_i = self.B(h1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think broadcast h1 is redundant and can be skipped to reduce computation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
with self.init_scope(): | ||
self.V_super = links.Linear(hidden_dim * n_heads, hidden_dim * n_heads) | ||
self.W_super = links.Linear(hidden_dim * n_heads, hidden_dim_super) | ||
self.B = GraphLinear(n_heads * hidden_dim, n_heads * hidden_dim_super) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.B = GraphLinear(hidden_dim, n_heads * hidden_dim_super)
when we skip h1 broadcast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
# intermediate_h.shape == (mb, self.n_heads * ch) | ||
h_trans = self.V_super(attention_sum) | ||
# compress heads | ||
h_trans = self.W_super(h_trans) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of applying linear operation twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think V_super is not necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
self.activation = activation | ||
|
||
def __call__(self, h, g): | ||
z = self.H(h) + self.G(g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can calculate self.G(g) as Linear layer followed by broadcast to each atom.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO later
Can you update code? @mottodora |
I will take over. |
super(SuperNodeTransmitterUnit, self).__init__() | ||
with self.init_scope(): | ||
self.F_super = links.Linear(in_size=hidden_dim_super, | ||
out_size=hidden_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original code uses F_super for 2 place: update super node feature itself & message to node for transmission.
|
||
def check_backward(gwm, embed_atom_data, new_embed_atom_data, supernode, | ||
y_grad, supernode_grad): | ||
gwm.GRU_local.reset_state() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
|
||
|
||
def check_forward(gwm, embed_atom_data, new_embed_atom_data, supernode): | ||
gwm.GRU_local.reset_state() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
:param adj: minibatch by bond_types by num_nodes by num_nodes 1/0 array. | ||
Adjacency matrices over several bond types | ||
:param adj: minibatch by bond_types by num_nodes by num_nodes 1/0 | ||
array. Adjacency matrices over several bond types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
*[links.Linear(in_size=hidden_dim_super, | ||
out_size=hidden_dim_super) | ||
for _ in range(n_layers)] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is different from original code, but separating F_super is better and it's ok.
# update for attention-message B h_i | ||
# h1.shape == (mb, atom, n_heads * ch) | ||
# Bh_i.shape == (mb, atom, self.n_heads * self.hidden_dim_super) | ||
Bh_i = self.B(h1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
# intermediate_h.shape == (mb, self.n_heads * ch) | ||
h_trans = self.V_super(attention_sum) | ||
# compress heads | ||
h_trans = self.W_super(h_trans) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
|
||
h_j = functions.expand_dims(h, 1) | ||
# h_j.shape == (mb, self.n_heads, atom, ch) | ||
h_j = functions.broadcast_to(h_j, (mb, self.n_heads, atom, ch)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply V_super instead of broadcast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO later
@@ -22,80 +156,65 @@ class GWM(chainer.Chain): | |||
number of super-node observation attributes | |||
n_edge_types (int): number of edge types witin graphs. | |||
dropout_ratio (default=0.5); if > 0.0, perform dropout | |||
tying_flag (default=false): enable if you want to share params across layers | |||
tying_flag (default=false): enable if you want to share params across | |||
layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in new PR
if tying_flag: | ||
num_layer = 1 | ||
n_layers = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead, refactor to remove n_layers dependency and do not use step
in call method.
TODO later.
Resolve #329