In [1]:
import dataset
import os
import TransFM
import FM
import PRME_FM
import HRM_FM
import sys
import argparse
import tensorflow as tf

In [2]:
def parse_args( filename,     model,            features,       features_file,  max_iters,       
                num_dims,     linear_reg,       emb_reg,        trans_reg,      init_mean,     
                starting_lr,  lr_decay_factor,  lr_decay_freq,  eval_freq,      quit_delta ):
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--filename',
        help='Filename of the input dataset.',
        required=True)
    parser.add_argument('--model',
        help='Model to run.',
        choices=['TransFM', 'FM', 'PRME-FM', 'HRM-FM'],
        required=True)
    parser.add_argument('--features',
        help='Which features to include.',
        choices=['none', 'categories', 'time', 'content', 'geo'],
        default='none')
    parser.add_argument('--features_file',
        help='Filename(s) for content features. For content features, provide '
        '<user filename>,<item filename>. For categories and geo, provide a single '
        'filename. Temporal data should be included within the dataset file itself.')
    parser.add_argument('--max_iters',
        help='Max number of iterations to run',
        default=1000000,
        type=int)
    parser.add_argument('--num_dims',
        help='Model dimensionality.',
        default=10,
        type=int)
    parser.add_argument('--linear_reg',
        help='L2 regularization: linear_reg.',
        default=1.0,
        type=float)
    parser.add_argument('--emb_reg',
        help='L2 regularization: embbeding regularization.',
        default=1.0,
        type=float)
    parser.add_argument('--trans_reg',
        help='L2 regularization: translation regularization.',
        default=1.0,
        type=float)
    parser.add_argument('--init_mean',
        help='Initialization mean for model parameters.',
        default=0.1,
        type=float)
    parser.add_argument('--starting_lr',
        help='Initial learning rate.',
        default=0.001,
        type=float)
    parser.add_argument('--lr_decay_factor',
        help='Decay factor for learning rate.',
        default=1.0,
        type=float)
    parser.add_argument('--lr_decay_freq',
        help='Frequency at which to decay learning rate.',
        default=1000,
        type=int)
    parser.add_argument('--eval_freq',
        help='Frequency at which to evaluate model.',
        default=50,
        type=int)
    parser.add_argument('--quit_delta',
        help='Number of iterations at which to quit if no improvement.',
        default=1000,
        type=int)
    args = parser.parse_args(args = [ 
                            '--filename',        filename,
                            '--model',           model,
                            '--features',        features,
                            '--features_file',   features_file, 
                            '--max_iters',       max_iters,
                            '--num_dims',        num_dims,
                            '--linear_reg',      linear_reg,
                            '--emb_reg',         emb_reg,
                            '--trans_reg',       trans_reg,
                            '--init_mean',       init_mean,
                            '--starting_lr',     starting_lr,
                            '--lr_decay_factor', lr_decay_factor,
                            '--lr_decay_freq',   lr_decay_freq,
                            '--eval_freq' ,      eval_freq,
                            '--quit_delta' ,     quit_delta
                    ])
    print(args)
    print('')
    return args

In [3]:
def train_transrec(dataset, args):
    if args.model == 'TransFM':
        model = TransFM.TransFM(dataset, args)
#     elif args.model == 'FM':
#         model = FM.FM(dataset, args)
#     elif args.model == 'PRME-FM':
#         model = PRME_FM.PRME_FM(dataset, args)
#     elif args.model == 'HRM-FM':
#         model = HRM_FM.HRM_FM(dataset, args)

    val_auc, test_auc,  var_emb_factors, var_trans_factors, g = model.train()

    print('')
    print(args)
    print('Validation AUC  = ' + str(val_auc))
    print('Test AUC        = ' + str(test_auc))
    return (val_auc, test_auc,  var_emb_factors, var_trans_factors, g)

### main script

