In [1]:
import os
import time
import argparse
import tensorflow as tf
from sampler import WarpSampler
from model import Model
from tqdm import tqdm
from util import *

In [2]:
def str2bool(s):
    if s not in {'False', 'True'}:
        raise ValueError('Not a valid boolean string')
    return s == 'True'

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True)
parser.add_argument('--train_dir', required=True)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=50, type=int)
parser.add_argument('--hidden_units', default=50, type=int)
parser.add_argument('--num_blocks', default=2, type=int)
parser.add_argument('--num_epochs', default=201, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)

_StoreAction(option_strings=['--l2_emb'], dest='l2_emb', nargs=None, const=None, default=0.0, type=<class 'float'>, choices=None, help=None, metavar=None)

In [4]:
dataset = 'Beauty'
train_dir = 'train'
batch_size = '128'
lr = '0.001'
maxlen = '50'
hidden_units = '50'
num_blocks = '2'
num_epochs = '201'
num_heads = '1'
dropout_rate = '0.5'
l2_emb = '0.0'

In [5]:
def get_args():
    args = parser.parse_args( args = [
                                        '--dataset', dataset,
                                        '--train_dir', train_dir,
                                        '--batch_size', batch_size,
                                        '--lr', lr,
                                        '--maxlen', maxlen,
                                        '--hidden_units', hidden_units,
                                        '--num_blocks', num_blocks,
                                        '--num_epochs', num_epochs,
                                        '--num_heads', num_heads,
                                        '--dropout_rate', dropout_rate,
                                        '--l2_emb', l2_emb
                                    ]
                            )
    return args

In [6]:
args = get_args()

In [7]:
if not os.path.isdir(args.dataset + '_' + args.train_dir):
    os.makedirs(args.dataset + '_' + args.train_dir)
with open(os.path.join(args.dataset + '_' + args.train_dir, 'args.txt'), 'w') as f:
    f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(list(vars(args).items()), key=lambda x: x[0])]))
f.close()

In [8]:
dataset = data_partition(args.dataset)
[user_train, user_valid, user_test, usernum, itemnum] = dataset
num_batch = int(len(user_train) / args.batch_size)
cc = 0.0
for u in user_train:
    cc += len(user_train[u])
print('average sequence length: %.2f' % (cc / len(user_train)))

average sequence length: 5.63


In [9]:
f = open(os.path.join(args.dataset + '_' + args.train_dir, 'log.txt'), 'w')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)

In [10]:
sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)
model = Model(usernum, itemnum, args)
sess.run(tf.initialize_all_variables())

W0908 19:07:54.612916 4512646592 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/SASRec/model.py:6: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0908 19:07:54.634287 4512646592 deprecation.py:323] From /Users/fumiyo_ito/Documents/git/SASRec/model.py:13: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.
W0908 19:07:54.644150 4512646592 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/SASRec/model.py:15: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0908 19:07:54.647084 4512646592 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/SASRec/modules.py:117: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W0908 19:07:55.451338 4512646592 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2

  File "/anaconda3/lib/python3.7/threading.py", line 1032, in join
    self._wait_for_tstate_lock()
  File "/anaconda3/lib/python3.7/threading.py", line 1048, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
KeyboardInterrupt
  File "/anaconda3/lib/python3.7/threading.py", line 1048, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt


In [11]:
T = 0.0
t0 = time.time()

In [12]:
for epoch in range(1, args.num_epochs + 1):

        for step in tqdm(list(range(num_batch)), total=num_batch, ncols=70, leave=False, unit='b'):
            u, seq, pos, neg = sampler.next_batch()
            auc, loss, _ = sess.run([model.auc, model.loss, model.train_op],
                                    {model.u: u, model.input_seq: seq, model.pos: pos, model.neg: neg,
                                     model.is_training: True})

        if epoch % 20 == 0:
            t1 = time.time() - t0
            T += t1
            print('Evaluating', end=' ')
            t_test = evaluate(model, dataset, args, sess)
            t_valid = evaluate_valid(model, dataset, args, sess)
            print('')
            print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)' % (
            epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1]))

            f.write(str(t_valid) + ' ' + str(t_test) + '\n')
            f.flush()
            t0 = time.time()

                                                                      

Evaluating . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 

  0%|                                  | 1/407 [00:00<00:49,  8.22b/s]


epoch:20, time: 911.770738(s), valid (NDCG@10: 0.2876, HR@10: 0.4543), test (NDCG@10: 0.2669, HR@10: 0.4309)


 91%|█████████████████████████████▏  | 371/407 [00:51<00:05,  7.09b/s]

KeyboardInterrupt: 

In [None]:
 f.close()
sampler.close()
print("Done")