diff --git a/src/GraphGAN/graph_gan.py b/src/GraphGAN/graph_gan.py index c2a5831..c0fde9c 100644 --- a/src/GraphGAN/graph_gan.py +++ b/src/GraphGAN/graph_gan.py @@ -138,15 +138,13 @@ def train(self): self.saver.save(self.sess, config.model_log + "model.checkpoint") # D-steps + center_nodes = [] + neighbor_nodes = [] + labels = [] for d_epoch in range(config.n_epochs_dis): - center_nodes = [] - neighbor_nodes = [] - labels = [] - # generate new nodes for the discriminator for every dis_interval iterations if d_epoch % config.dis_interval == 0: center_nodes, neighbor_nodes, labels = self.prepare_data_for_d() - # training train_size = len(center_nodes) start_list = list(range(0, train_size, config.batch_size_dis)) @@ -159,10 +157,10 @@ def train(self): self.discriminator.label: np.array(labels[start:end])}) # G-steps + node_1 = [] + node_2 = [] + reward = [] for g_epoch in range(config.n_epochs_gen): - node_1 = [] - node_2 = [] - reward = [] if g_epoch % config.gen_interval == 0: node_1, node_2, reward = self.prepare_data_for_g()