In [4]:
# config
filename        = 'ratings.csv' 
model           = 'TransFM'  
features        = 'none' 
features_file   = 'none' 
max_iters       = '1000000' 
num_dims        = '10' 
linear_reg      = '10.0' 
emb_reg         = '1.0'
trans_reg       = '0.1' 
init_mean       = '0.1' 
starting_lr     = '0.02' 
lr_decay_factor = '1.0' 
lr_decay_freq   = '1000' 
eval_freq       = '50' 
quit_delta      = '1000'

In [5]:
if __name__ == '__main__':
    args = parse_args(  filename,     model,            features,       features_file,  max_iters,       
                        num_dims,     linear_reg,       emb_reg,        trans_reg,      init_mean,     
                        starting_lr,  lr_decay_factor,  lr_decay_freq,  eval_freq,      quit_delta )
    d = dataset.Dataset(args.filename, args)
    val_auc, test_auc,  var_emb_factors, var_trans_factors, g = train_transrec(d, args)

Namespace(emb_reg=1.0, eval_freq=50, features='none', features_file='none', filename='ratings.csv', init_mean=0.1, linear_reg=10.0, lr_decay_factor=1.0, lr_decay_freq=1000, max_iters=1000000, model='TransFM', num_dims=10, quit_delta=1000, starting_lr=0.02, trans_reg=0.1)

First pass
	num_users = 20
	num_items = 1768
	df_shape  = (2999, 4)
Collected user counts...
Collected item counts...
User filtering done...


W0807 20:54:24.865934 4597990848 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/TransFM/TransFM.py:106: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W0807 20:54:24.895255 4597990848 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/TransFM/TransFM.py:122: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0807 20:54:24.903903 4597990848 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/TransFM/TransFM.py:25: The name tf.sparse_tensor_dense_matmul is deprecated. Please use tf.sparse.sparse_dense_matmul instead.

W0807 20:54:24.915277 4597990848 deprecation.py:506] From /Users/fumiyo_ito/Documents/git/TransFM/TransFM.py:31: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
W0807 20:54:24.923309 4597990848 deprecati

Item filtering done...
Second pass
	num_users = 20
	num_items = 80
	df_shape  = (520, 4)
Normalizing temporal values...
Constructing datasets...
In class TransFM
Feature dimension = 180


W0807 20:54:25.006724 4597990848 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/TransFM/TransFM.py:158: The name tf.log is deprecated. Please use tf.math.log instead.

W0807 20:54:25.016883 4597990848 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/TransFM/TransFM.py:162: The name tf.train.exponential_decay is deprecated. Please use tf.compat.v1.train.exponential_decay instead.

W0807 20:54:25.026036 4597990848 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/git/TransFM/TransFM.py:165: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.

W0807 20:54:25.538912 4597990848 deprecation.py:323] From /Users/fumiyo_ito/Documents/git/TransFM/TransFM.py:169: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.
W0807 20:54:25.540326 4597990848 deprecation_wrapper.py:119] From /Users/fumiyo_ito/Documents/

Epoch: 0 	Loss = 332.70440673828125
	Val AUC = 0.3842105	Test AUC = 0.3894737
	Current max = 0.3842104971408844 at epoch 0
Epoch: 1 	Loss = 323.8565673828125
Epoch: 2 	Loss = 317.20648193359375
Epoch: 3 	Loss = 310.3708190917969
Epoch: 4 	Loss = 305.685791015625
Epoch: 5 	Loss = 296.2783203125
Epoch: 6 	Loss = 291.52752685546875
Epoch: 7 	Loss = 280.2460021972656
Epoch: 8 	Loss = 274.98321533203125
Epoch: 9 	Loss = 269.7142639160156
Epoch: 10 	Loss = 254.27789306640625
Epoch: 11 	Loss = 251.39208984375
Epoch: 12 	Loss = 248.86500549316406
Epoch: 13 	Loss = 240.76690673828125
Epoch: 14 	Loss = 228.2826385498047
Epoch: 15 	Loss = 214.69866943359375
Epoch: 16 	Loss = 214.8921356201172
Epoch: 17 	Loss = 216.76852416992188
Epoch: 18 	Loss = 200.03118896484375
Epoch: 19 	Loss = 200.35409545898438
Epoch: 20 	Loss = 195.39126586914062
Epoch: 21 	Loss = 205.00924682617188
Epoch: 22 	Loss = 199.80052185058594
Epoch: 23 	Loss = 207.9153289794922
Epoch: 24 	Loss = 190.15225219726562
Epoch: 25 	Los

