Skip to content

Commit

Permalink
update policy nn
Browse files Browse the repository at this point in the history
  • Loading branch information
orcax committed Jun 20, 2019
1 parent 4cdd897 commit 5189e34
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
27 changes: 15 additions & 12 deletions batch_actor_critic.py
Expand Up @@ -14,11 +14,9 @@
from torch.distributions import Categorical
from tensorboardX import SummaryWriter

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from kgrl.knowledge_graph import KnowledgeGraph
from kgrl.batch_env import BatchKGEnvironment
from kgrl.kg_utils import *
from knowledge_graph import KnowledgeGraph
from batch_env import BatchKGEnvironment
from utils import *

logger = None

Expand Down Expand Up @@ -148,11 +146,9 @@ def get_batch(self):


def train(args):
global logger
train_writer = SummaryWriter(args.log_dir)

env = BatchKGEnvironment(args.dataset, args.max_acts, max_path_len=args.max_path_len, embed_hop=args.hop,
state_history=args.state_history)
env = BatchKGEnvironment(args.dataset, args.max_acts, max_path_len=args.max_path_len, state_history=args.state_history)
uids = list(env.kg(USER).keys())
dataloader = ACDataLoader(uids, args.batch_size)
model = ActorCritic(env.state_dim, env.act_dim, gamma=args.gamma, hidden_sizes=args.hidden).to(args.device)
Expand Down Expand Up @@ -222,10 +218,10 @@ def train(args):
torch.save(model.state_dict(), policy_file)


if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cd', help='One of {clothing, cell, beauty, cd}')
parser.add_argument('--name', type=str, default='train_bac_hop1_acts250', help='directory name.')
parser.add_argument('--dataset', type=str, default='beauty', help='One of {clothing, cell, beauty, cd}')
parser.add_argument('--name', type=str, default='train_bac', help='directory name.')
parser.add_argument('--seed', type=int, default=123, help='random seed.')
parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
parser.add_argument('--epochs', type=int, default=50, help='Max number of epochs.')
Expand All @@ -243,11 +239,18 @@ def train(args):

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
args.log_dir = DATA_DIR[args.dataset] + '/' + args.name

args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log.txt')
logger.info(args)

set_random_seed(args.seed)
train(args)


if __name__ == '__main__':
main()
16 changes: 6 additions & 10 deletions batch_env.py
Expand Up @@ -8,10 +8,8 @@
import torch
from datetime import datetime

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from kgrl.knowledge_graph import KnowledgeGraph
from kgrl.kg_utils import *
from knowledge_graph import KnowledgeGraph
from utils import *


class KGState(object):
Expand Down Expand Up @@ -41,12 +39,12 @@ def __call__(self, user_embed, node_embed, last_node_embed, last_relation_embed,


class BatchKGEnvironment(object):
def __init__(self, dataset_str, max_acts, max_path_len=3, state_history=1, embed_hop=1):
def __init__(self, dataset_str, max_acts, max_path_len=3, state_history=1):
self.max_acts = max_acts
self.act_dim = max_acts + 1 # Add self-loop action, whose act_idx is always 0.
self.max_num_nodes = max_path_len + 1 # max number of hops (= #nodes - 1)
self.kg = load_kg(dataset_str)
self.embeds = load_embed(dataset_str, embed_hop)
self.embeds = load_embed(dataset_str)
self.embed_size = self.embeds[USER].shape[1]
self.embeds[SELF_LOOP] = (np.zeros(self.embed_size), 0.0)
self.state_gen = KGState(self.embed_size, history_len=state_history)
Expand Down Expand Up @@ -264,14 +262,12 @@ def main():
random.seed(123)
np.random.seed(123)

dataset_str = 'beauty'
dataset_str = 'cloth'
max_acts = 250
max_path_len = 3
env = BatchKGEnvironment(dataset_str, max_acts, max_path_len, embed_hop=2)
env = BatchKGEnvironment(dataset_str, max_acts, max_path_len)
print(env.u_p_scales.shape, len(env.kg(USER).keys()))

return

# Test case 1
print('Test case 1')
t1 = datetime.now()
Expand Down
9 changes: 5 additions & 4 deletions train_transe_model.py
Expand Up @@ -5,8 +5,6 @@
import argparse
import random
import numpy as np
import logging
import logging.handlers
import torch
import torch.nn as nn
import torch.optim as optim
Expand Down Expand Up @@ -116,7 +114,7 @@ def main():
parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {beauty, cd, cell, clothing}.')
parser.add_argument('--name', type=str, default='train_transe_model', help='model name.')
parser.add_argument('--seed', type=int, default=123, help='random seed.')
parser.add_argument('--gpu', type=str, default='cuda:1', help='gpu device.')
parser.add_argument('--gpu', type=str, default='1', help='gpu device.')
parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train.')
parser.add_argument('--batch_size', type=int, default=64, help='batch size.')
parser.add_argument('--lr', type=float, default=0.5, help='learning rate.')
Expand All @@ -128,12 +126,15 @@ def main():
parser.add_argument('--steps_per_checkpoint', type=int, default=200, help='Number of steps for checkpoint.')
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log.txt')
args.device = torch.device(args.gpu) if torch.cuda.is_available() else 'cpu'
logger.info(args)

set_random_seed(args.seed)
Expand Down
1 change: 1 addition & 0 deletions utils.py
Expand Up @@ -4,6 +4,7 @@
import random
import pickle
import logging
import logging.handlers
import numpy as np
import scipy.sparse as sp
from sklearn.feature_extraction.text import TfidfTransformer
Expand Down

0 comments on commit 5189e34

Please sign in to comment.