Skip to content

Commit

Permalink
allow setting layer number for generator and discriminator separately
Browse files Browse the repository at this point in the history
  • Loading branch information
JiaxuanYou committed Jul 26, 2018
1 parent b080018 commit 84e5d29
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
13 changes: 7 additions & 6 deletions rl-baselines/baselines/ppo1/gcn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,13 @@ def discriminator_net(ob,args,name='d_net'):
if args.bn==1:
ob_node = tf.layers.batch_normalization(ob_node,axis=-1)
emb_node = GCN_batch(ob['adj'], ob_node, args.emb_size, name='gcn1',aggregate=args.gcn_aggregate)
for i in range(args.layer_num_d - 2):
if args.bn==1:
emb_node = tf.layers.batch_normalization(emb_node,axis=-1)
emb_node = GCN_batch(ob['adj'], emb_node, args.emb_size, name='gcn1_'+str(i+1),aggregate=args.gcn_aggregate)
if args.bn==1:
emb_node = tf.layers.batch_normalization(emb_node,axis=-1)
emb_node = GCN_batch(ob['adj'], emb_node, args.emb_size, name='gcn2',aggregate=args.gcn_aggregate)
if args.bn==1:
emb_node = tf.layers.batch_normalization(emb_node,axis=-1)
emb_node = GCN_batch(ob['adj'], emb_node, args.emb_size, is_act=False, is_normalize=(args.bn == 0), name='gcn3',aggregate=args.gcn_aggregate)
emb_node = GCN_batch(ob['adj'], emb_node, args.emb_size, is_act=False, is_normalize=(args.bn == 0), name='gcn2',aggregate=args.gcn_aggregate)
if args.bn==1:
emb_node = tf.layers.batch_normalization(emb_node,axis=-1)
# emb_graph = tf.reduce_max(tf.squeeze(emb_node2, axis=1),axis=1) # B*f
Expand Down Expand Up @@ -186,7 +187,7 @@ def _init(self, ob_space, ac_space, kind, atom_type_num,args):
emb_node = GCN_batch(ob['adj'], ob_node, args.emb_size, name='gcn1',aggregate=args.gcn_aggregate)
if args.bn == 1:
emb_node = tf.layers.batch_normalization(emb_node, axis=-1)
for i in range(args.layer_num-2):
for i in range(args.layer_num_g-2):
if args.has_residual==1:
emb_node = GCN_batch(ob['adj'], emb_node, args.emb_size, name='gcn1_'+str(i+1),aggregate=args.gcn_aggregate)+self.emb_node1
elif args.has_concat==1:
Expand Down Expand Up @@ -367,7 +368,7 @@ def GCN_emb(ob,args):
axis=-1)
else:
emb_node1 = GCN_batch(ob['adj'], ob_node, args.emb_size, name='gcn1', aggregate=args.gcn_aggregate)
for i in range(args.layer_num - 2):
for i in range(args.layer_num_g - 2):
if args.has_residual == 1:
emb_node1 = GCN_batch(ob['adj'], emb_node1, args.emb_size, name='gcn1_' + str(i + 1),
aggregate=args.gcn_aggregate) + emb_node1
Expand Down
5 changes: 3 additions & 2 deletions run_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,14 @@ def molecule_arg_parser():
parser.add_argument('--curriculum_step', type=int, default=200)
parser.add_argument('--supervise_time', type=int, default=4)
parser.add_argument('--normalize_adj', type=int, default=0)
parser.add_argument('--layer_num', type=int, default=3)
parser.add_argument('--layer_num_g', type=int, default=3)
parser.add_argument('--layer_num_d', type=int, default=3)
parser.add_argument('--graph_emb', type=int, default=0)
parser.add_argument('--stop_shift', type=int, default=-3)
parser.add_argument('--has_residual', type=int, default=0)
parser.add_argument('--has_concat', type=int, default=0)
parser.add_argument('--has_feature', type=int, default=0)
parser.add_argument('--emb_size', type=int, default=64)
parser.add_argument('--emb_size', type=int, default=128) # default 64
parser.add_argument('--gcn_aggregate', type=str, default='mean')# sum, mean, concat
parser.add_argument('--gan_type', type=str, default='normal')# normal, recommend, wgan
parser.add_argument('--gate_sum_d', type=int, default=0)
Expand Down

0 comments on commit 84e5d29

Please sign in to comment.