Epoch: 248 	Loss = 114.18403625488281
Epoch: 249 	Loss = 111.39147186279297
Epoch: 250 	Loss = 118.81715393066406
	Val AUC = 0.5947369	Test AUC = 0.4105263
	Current max = 0.5947368741035461 at epoch 250
Epoch: 251 	Loss = 108.0062255859375
Epoch: 252 	Loss = 114.02845764160156
Epoch: 253 	Loss = 122.79600524902344
Epoch: 254 	Loss = 117.54546356201172
Epoch: 255 	Loss = 117.70646667480469
Epoch: 256 	Loss = 124.0114974975586
Epoch: 257 	Loss = 122.44133758544922
Epoch: 258 	Loss = 128.00418090820312
Epoch: 259 	Loss = 124.11518859863281
Epoch: 260 	Loss = 121.8332290649414
Epoch: 261 	Loss = 123.89492797851562
Epoch: 262 	Loss = 123.82987976074219
Epoch: 263 	Loss = 123.4385757446289
Epoch: 264 	Loss = 110.68022155761719
Epoch: 265 	Loss = 112.1372299194336
Epoch: 266 	Loss = 122.97647857666016
Epoch: 267 	Loss = 119.14950561523438
Epoch: 268 	Loss = 118.5841064453125
Epoch: 269 	Loss = 115.86106872558594
Epoch: 270 	Loss = 112.53782653808594
Epoch: 271 	Loss = 122.21568298339844
Epoch

Epoch: 459 	Loss = 122.38712310791016
Epoch: 460 	Loss = 110.45098114013672
Epoch: 461 	Loss = 115.39352416992188
Epoch: 462 	Loss = 119.8653564453125
Epoch: 463 	Loss = 113.91155242919922
Epoch: 464 	Loss = 111.11509704589844
Epoch: 465 	Loss = 109.97491455078125
Epoch: 466 	Loss = 109.72807312011719
Epoch: 467 	Loss = 116.46180725097656
Epoch: 468 	Loss = 117.63858032226562
Epoch: 469 	Loss = 122.37547302246094
Epoch: 470 	Loss = 115.92105865478516
Epoch: 471 	Loss = 123.78897094726562
Epoch: 472 	Loss = 111.61540222167969
Epoch: 473 	Loss = 117.64091491699219
Epoch: 474 	Loss = 118.10282897949219
Epoch: 475 	Loss = 113.77517700195312
Epoch: 476 	Loss = 110.05783081054688
Epoch: 477 	Loss = 118.49537658691406
Epoch: 478 	Loss = 115.14232635498047
Epoch: 479 	Loss = 116.53555297851562
Epoch: 480 	Loss = 116.16671752929688
Epoch: 481 	Loss = 113.95562744140625
Epoch: 482 	Loss = 109.70361328125
Epoch: 483 	Loss = 122.18672180175781
Epoch: 484 	Loss = 117.33509826660156
Epoch: 485 	Loss

