Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
hwwang55 committed Nov 6, 2018
1 parent 0d3e23f commit a709f7f
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/GraphGAN/graph_gan.py
Expand Up @@ -138,15 +138,13 @@ def train(self):
self.saver.save(self.sess, config.model_log + "model.checkpoint")

# D-steps
center_nodes = []
neighbor_nodes = []
labels = []
for d_epoch in range(config.n_epochs_dis):
center_nodes = []
neighbor_nodes = []
labels = []

# generate new nodes for the discriminator for every dis_interval iterations
if d_epoch % config.dis_interval == 0:
center_nodes, neighbor_nodes, labels = self.prepare_data_for_d()

# training
train_size = len(center_nodes)
start_list = list(range(0, train_size, config.batch_size_dis))
Expand All @@ -159,10 +157,10 @@ def train(self):
self.discriminator.label: np.array(labels[start:end])})

# G-steps
node_1 = []
node_2 = []
reward = []
for g_epoch in range(config.n_epochs_gen):
node_1 = []
node_2 = []
reward = []
if g_epoch % config.gen_interval == 0:
node_1, node_2, reward = self.prepare_data_for_g()

Expand Down

0 comments on commit a709f7f

Please sign in to comment.