Skip to content

Commit

Permalink
Remove redundant code
Browse files Browse the repository at this point in the history
  • Loading branch information
natsukium committed Mar 22, 2019
1 parent 5a14a63 commit 5947af9
Showing 1 changed file with 11 additions and 21 deletions.
32 changes: 11 additions & 21 deletions chainer_chemistry/models/prediction/set_up_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,53 +34,43 @@ def set_up_predictor(
postprocess_fn (chainer.FunctionNode or None):
postprocess function for prediction.
"""
mlp = MLP(out_dim=class_num, hidden_dim=n_unit)
mlp = MLP(out_dim=class_num, hidden_dim=n_unit) # type: Optional[MLP]

if method == 'nfp':
print('Training an NFP predictor...')
nfp = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(nfp, mlp, label_scaler, postprocess_fn)
conv = NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
elif method == 'ggnn':
print('Training a GGNN predictor...')
ggnn = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(ggnn, mlp, label_scaler, postprocess_fn)
conv = GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
elif method == 'schnet':
print('Training an SchNet predictor...')
schnet = SchNet(
conv = SchNet(
out_dim=class_num, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(schnet, None, label_scaler,
postprocess_fn)
mlp = None
elif method == 'weavenet':
print('Training a WeaveNet predictor...')
n_atom = 20
n_sub_layer = 1
weave_channels = [50] * conv_layers

weavenet = WeaveNet(
conv = WeaveNet(
weave_channels=weave_channels,
hidden_dim=n_unit,
n_sub_layer=n_sub_layer,
n_atom=n_atom)
predictor = GraphConvPredictor(weavenet, mlp, label_scaler,
postprocess_fn)
elif method == 'rsgcn':
print('Training an RSGCN predictor...')
rsgcn = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(rsgcn, mlp, label_scaler,
postprocess_fn)
conv = RSGCN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
elif method == 'relgcn':
print('Training a Relational GCN predictor...')
num_edge_type = 4
relgcn = RelGCN(
conv = RelGCN(
out_channels=n_unit, num_edge_type=num_edge_type, scale_adj=True)
predictor = GraphConvPredictor(relgcn, mlp, label_scaler,
postprocess_fn)
elif method == 'relgat':
print('Training a Relational GAT predictor...')
relgat = RelGAT(
out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
predictor = GraphConvPredictor(relgat, mlp, label_scaler,
postprocess_fn)
conv = RelGAT(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers)
else:
raise ValueError('[ERROR] Invalid method: {}'.format(method))

predictor = GraphConvPredictor(conv, mlp, label_scaler, postprocess_fn)
return predictor

0 comments on commit 5947af9

Please sign in to comment.