Epoch: 675 	Loss = 115.78050231933594
Epoch: 676 	Loss = 117.97261810302734
Epoch: 677 	Loss = 109.50732421875
Epoch: 678 	Loss = 108.64895629882812
Epoch: 679 	Loss = 109.79191589355469
Epoch: 680 	Loss = 112.60762786865234
Epoch: 681 	Loss = 114.31535339355469
Epoch: 682 	Loss = 106.91432189941406
Epoch: 683 	Loss = 119.62372589111328
Epoch: 684 	Loss = 112.45065307617188
Epoch: 685 	Loss = 110.17607116699219
Epoch: 686 	Loss = 113.23603820800781
Epoch: 687 	Loss = 117.04701232910156
Epoch: 688 	Loss = 111.80392456054688
Epoch: 689 	Loss = 113.09502410888672
Epoch: 690 	Loss = 108.977294921875
Epoch: 691 	Loss = 110.5811538696289
Epoch: 692 	Loss = 104.23531341552734
Epoch: 693 	Loss = 115.89739990234375
Epoch: 694 	Loss = 111.67095947265625
Epoch: 695 	Loss = 117.33832550048828
Epoch: 696 	Loss = 116.69204711914062
Epoch: 697 	Loss = 112.04045104980469
Epoch: 698 	Loss = 105.47467041015625
Epoch: 699 	Loss = 116.30430603027344
Epoch: 700 	Loss = 110.84098815917969
	Val AUC = 0.53684

Epoch: 892 	Loss = 111.41249084472656
Epoch: 893 	Loss = 118.57699584960938
Epoch: 894 	Loss = 118.11397552490234
Epoch: 895 	Loss = 114.15238952636719
Epoch: 896 	Loss = 114.10604858398438
Epoch: 897 	Loss = 114.66154479980469
Epoch: 898 	Loss = 113.33744812011719
Epoch: 899 	Loss = 122.28619384765625
Epoch: 900 	Loss = 122.3060531616211
	Val AUC = 0.6052632	Test AUC = 0.44210523
	Current max = 0.6473683714866638 at epoch 650
Epoch: 901 	Loss = 117.63832092285156
Epoch: 902 	Loss = 108.71070861816406
Epoch: 903 	Loss = 120.24147033691406
Epoch: 904 	Loss = 108.29515075683594
Epoch: 905 	Loss = 107.79983520507812
Epoch: 906 	Loss = 111.2197265625
Epoch: 907 	Loss = 112.43321228027344
Epoch: 908 	Loss = 111.07456970214844
Epoch: 909 	Loss = 115.91896057128906
Epoch: 910 	Loss = 120.39311981201172
Epoch: 911 	Loss = 115.51187896728516
Epoch: 912 	Loss = 110.19670867919922
Epoch: 913 	Loss = 117.29129791259766
Epoch: 914 	Loss = 107.33535766601562
Epoch: 915 	Loss = 109.96764373779297
Epo

Epoch: 1121 	Loss = 116.22337341308594
Epoch: 1122 	Loss = 113.64131164550781
Epoch: 1123 	Loss = 115.61138916015625
Epoch: 1124 	Loss = 114.1863784790039
Epoch: 1125 	Loss = 111.52194213867188
Epoch: 1126 	Loss = 111.42095947265625
Epoch: 1127 	Loss = 105.88548278808594
Epoch: 1128 	Loss = 112.95368957519531
Epoch: 1129 	Loss = 109.40987396240234
Epoch: 1130 	Loss = 112.28767395019531
Epoch: 1131 	Loss = 110.94508361816406
Epoch: 1132 	Loss = 113.69905090332031
Epoch: 1133 	Loss = 111.93666076660156
Epoch: 1134 	Loss = 110.80155181884766
Epoch: 1135 	Loss = 115.6980209350586
Epoch: 1136 	Loss = 107.92442321777344
Epoch: 1137 	Loss = 123.94975280761719
Epoch: 1138 	Loss = 107.91273498535156
Epoch: 1139 	Loss = 109.6620864868164
Epoch: 1140 	Loss = 112.5634765625
Epoch: 1141 	Loss = 119.2214126586914
Epoch: 1142 	Loss = 107.28107452392578
Epoch: 1143 	Loss = 111.68490600585938
Epoch: 1144 	Loss = 113.05532836914062
Epoch: 1145 	Loss = 106.48202514648438
Epoch: 1146 	Loss = 114.957160949

