In [1]:
import os
import time
import argparse
import tensorflow as tf
from sampler import WarpSampler
from model import Model
from tqdm import tqdm
from util import *

In [2]:
def str2bool(s):
    if s not in {'False', 'True'}:
        raise ValueError('Not a valid boolean string')
    return s == 'True'

In [3]:
def get_args(dataset, train_dir, batch_size, lr, maxlen, hidden_units, num_blocks, num_epochs, num_heads, dropout_rate, l2_emb):
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--dataset', required=True)
    parser.add_argument('--train_dir', required=True)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--maxlen', default=50, type=int)
    parser.add_argument('--hidden_units', default=50, type=int)
    parser.add_argument('--num_blocks', default=2, type=int)
    parser.add_argument('--num_epochs', default=201, type=int)
    parser.add_argument('--num_heads', default=1, type=int)
    parser.add_argument('--dropout_rate', default=0.5, type=float)
    parser.add_argument('--l2_emb', default=0.0, type=float)
    
    args = parser.parse_args( args = [
                                        '--dataset',      dataset,
                                        '--train_dir',    train_dir,
                                        '--batch_size',   batch_size,
                                        '--lr',           lr,
                                        '--maxlen',       maxlen,
                                        '--hidden_units', hidden_units,
                                        '--num_blocks',   num_blocks,
                                        '--num_epochs',   num_epochs,
                                        '--num_heads',    num_heads,
                                        '--dropout_rate', dropout_rate,
                                        '--l2_emb',       l2_emb
                                    ]
                            )
    return args

In [4]:
def model_train(args):
    if not os.path.isdir('tmp/' +  args.dataset + '_' + args.train_dir):
        os.makedirs('tmp/'  + args.dataset + '_' + args.train_dir)
    with open(os.path.join('tmp/'  + args.dataset + '_' + args.train_dir, 'args.txt'), 'w') as f:
        f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(list(vars(args).items()), key=lambda x: x[0])]))
    f.close()

    dataset = data_partition(args.dataset)
    [user_train, user_valid, user_test, usernum, itemnum] = dataset
    num_batch = int(len(user_train) / args.batch_size)
    cc = 0.0
    for u in user_train:
        cc += len(user_train[u])
    print('average sequence length: %.2f' % (cc / len(user_train)))

    f = open(os.path.join('tmp/'  + args.dataset + '_' + args.train_dir, 'log.txt'), 'w')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)

    sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)
    model = Model(usernum, itemnum, args)
    sess.run(tf.initialize_all_variables())

    T = 0.0
    t0 = time.time()

    try:
        for epoch in range(1, args.num_epochs + 1):

            for step in tqdm(list(range(num_batch)), total=num_batch, ncols=70, leave=False, unit='b'):
                u, seq, pos, neg = sampler.next_batch()
                auc, loss, _ = sess.run([model.auc, model.loss, model.train_op],
                                        {model.u: u, model.input_seq: seq, model.pos: pos, model.neg: neg,
                                        model.is_training: True})

            if epoch % 20 == 0:
                t1 = time.time() - t0
                T += t1
                print('Evaluating', end=' ')
                t_test = evaluate(model, dataset, args, sess)
                t_valid = evaluate_valid(model, dataset, args, sess)
                print('')
                print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)' % (
                epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1]))

                f.write(str(t_valid) + ' ' + str(t_test) + '\n')
                f.flush()
                t0 = time.time()
    except:
        sampler.close()
        f.close()
        exit(1)

    f.close()
    sampler.close()
    del model, dataset, user_train, user_valid, user_test, usernum, itemnum, args, sess
    
    print("Done")

In [5]:
# dataset
dataset = 'ml-1m_1000_demo'
# config
train_dir = 'train'
# hyper params
batch_size = '128'
lr = '0.001'
maxlen = '50'
hidden_units = '50'
num_blocks = '2'
num_epochs = '21'
num_heads = '1'
dropout_rate = '0.5'
l2_emb = '0.0'

In [6]:
args = get_args(dataset, train_dir, batch_size, lr, maxlen, hidden_units, num_blocks, num_epochs, num_heads, dropout_rate, l2_emb)
model_train(args)

average sequence length: 109.11


W0922 20:42:27.834063 4446111168 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/SASRec/model.py:6: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0922 20:42:27.857807 4446111168 deprecation.py:323] From /Users/fumiyo_ito/Documents/git/SASRec/model.py:13: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.
W0922 20:42:27.875936 4446111168 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/SASRec/model.py:15: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0922 20:42:27.885945 4446111168 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/SASRec/modules.py:117: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W0922 20:42:29.714962 4446111168 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2

Evaluating 

                

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>

epoch:20, time: 0.110324(s), valid (NDCG@10: 0.0370, HR@10: 0.1111), test (NDCG@10: 0.0000, HR@10: 0.0000)
Done




In [7]:
import slackweb
# スラックにお知らせ
slack = slackweb.Slack(url="https://hooks.slack.com/services/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")
slack.notify(text="SASRec {}の実験が完了しました〜！".format(dataset))

'ok'