# base model of CTR
Embedding + MLP
目的：用户的前n个广告点击记录，预测点击第n+1个广告的概率

In [1]:
from sklearn import metrics
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import data as gdata, loss as gloss, nn
import pickle
import random
import time
import sys
sys.path.append('/data/CaoZhong/utils/')
from my_utils import *
from tqdm import tqdm_notebook

## 加载数据

In [2]:
ctx = mx.gpu(1)
train_batch_size = 32
test_batch_size = 512
model_name = 'din_basemodel_gluon'
with open('../data/dataset_sub_gluon.pkl', 'rb') as f:
    train_set = pickle.load(f)
    test_set = pickle.load(f)
    cate_list = pickle.load(f)
    user_count, item_count, cate_count = pickle.load(f)
random.shuffle(train_set)
cate_list = nd.array(cate_list, ctx=ctx)
print("user count: %d\titem count: %d\tcate count: %d" % (user_count, item_count, cate_count))
print("train set len: ",len(train_set))
print('test set len: ', len(test_set))

user count: 1053	item count: 63001	cate count: 801
train set len:  68814
test set len:  2106


## 建立模型

In [3]:
class HybridNet(nn.HybridBlock):
    def __init__(self, item_count, cate_count, embed_size, num_hiddens,  ctx, **kwargs):
        super(HybridNet, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.item_embedding = nn.Embedding(item_count, embed_size)
        self.cate_embedding = nn.Embedding(cate_count, embed_size)
        self.batch_normal_layer = nn.BatchNorm()
        self.dense_layer = nn.Dense(num_hiddens)
        
        self.mlp = nn.HybridSequential()
        self.mlp.add(nn.BatchNorm())
        self.mlp.add(nn.Dense(80, activation='sigmoid'))
        self.mlp.add(nn.Dense(40, activation='sigmoid'))
        self.mlp.add(nn.Dense(1, activation=None))
    
    def hybrid_forward(self,F, item, cate, hist, hist_cate, ls):
        
        item = item.reshape((-1))                         # [B]
        item_emb_w = self.item_embedding(item)            # [B, H/2]
        
        cate_emb_w = self.cate_embedding(cate)
        i_emb = F.concat(item_emb_w, cate_emb_w,dim=1)   # [B, H]
        
        hi_emb = self.item_embedding(hist)                 # [B, T, H/2]
        hc_emb = self.cate_embedding(hist_cate)                 # [B, T, H/2]
        h_emb = F.concat(hi_emb, hc_emb, dim=-1)          # [B, T, H]
        
        h_emb = F.SequenceMask(h_emb.swapaxes(0, 1), sequence_length=ls.reshape((-1)), use_sequence_length=True) # [T, B, H]
        h_emb = F.mean(h_emb, axis=0)
        h_emb = h_emb.reshape((-1, self.num_hiddens))
        h_emb = self.batch_normal_layer(h_emb)
        h_emb = self.dense_layer(h_emb)
        user_emb = h_emb
        
        din = F.concat(user_emb, i_emb, dim=-1)
        score = self.mlp(din)
        return score

## 混合式编程

In [4]:
# cate_list_i = nd.array([0, 1, 2, 3, 4, 5, 6, 7, 8,9] *10 , ctx =ctx)

# item_i = nd.array([1,2],ctx=ctx)
# hist_i = nd.array([[2,3,0,0],[1,4,7,0]],ctx=ctx)
# sl_i = nd.array([2,3],ctx=ctx)
# label_i =  nd.array([0,1],ctx=ctx).reshape((2, -1))

# cate = cate_list_i[item_i].reshape((-1))
# net_i = HybridNet(100, 10, 4, 8, ctx)
# net_i.initialize(init=init.Xavier(),force_reinit=True, ctx=ctx)

# pred_i = net_i(item_i,cate, hist_i, sl_i)
# print(pred_i)

In [5]:
# net_i.hybridize()
# pred_i = net_i(item_i,cate, hist_i, sl_i)
# print(pred_i)

In [6]:
# pred_i = net_i(item_i,cate, hist_i, sl_i)
# print(pred_i)

## 建立数据迭代器

In [7]:
def get_data(data_set):    
    all_user, all_hist, all_pre,  all_label = [],[],[],[]
    for u, hist, pre, label in data_set:
        all_user.append(u)
        all_hist.append(hist)
        all_pre.append(pre)
        all_label.append(label)
    return all_user, all_hist, all_pre, all_label

In [8]:
def batchify(data):
    max_len = max([len(h) for u, h ,p, l in data])
    uid, hist, pre, label, sl= [], [], [], [], []
    for u, h, p, l in data:
        uid += [u]
        sl += [len(h)]
        pre += [p]
        hist += [h + [0] * (max_len-len(h))]
        label += [l]
    return (nd.array(uid).reshape((-1,1)), nd.array(hist), nd.array(pre).reshape((-1, 1)),
           nd.array(label).reshape((-1,1)),nd.array(sl).reshape((-1, 1)))

In [9]:
# 训练集
all_user, all_hist, all_pre, all_label = get_data(train_set)
dataset = gdata.ArrayDataset(all_user, all_hist, all_pre, all_label)
# 测试集
all_user, all_hist, all_pre, all_label = get_data(test_set)
dataset_test = gdata.ArrayDataset(all_user, all_hist, all_pre, all_label)

# 建立数据迭代器
data_iter = gdata.DataLoader(dataset, train_batch_size, shuffle=True, batchify_fn=batchify)
data_iter_test = gdata.DataLoader(dataset_test, test_batch_size, shuffle=True, batchify_fn=batchify)

In [10]:
for batch in data_iter:
    for name, data in zip(['uid','hist','pre','label','sl'], batch):
        print(name, 'shape: ', data.shape, data.context)        
    break
for batch in data_iter_test:
    for name, data in zip(['uid','hist','pre','label','sl'], batch):
        print(name, 'shape: ', data.shape)        
    break

uid shape:  (32, 1) cpu(0)
hist shape:  (32, 311) cpu(0)
pre shape:  (32, 1) cpu(0)
label shape:  (32, 1) cpu(0)
sl shape:  (32, 1) cpu(0)
uid shape:  (512, 1)
hist shape:  (512, 430)
pre shape:  (512, 1)
label shape:  (512, 1)
sl shape:  (512, 1)


# 模型训练

In [11]:
def eval_auc(net, ctx):
    score = None
    y = None
    for batch in data_iter_test:

        uid, hist, pos, label, sl = [data.as_in_context(ctx) for data in batch]
        cate = cate_list[pos].reshape((-1))
        hist_cate = cate_list[hist]
        if score is None:
            score = nd.sigmoid(net(pos,cate, hist, hist_cate, sl))
            y = label
        else:
            score = nd.concat(score, nd.sigmoid(net(pos, cate, hist, hist_cate, sl)), dim=0)
            y = nd.concat(y, label, dim=0)
    fpr,tpr,thresholds = metrics.roc_curve(list(y.asnumpy()),list(score.asnumpy()))
    auc = metrics.auc(fpr,tpr)
    return auc

In [12]:
def train(net, lr, num_epochs, ctx):
    auc_list, loss_list, x_vals= [], [], []
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':lr})
    
    global_step = 1
    stime = time.time()
    stime2 = time.time()
    
    loss_val, auc_val, time_val = 0.0, 0.0, 0.0
    print('auc: %.4f' % (eval_auc(net, ctx)))
    epoch_bar = tqdm_notebook(range(1, num_epochs+1))
    for epoch in epoch_bar:
        l_sum = 0.0
        bar = tqdm_notebook(data_iter)
        for batch in bar:
            
            uid, hist, pre, label, sl = [data.as_in_context(ctx) for data in batch]
            cate = cate_list[pre].reshape((-1))
            hist_cate = cate_list[hist]
            with autograd.record():
                pred = net(pre,cate,hist,hist_cate, sl)
                l = loss(pred, label)
            l.backward()
            trainer.step(train_batch_size)
            
            l_sum += l.mean().asscalar()
            
            if global_step % 1000 ==0:
                
                test_auc = eval_auc(net, ctx)
                loss_val = l_sum / 1000
                auc_val = test_auc
                time_val = time.time()-stime2
                tip = "epoch %d, global step:%d, loss %.4f, test auc:%.4f, time:%.2f" % (epoch,global_step,  loss_val, auc_val, time_val )
                bar.set_description_str(tip)
                tip_info(tip,out=False)
                loss_list.append(l_sum/1000)
                auc_list.append(test_auc)
                x_vals.append(global_step//1000)
                l_sum = 0.0
                stime2 = time.time()
            global_step += 1
        tip = 'epoch %d done, cost time:%.2f' % (epoch, time.time() - stime)
        epoch_bar.set_description_str(tip)
        tip_info(tip,out=False)

    return loss_list, auc_list, x_vals
            

In [13]:
loss = gloss.SigmoidBinaryCrossEntropyLoss()
net = HybridNet(item_count, cate_count, 64, 128, ctx)
net.initialize(init=init.Xavier(),force_reinit=True, ctx=ctx)
net.hybridize()

In [14]:
loss_list, auc_list, x_vals = train(net, 0.1, 30, ctx)

auc: 0.4916


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))




In [16]:
file_name = 'train_result_'+ model_name+'.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(loss_list, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(auc_list, f, pickle.HIGHEST_PROTOCOL)

In [18]:
len(x_vals)

21