Epoch: 1370 	Loss = 115.48799896240234
Epoch: 1371 	Loss = 117.61443328857422
Epoch: 1372 	Loss = 107.92263793945312
Epoch: 1373 	Loss = 117.39186096191406
Epoch: 1374 	Loss = 116.13459777832031
Epoch: 1375 	Loss = 117.59282684326172
Epoch: 1376 	Loss = 116.49526977539062
Epoch: 1377 	Loss = 117.15110778808594
Epoch: 1378 	Loss = 114.26109313964844
Epoch: 1379 	Loss = 115.47054290771484
Epoch: 1380 	Loss = 119.84515380859375
Epoch: 1381 	Loss = 115.66513061523438
Epoch: 1382 	Loss = 107.92220306396484
Epoch: 1383 	Loss = 106.47899627685547
Epoch: 1384 	Loss = 112.38616943359375
Epoch: 1385 	Loss = 118.3733901977539
Epoch: 1386 	Loss = 107.44451904296875
Epoch: 1387 	Loss = 109.49652862548828
Epoch: 1388 	Loss = 114.84075927734375
Epoch: 1389 	Loss = 124.51475524902344
Epoch: 1390 	Loss = 119.55625915527344
Epoch: 1391 	Loss = 121.72576904296875
Epoch: 1392 	Loss = 105.19804382324219
Epoch: 1393 	Loss = 107.03846740722656
Epoch: 1394 	Loss = 112.257568359375
Epoch: 1395 	Loss = 120.6231

Epoch: 1578 	Loss = 110.94107055664062
Epoch: 1579 	Loss = 114.26249694824219
Epoch: 1580 	Loss = 111.05779266357422
Epoch: 1581 	Loss = 106.29776000976562
Epoch: 1582 	Loss = 114.5297622680664
Epoch: 1583 	Loss = 112.3629379272461
Epoch: 1584 	Loss = 113.50942993164062
Epoch: 1585 	Loss = 117.16960144042969
Epoch: 1586 	Loss = 110.09820556640625
Epoch: 1587 	Loss = 105.72532653808594
Epoch: 1588 	Loss = 113.69436645507812
Epoch: 1589 	Loss = 110.93197631835938
Epoch: 1590 	Loss = 101.87983703613281
Epoch: 1591 	Loss = 111.55779266357422
Epoch: 1592 	Loss = 120.2549819946289
Epoch: 1593 	Loss = 115.23350524902344
Epoch: 1594 	Loss = 116.09346008300781
Epoch: 1595 	Loss = 111.77560424804688
Epoch: 1596 	Loss = 116.95569610595703
Epoch: 1597 	Loss = 111.72982788085938
Epoch: 1598 	Loss = 118.71575927734375
Epoch: 1599 	Loss = 111.99652099609375
Epoch: 1600 	Loss = 111.00364685058594
	Val AUC = 0.52105266	Test AUC = 0.34736842
	Current max = 0.6473683714866638 at epoch 650
Epoch: 1601 	Lo

In [6]:
print(train_transrec)

<function train_transrec at 0xb2e0ce950>


In [7]:
var_emb_factors

<tf.Variable 'emb_factors:0' shape=(180, 10) dtype=float32_ref>

In [8]:
var_trans_factors

<tf.Variable 'trans_factors:0' shape=(180, 10) dtype=float32_ref>

In [11]:
with g.as_default():
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
            var_trans_factors_array = sess.run(var_trans_factors)

In [12]:
print(var_trans_factors_array)

[[-0.02234516 -0.06075604  0.02572634 ...  0.06150661  0.05636995
   0.09556062]
 [ 0.06656756  0.08392347  0.0265552  ...  0.04572206  0.06093059
   0.09374323]
 [-0.02106845 -0.09611819 -0.07208109 ... -0.06502934 -0.06227753
  -0.05833266]
 ...
 [-0.04258239  0.01423626  0.01042209 ...  0.00273652 -0.0757273
  -0.05342245]
 [-0.01873877 -0.03166125  0.01890874 ... -0.08672202  0.09041398
  -0.01737066]
 [-0.00363672  0.06548946 -0.04040125 ...  0.06956197 -0.04008773
  -0.03523288]]


NameError: name 'pl_pos_indices' is not defined