Copyright (C) 2016 - 2019 Pinard Liu(liujianping-ok@163.com)

https://www.cnblogs.com/pinard

Permission given to modify the code as long as you keep this declaration at the top

用tensorflow学习贝叶斯个性化排序(BPR) https://www.cnblogs.com/pinard/p/9163481.html

In [18]:
import numpy
import tensorflow as tf
import os
import random
from collections import defaultdict

def load_data(data_path):
    user_ratings = defaultdict(set)
    max_u_id = -1
    max_i_id = -1
    with open(data_path, 'r') as f:
        for line in f.readlines():
            u, i, _, _ = line.split("\t")
            u = int(u)
            i = int(i)
            user_ratings[u].add(i)
            max_u_id = max(u, max_u_id)
            max_i_id = max(i, max_i_id)
    print ("max_u_id:", max_u_id)
    print ("max_i_id:", max_i_id)
    return max_u_id, max_i_id, user_ratings
    

data_path = os.path.join('D:\\tmp\\ml-100k', 'u.data')
user_count, item_count, user_ratings = load_data(data_path)


max_u_id: 943
max_i_id: 1682


In [19]:
print (user_ratings)

defaultdict(<class 'set'>, {196: {257, 8, 393, 13, 269, 655, 1022, 663, 25, 153, 411, 285, 286, 287, 428, 173, 306, 692, 66, 67, 580, 70, 202, 845, 340, 1241, 94, 1118, 108, 238, 111, 1007, 110, 242, 116, 762, 251, 381, 382}, 186: {1033, 12, 1042, 1046, 540, 31, 546, 550, 38, 554, 44, 53, 566, 55, 568, 56, 1083, 71, 588, 77, 591, 79, 595, 596, 95, 98, 100, 106, 117, 118, 121, 147, 148, 159, 684, 689, 177, 1213, 203, 717, 225, 226, 1253, 742, 237, 754, 243, 250, 1277, 257, 770, 258, 263, 269, 281, 288, 291, 294, 295, 298, 299, 300, 302, 303, 306, 820, 1336, 829, 322, 327, 330, 331, 332, 333, 338, 356, 1385, 880, 887, 1399, 385, 405, 406, 925, 934, 939, 977, 470, 983, 988, 477, 1016}, 22: {2, 515, 4, 523, 526, 17, 21, 24, 29, 546, 550, 554, 50, 53, 566, 568, 62, 68, 79, 80, 85, 89, 94, 96, 105, 109, 110, 117, 118, 121, 636, 127, 128, 648, 651, 144, 153, 154, 665, 161, 163, 167, 168, 683, 172, 173, 684, 687, 176, 175, 688, 174, 692, 181, 184, 186, 187, 194, 195, 712, 201, 202, 204, 208, 2

In [20]:
def generate_test(user_ratings):
    user_test = dict()
    for u, i_list in user_ratings.items():
        user_test[u] = random.sample(user_ratings[u], 1)[0]
    return user_test

user_ratings_test = generate_test(user_ratings)

In [21]:
print (user_ratings_test)

{196: 153, 186: 588, 22: 290, 244: 596, 166: 984, 298: 196, 115: 234, 253: 192, 305: 89, 6: 64, 62: 228, 286: 1411, 200: 771, 210: 692, 224: 1401, 303: 577, 122: 70, 194: 402, 291: 941, 234: 16, 119: 168, 167: 288, 299: 198, 308: 411, 95: 183, 38: 94, 102: 667, 63: 475, 160: 117, 50: 15, 301: 334, 225: 603, 290: 484, 97: 175, 157: 250, 181: 1187, 278: 301, 276: 710, 7: 52, 10: 56, 284: 906, 201: 715, 287: 200, 246: 97, 242: 1357, 249: 124, 99: 694, 178: 259, 251: 258, 81: 186, 260: 322, 25: 968, 59: 959, 72: 380, 87: 163, 42: 732, 292: 276, 20: 210, 13: 29, 138: 742, 60: 9, 57: 1011, 223: 717, 189: 863, 243: 699, 92: 278, 241: 332, 254: 234, 293: 1046, 127: 228, 222: 11, 267: 218, 11: 15, 8: 294, 162: 179, 279: 461, 145: 98, 28: 227, 135: 258, 32: 298, 90: 1134, 216: 697, 250: 1137, 271: 265, 265: 245, 198: 727, 168: 325, 110: 658, 58: 347, 237: 357, 94: 154, 128: 283, 44: 161, 264: 70, 41: 58, 82: 546, 262: 419, 174: 451, 43: 82, 84: 1033, 269: 525, 259: 108, 85: 241, 213: 193, 121: 2

In [22]:
print (random.sample(user_ratings[1], 1)[0])

128


In [23]:
def generate_train_batch(user_ratings, user_ratings_test, item_count, batch_size=512):
    t = []
    for b in range(batch_size):
        u = random.sample(user_ratings.keys(), 1)[0]
        i = random.sample(user_ratings[u], 1)[0]
        while i == user_ratings_test[u]:
            i = random.sample(user_ratings[u], 1)[0]
        
        j = random.randint(1, item_count)
        while j in user_ratings[u]:
            j = random.randint(1, item_count)
        t.append([u, i, j])
    return numpy.asarray(t)

In [24]:
print (generate_train_batch(user_ratings, user_ratings_test, item_count, batch_size=512))

[[ 181  681   46]
 [ 133  258  568]
 [ 679  223  101]
 ...
 [  81  476 1566]
 [ 602  358 1281]
 [ 494  237  200]]


In [25]:
def generate_test_batch(user_ratings, user_ratings_test, item_count):
    for u in user_ratings.keys():
        t = []
        i = user_ratings_test[u]
        for j in range(1, item_count+1):
            if not (j in user_ratings[u]):
                t.append([u, i, j])
        yield numpy.asarray(t)

In [26]:
for uij in generate_test_batch(user_ratings, user_ratings_test, item_count):
    print (uij)

[[ 196  153    1]
 [ 196  153    2]
 [ 196  153    3]
 ...
 [ 196  153 1680]
 [ 196  153 1681]
 [ 196  153 1682]]
[[ 186  588    1]
 [ 186  588    2]
 [ 186  588    3]
 ...
 [ 186  588 1680]
 [ 186  588 1681]
 [ 186  588 1682]]
[[  22  290    1]
 [  22  290    3]
 [  22  290    5]
 ...
 [  22  290 1680]
 [  22  290 1681]
 [  22  290 1682]]
[[ 244  596    2]
 [ 244  596    4]
 [ 244  596    5]
 ...
 [ 244  596 1680]
 [ 244  596 1681]
 [ 244  596 1682]]
[[ 166  984    1]
 [ 166  984    2]
 [ 166  984    3]
 ...
 [ 166  984 1680]
 [ 166  984 1681]
 [ 166  984 1682]]
[[ 298  196    2]
 [ 298  196    3]
 [ 298  196    4]
 ...
 [ 298  196 1680]
 [ 298  196 1681]
 [ 298  196 1682]]
[[ 115  234    1]
 [ 115  234    2]
 [ 115  234    3]
 ...
 [ 115  234 1680]
 [ 115  234 1681]
 [ 115  234 1682]]
[[ 253  192    2]
 [ 253  192    3]
 [ 253  192    5]
 ...
 [ 253  192 1680]
 [ 253  192 1681]
 [ 253  192 1682]]
[[ 305   89    3]
 [ 305   89    4]
 [ 305   89    5]
 ...
 [ 305   89 1680]
 [ 305   89

[[ 131  750    2]
 [ 131  750    3]
 [ 131  750    4]
 ...
 [ 131  750 1680]
 [ 131  750 1681]
 [ 131  750 1682]]
[[ 230  484    2]
 [ 230  484    3]
 [ 230  484    4]
 ...
 [ 230  484 1680]
 [ 230  484 1681]
 [ 230  484 1682]]
[[ 126  315    1]
 [ 126  315    2]
 [ 126  315    3]
 ...
 [ 126  315 1680]
 [ 126  315 1681]
 [ 126  315 1682]]
[[ 231  126    2]
 [ 231  126    3]
 [ 231  126    4]
 ...
 [ 231  126 1680]
 [ 231  126 1681]
 [ 231  126 1682]]
[[ 280   76    6]
 [ 280   76   10]
 [ 280   76   14]
 ...
 [ 280   76 1680]
 [ 280   76 1681]
 [ 280   76 1682]]
[[ 288  173    1]
 [ 288  173    2]
 [ 288  173    3]
 ...
 [ 288  173 1680]
 [ 288  173 1681]
 [ 288  173 1682]]
[[ 152 1301    1]
 [ 152 1301    2]
 [ 152 1301    3]
 ...
 [ 152 1301 1680]
 [ 152 1301 1681]
 [ 152 1301 1682]]
[[ 217  685    1]
 [ 217  685    3]
 [ 217  685    4]
 ...
 [ 217  685 1680]
 [ 217  685 1681]
 [ 217  685 1682]]
[[  79  301    2]
 [  79  301    3]
 [  79  301    4]
 ...
 [  79  301 1680]
 [  79  301

 [ 183  225 1682]]
[[ 328  427    1]
 [ 328  427    2]
 [ 328  427    3]
 ...
 [ 328  427 1680]
 [ 328  427 1681]
 [ 328  427 1682]]
[[ 322   50    2]
 [ 322   50    3]
 [ 322   50    4]
 ...
 [ 322   50 1680]
 [ 322   50 1681]
 [ 322   50 1682]]
[[ 330  168    2]
 [ 330  168    3]
 [ 330  168    4]
 ...
 [ 330  168 1680]
 [ 330  168 1681]
 [ 330  168 1682]]
[[  27  281    1]
 [  27  281    2]
 [  27  281    3]
 ...
 [  27  281 1680]
 [  27  281 1681]
 [  27  281 1682]]
[[ 331   59    2]
 [ 331   59    3]
 [ 331   59    4]
 ...
 [ 331   59 1680]
 [ 331   59 1681]
 [ 331   59 1682]]
[[ 332  218    2]
 [ 332  218    3]
 [ 332  218    4]
 ...
 [ 332  218 1680]
 [ 332  218 1681]
 [ 332  218 1682]]
[[ 329  288    1]
 [ 329  288    2]
 [ 329  288    3]
 ...
 [ 329  288 1680]
 [ 329  288 1681]
 [ 329  288 1682]]
[[  86  326    1]
 [  86  326    2]
 [  86  326    3]
 ...
 [  86  326 1680]
 [  86  326 1681]
 [  86  326 1682]]
[[ 139  127    1]
 [ 139  127    2]
 [ 139  127    3]
 ...
 [ 139  12

[[ 472  651    5]
 [ 472  651    6]
 [ 472  651    8]
 ...
 [ 472  651 1680]
 [ 472  651 1681]
 [ 472  651 1682]]
[[ 465  845    2]
 [ 465  845    3]
 [ 465  845    4]
 ...
 [ 465  845 1680]
 [ 465  845 1681]
 [ 465  845 1682]]
[[ 463  744    2]
 [ 463  744    4]
 [ 463  744    5]
 ...
 [ 463  744 1680]
 [ 463  744 1681]
 [ 463  744 1682]]
[[ 471  140    2]
 [ 471  140    3]
 [ 471  140    4]
 ...
 [ 471  140 1680]
 [ 471  140 1681]
 [ 471  140 1682]]
[[ 474  131    1]
 [ 474  131    2]
 [ 474  131    3]
 ...
 [ 474  131 1680]
 [ 474  131 1681]
 [ 474  131 1682]]
[[ 469  607    1]
 [ 469  607    2]
 [ 469  607    3]
 ...
 [ 469  607 1680]
 [ 469  607 1681]
 [ 469  607 1682]]
[[ 464  258    1]
 [ 464  258    2]
 [ 464  258    3]
 ...
 [ 464  258 1680]
 [ 464  258 1681]
 [ 464  258 1682]]
[[ 476  890    1]
 [ 476  890    2]
 [ 476  890    3]
 ...
 [ 476  890 1680]
 [ 476  890 1681]
 [ 476  890 1682]]
[[ 478  710    2]
 [ 478  710    3]
 [ 478  710    4]
 ...
 [ 478  710 1680]
 [ 478  710

[[ 692 1023    2]
 [ 692 1023    3]
 [ 692 1023    4]
 ...
 [ 692 1023 1680]
 [ 692 1023 1681]
 [ 692 1023 1682]]
[[ 690  705    2]
 [ 690  705    3]
 [ 690  705    5]
 ...
 [ 690  705 1680]
 [ 690  705 1681]
 [ 690  705 1682]]
[[ 689  298    2]
 [ 689  298    3]
 [ 689  298    4]
 ...
 [ 689  298 1680]
 [ 689  298 1681]
 [ 689  298 1682]]
[[ 686  234    1]
 [ 686  234    3]
 [ 686  234    4]
 ...
 [ 686  234 1680]
 [ 686  234 1681]
 [ 686  234 1682]]
[[ 693  499    1]
 [ 693  499    2]
 [ 693  499    3]
 ...
 [ 693  499 1680]
 [ 693  499 1681]
 [ 693  499 1682]]
[[ 688  898    1]
 [ 688  898    2]
 [ 688  898    3]
 ...
 [ 688  898 1680]
 [ 688  898 1681]
 [ 688  898 1682]]
[[ 697 1012    2]
 [ 697 1012    3]
 [ 697 1012    4]
 ...
 [ 697 1012 1680]
 [ 697 1012 1681]
 [ 697 1012 1682]]
[[ 698  968    2]
 [ 698  968    3]
 [ 698  968    4]
 ...
 [ 698  968 1680]
 [ 698  968 1681]
 [ 698  968 1682]]
[[ 670  521    1]
 [ 670  521    2]
 [ 670  521    3]
 ...
 [ 670  521 1680]
 [ 670  521

 [ 911  176 1682]]
[[ 912  186    1]
 [ 912  186    2]
 [ 912  186    3]
 ...
 [ 912  186 1680]
 [ 912  186 1681]
 [ 912  186 1682]]
[[ 914 1406    1]
 [ 914 1406    2]
 [ 914 1406    3]
 ...
 [ 914 1406 1680]
 [ 914 1406 1681]
 [ 914 1406 1682]]
[[ 918  428    2]
 [ 918  428    3]
 [ 918  428    4]
 ...
 [ 918  428 1680]
 [ 918  428 1681]
 [ 918  428 1682]]
[[ 919  124    2]
 [ 919  124    3]
 [ 919  124    6]
 ...
 [ 919  124 1680]
 [ 919  124 1681]
 [ 919  124 1682]]
[[ 921 1279    2]
 [ 921 1279    3]
 [ 921 1279    4]
 ...
 [ 921 1279 1680]
 [ 921 1279 1681]
 [ 921 1279 1682]]
[[ 910  508    2]
 [ 910  508    4]
 [ 910  508    5]
 ...
 [ 910  508 1680]
 [ 910  508 1681]
 [ 910  508 1682]]
[[ 913  310    2]
 [ 913  310    3]
 [ 913  310    5]
 ...
 [ 913  310 1680]
 [ 913  310 1681]
 [ 913  310 1682]]
[[ 915  346    1]
 [ 915  346    2]
 [ 915  346    3]
 ...
 [ 915  346 1680]
 [ 915  346 1681]
 [ 915  346 1682]]
[[ 922   11    2]
 [ 922   11    3]
 [ 922   11    4]
 ...
 [ 922   1

In [27]:
def bpr_mf(user_count, item_count, hidden_dim):
    u = tf.placeholder(tf.int32, [None])
    i = tf.placeholder(tf.int32, [None])
    j = tf.placeholder(tf.int32, [None])

    with tf.device("/cpu:0"):
        user_emb_w = tf.get_variable("user_emb_w", [user_count+1, hidden_dim], 
                            initializer=tf.random_normal_initializer(0, 0.1))
        item_emb_w = tf.get_variable("item_emb_w", [item_count+1, hidden_dim], 
                                initializer=tf.random_normal_initializer(0, 0.1))
        
        u_emb = tf.nn.embedding_lookup(user_emb_w, u)
        i_emb = tf.nn.embedding_lookup(item_emb_w, i)
        j_emb = tf.nn.embedding_lookup(item_emb_w, j)
    
    # MF predict: u_i > u_j
    x = tf.reduce_sum(tf.multiply(u_emb, (i_emb - j_emb)), 1, keep_dims=True)
    
    # AUC for one user:
    # reasonable iff all (u,i,j) pairs are from the same user
    # 
    # average AUC = mean( auc for each user in test set)
    mf_auc = tf.reduce_mean(tf.to_float(x > 0))
    
    l2_norm = tf.add_n([
            tf.reduce_sum(tf.multiply(u_emb, u_emb)), 
            tf.reduce_sum(tf.multiply(i_emb, i_emb)),
            tf.reduce_sum(tf.multiply(j_emb, j_emb))
        ])
    
    regulation_rate = 0.0001
    bprloss = regulation_rate * l2_norm - tf.reduce_mean(tf.log(tf.sigmoid(x)))
    
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(bprloss)
    return u, i, j, mf_auc, bprloss, train_op

In [28]:
em =  tf.constant([[1,2],[3,4],[5,6]])
ccc = tf.nn.embedding_lookup(em, 2)
session111 = tf.Session() 
print (session111.run(ccc))

[5 6]


In [29]:
with tf.Graph().as_default(), tf.Session() as session:
    u, i, j, mf_auc, bprloss, train_op = bpr_mf(user_count, item_count, 20)
    session.run(tf.initialize_all_variables())
    for epoch in range(1, 4):
        _batch_bprloss = 0
        for k in range(1, 5000): # uniform samples from training set
            uij = generate_train_batch(user_ratings, user_ratings_test, item_count)

            _bprloss, _train_op = session.run([bprloss, train_op], 
                                feed_dict={u:uij[:,0], i:uij[:,1], j:uij[:,2]})
            _batch_bprloss += _bprloss
        
        print ("epoch: ", epoch)
        print ("bpr_loss: ", _batch_bprloss / k)
        print ("_train_op")

        user_count = 0
        _auc_sum = 0.0

        # each batch will return only one user's auc
        for t_uij in generate_test_batch(user_ratings, user_ratings_test, item_count):

            _auc, _test_bprloss = session.run([mf_auc, bprloss],
                                    feed_dict={u:t_uij[:,0], i:t_uij[:,1], j:t_uij[:,2]}
                                )
            user_count += 1
            _auc_sum += _auc
        print ("test_loss: ", _test_bprloss, "test_auc: ", _auc_sum/user_count)
        print ("")
    variable_names = [v.name for v in tf.trainable_variables()]
    values = session.run(variable_names)
    for k,v in zip(variable_names, values):
        print("Variable: ", k)
        print("Shape: ", v.shape)
        print(v)

epoch:  1
bpr_loss:  0.7236263042427249
_train_op
test_loss:  0.76150036 test_auc:  0.4852939894020929

epoch:  2
bpr_loss:  0.7229681559433149
_train_op
test_loss:  0.76061743 test_auc:  0.48528061393838007

epoch:  3
bpr_loss:  0.7223725006756341
_train_op
test_loss:  0.7597519 test_auc:  0.4852617720521252

Variable:  user_emb_w:0
Shape:  (944, 20)
[[ 0.08105529  0.04270628 -0.12196594 ...  0.02729403  0.1556453
  -0.07148876]
 [ 0.0729574   0.01720054 -0.08198593 ...  0.05565814 -0.0372898
   0.11935959]
 [ 0.03591165 -0.11786834  0.04123168 ...  0.06533947  0.11889934
  -0.19697346]
 ...
 [-0.05796075 -0.00695129  0.07784595 ... -0.03869986  0.10723818
   0.01293885]
 [ 0.13237114 -0.07055715 -0.05505611 ...  0.16433473  0.04535925
   0.0701588 ]
 [-0.2069717   0.04607181  0.07822093 ...  0.03704183  0.07326393
   0.06110878]]
Variable:  item_emb_w:0
Shape:  (1683, 20)
[[ 0.09130769 -0.16516572  0.06490657 ...  0.03657753 -0.02265425
   0.1437734 ]
 [ 0.02463264  0.13691436 -0.017

In [30]:
print (values[0].shape)
print (values[0])

(944, 20)
[[ 0.08105529  0.04270628 -0.12196594 ...  0.02729403  0.1556453
  -0.07148876]
 [ 0.0729574   0.01720054 -0.08198593 ...  0.05565814 -0.0372898
   0.11935959]
 [ 0.03591165 -0.11786834  0.04123168 ...  0.06533947  0.11889934
  -0.19697346]
 ...
 [-0.05796075 -0.00695129  0.07784595 ... -0.03869986  0.10723818
   0.01293885]
 [ 0.13237114 -0.07055715 -0.05505611 ...  0.16433473  0.04535925
   0.0701588 ]
 [-0.2069717   0.04607181  0.07822093 ...  0.03704183  0.07326393
   0.06110878]]


In [31]:
print (values[0][0].shape)
print (values[0][0])

(20,)
[ 0.08105529  0.04270628 -0.12196594 -0.0118052   0.00052658  0.03041512
  0.08178367  0.00329294  0.07350662  0.09376822  0.08431987 -0.0375998
  0.08964082  0.20457171  0.08042991  0.07238016 -0.00179652  0.02729403
  0.1556453  -0.07148876]


In [32]:
session1 = tf.Session()
u1_dim = tf.expand_dims(values[0][0], 0)
print (u1_dim.shape)
print (session1.run(u1_dim))

(1, 20)
[[ 0.08105529  0.04270628 -0.12196594 -0.0118052   0.00052658  0.03041512
   0.08178367  0.00329294  0.07350662  0.09376822  0.08431987 -0.0375998
   0.08964082  0.20457171  0.08042991  0.07238016 -0.00179652  0.02729403
   0.1556453  -0.07148876]]


In [34]:
print (u1_dim.shape)
print (values[1].shape)
u0_all = tf.matmul(u1_dim, values[1],transpose_b=True)

(1, 20)
(1683, 20)


In [35]:
print (u0_all.shape)
print (session1.run(u0_all))

(1, 1683)
[[-0.07065812 -0.02992807 -0.01091636 ... -0.03492806 -0.01390784
   0.04102187]]


In [38]:
session1 = tf.Session()
u1_dim = tf.expand_dims(values[0][0], 0)
u1_all = tf.matmul(u1_dim, values[1],transpose_b=True)
result_1 = session1.run(u0_all)
print (result_1)

[[-0.07065812 -0.02992807 -0.01091636 ... -0.03492806 -0.01390784
   0.04102187]]


In [48]:
print("以下是给用户0的推荐：")
p = numpy.squeeze(result_1)
p[numpy.argsort(p)[:-5]] = 0
for index in range(len(p)):
    if p[index] != 0:
        print (index, p[index])

以下是给用户0的推荐：
405 0.11510968
1117 0.12954654
1256 0.099157736
1260 0.12162529
1627 0.09925515
