Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
orcax committed Jun 22, 2019
1 parent 6ab016f commit 5f869c5
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 31 deletions.
1 change: 0 additions & 1 deletion kg_env.py
Expand Up @@ -222,7 +222,6 @@ def batch_step(self, batch_act_idx):
act_idx = batch_act_idx[i]
_, curr_node_type, curr_node_id = self._batch_path[i][-1]
relation, next_node_id = self._batch_curr_actions[i][act_idx]
# TODO: check act_idx is valid?
if relation == SELF_LOOP:
next_node_type = curr_node_type
else:
Expand Down
18 changes: 1 addition & 17 deletions knowledge_graph.py
Expand Up @@ -204,23 +204,6 @@ def heuristic_search(self, uid, pid, pattern_id, trim_edges=False):
intersect_nodes = wids_u.intersection(wids_u_p)
tmp_paths = [(uid, x, uid_p, pid) for x in intersect_nodes]
paths.extend(tmp_paths)
# elif len(pattern) == 5: # DOES NOT WORK SO FAR!
# nodes_from_user = set(self.G[USER][uid][pattern[1][0]]) # USER->MENTION->WORD
# nodes_from_product = set(self.G[PRODUCT][pid][pattern[-1][0]])
# if pattern[-2][1] == USER:
# nodes_from_product.difference([uid])
# count = 0
# for wid in nodes_from_user:
# pids_from_wid = set(self.G[WORD][wid][pattern[2][0]]) # USER->MENTION->WORD->DESCRIBE->PRODUCT
# pids_from_wid = pids_from_wid.difference([pid]) # exclude target product
# for nid in nodes_from_product:
# if pattern[-2][1] == WORD:
# if nid == wid:
# continue
# other_pids = set(self.G[pattern[-2][1]][nid][pattern[-2][0]])
# intersect_nodes = pids_from_wid.intersection(other_pids)
# count += len(intersect_nodes)
# return count

return paths

Expand All @@ -236,3 +219,4 @@ def check_test_path(dataset_str, kg):
count += len(tmp_path)
if count == 0:
print(uid, pid)

3 changes: 1 addition & 2 deletions preprocess.py
Expand Up @@ -27,7 +27,7 @@ def generate_labels(dataset, mode='train'):

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default=CELL, help='One of {BEAUTY, CELL, CD, CLOTH}.')
parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {BEAUTY, CELL, CD, CLOTH}.')
args = parser.parse_args()

# Create AmazonDataset instance for dataset.
Expand Down Expand Up @@ -55,7 +55,6 @@ def main():
# =========== END =========== #



if __name__ == '__main__':
main()

3 changes: 1 addition & 2 deletions test_agent.py
Expand Up @@ -13,7 +13,6 @@
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical
from tensorboardX import SummaryWriter
import threading
from functools import reduce

Expand Down Expand Up @@ -220,7 +219,7 @@ def test(args):
if __name__ == '__main__':
boolean = lambda x: (str(x).lower() == 'true')
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='beauty', help='One of {cloth, beauty, cell, cd}')
parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {cloth, beauty, cell, cd}')
parser.add_argument('--name', type=str, default='train_agent', help='directory name.')
parser.add_argument('--seed', type=int, default=123, help='random seed.')
parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
Expand Down
10 changes: 1 addition & 9 deletions train_agent.py
Expand Up @@ -10,7 +10,6 @@
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical
from tensorboardX import SummaryWriter

from knowledge_graph import KnowledgeGraph
from kg_env import BatchKGEnvironment
Expand Down Expand Up @@ -130,8 +129,6 @@ def get_batch(self):


def train(args):
train_writer = SummaryWriter(args.log_dir)

env = BatchKGEnvironment(args.dataset, args.max_acts, max_path_len=args.max_path_len, state_history=args.state_history)
uids = list(env.kg(USER).keys())
dataloader = ACDataLoader(uids, args.batch_size)
Expand Down Expand Up @@ -185,11 +182,6 @@ def train(args):
' | vloss={:.5f}'.format(avg_vloss) +
' | entropy={:.5f}'.format(avg_entropy) +
' | reward={:.5f}'.format(avg_reward))
train_writer.add_scalar('train/loss', avg_loss, step)
train_writer.add_scalar('train/ploss', avg_ploss, step)
train_writer.add_scalar('train/vloss', avg_vloss, step)
train_writer.add_scalar('train/entropy', avg_entropy, step)
train_writer.add_scalar('train/reward', avg_reward, step)
### END of epoch ###

policy_file = '{}/policy_model_epoch_{}.ckpt'.format(args.log_dir, epoch)
Expand All @@ -199,7 +191,7 @@ def train(args):

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='beauty', help='One of {clothing, cell, beauty, cd}')
parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {clothing, cell, beauty, cd}')
parser.add_argument('--name', type=str, default='train_agent', help='directory name.')
parser.add_argument('--seed', type=int, default=123, help='random seed.')
parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
Expand Down

0 comments on commit 5f869c5

Please sign in to comment.