-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
148 lines (120 loc) · 4.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import time
import pickle
import random
import numpy as np
import sys
from input import DataInput, DataInputTest
from model import Model
import tensorflow as tf
import warnings
warnings.filterwarnings('ignore')
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
random.seed(1234)
np.random.seed(1234)
tf.set_random_seed(1234)
train_batch_size = 32
test_batch_size = 512
predict_batch_size = 32
predict_users_num = 1000
predict_ads_num = 100
with open('dataset.pkl', 'rb') as f:
train_set = pickle.load(f, encoding='bytes')
test_set = pickle.load(f, encoding='bytes')
cate_list = pickle.load(f, encoding='bytes')
user_count, item_count, cate_count = pickle.load(f, encoding='bytes')
best_auc = 0.0
def calc_auc(raw_arr):
"""Summary
Args:
raw_arr (TYPE): Description
Returns:
TYPE: Description
"""
# sort by pred value, from small to big
arr = sorted(raw_arr, key=lambda d: d[2])
auc = 0.0
fp1, tp1, fp2, tp2 = 0.0, 0.0, 0.0, 0.0
for record in arr:
fp2 += record[0] # noclick
tp2 += record[1] # click
auc += (fp2 - fp1) * (tp2 + tp1)
fp1, tp1 = fp2, tp2
# if all nonclick or click, disgard
threshold = len(arr) - 1e-3
if tp2 > threshold or fp2 > threshold:
return -0.5
if tp2 * fp2 > 0.0: # normal auc
return (1.0 - auc / (2.0 * tp2 * fp2))
else:
return None
def _auc_arr(score):
score_p = score[:, 0]
score_n = score[:, 1]
# print "============== p ============="
# print score_p
# print "============== n ============="
# print score_n
score_arr = []
for s in score_p.tolist():
score_arr.append([0, 1, s])
for s in score_n.tolist():
score_arr.append([1, 0, s])
return score_arr
def _eval(sess, model): #验证集得到最好的model
auc_sum = 0.0
score_arr = []
for _, uij in DataInputTest(test_set, test_batch_size):
auc_, score_ = model.eval(sess, uij)
score_arr += _auc_arr(score_)
auc_sum += auc_ * len(uij[0])
test_gauc = auc_sum / len(test_set)
Auc = calc_auc(score_arr)
global best_auc
if best_auc < test_gauc:
best_auc = test_gauc
model.save(sess, 'save_path/ckpt')
return test_gauc, Auc
def _test(sess, model):
auc_sum = 0.0
score_arr = []
predicted_users_num = 0
print("test sub items")
for _, uij in DataInputTest(test_set, predict_batch_size):
if predicted_users_num >= predict_users_num:
break
score_ = model.test(sess, uij)
score_arr.append(score_)
predicted_users_num += predict_batch_size
return score_[0]
gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
model = Model(user_count, item_count, cate_count, cate_list, predict_batch_size, predict_ads_num)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
print('test_gauc: %.4f\t test_auc: %.4f' % _eval(sess, model))
sys.stdout.flush()
lr = 1.0
start_time = time.time()
for _ in range(50):
random.shuffle(train_set)
epoch_size = round(len(train_set) / train_batch_size)
loss_sum = 0.0
for _, uij in DataInput(train_set, train_batch_size):
loss = model.train(sess, uij, lr)
loss_sum += loss
if model.global_step.eval() % 1000 == 0: #每隔1000个step就eval一下,保存最好的model
test_gauc, Auc = _eval(sess, model)
print('Epoch %d Global_step %d\tTrain_loss: %.4f\tEval_GAUC: %.4f\tEval_AUC: %.4f' %
(model.global_epoch_step.eval(), model.global_step.eval(),
loss_sum / 1000, test_gauc, Auc))
sys.stdout.flush()
loss_sum = 0.0
if model.global_step.eval() % 336000 == 0: #调整learning rate,学习次数多的时候,调低学习率
lr = 0.1
print('Epoch %d DONE\tCost time: %.2f' %
(model.global_epoch_step.eval(), time.time() - start_time))
sys.stdout.flush()
model.global_epoch_step_op.eval()
print('best test_gauc:', best_auc)
sys.stdout.flush()