In [1]:
from sklearn import metrics
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import data as gdata, loss as gloss, nn
import pickle
import random
import time
import sys
sys.path.append('/data/CaoZhong/utils/')
from my_utils import *
from tqdm import tqdm_notebook

# 加载数据

In [2]:
ctx = mx.gpu(5)
train_batch_size = 64
test_batch_size = 512
model_name = 'base_din'
data_path = '/data/CaoZhong/data/din/dataset_sub_gluon.pkl'

In [3]:
data_iter = DinDataIter(data_path, train_batch_size, test_batch_size)
user_count, item_count, cate_count = data_iter.get_count()
train_iter, test_iter = data_iter.get_data_iter()

user count: 1053	item count: 63001	cate count: 801
train set len:  129888
test set len:  2106


In [4]:
for batch in train_iter:
    for name, data in zip(['uid','hist','pre','label','sl'], batch):
        print(name, 'shape: ', data.shape, data.context)      
    print('this batch max history len:',batch[-1])
    break
for batch in test_iter:
    for name, data in zip(['uid','hist','pre','label','sl'], batch):
        print(name, 'shape: ', data.shape)        
    break

uid shape:  (64,) cpu(0)
hist shape:  (64, 190) cpu(0)
pre shape:  (64, 190) cpu(0)
label shape:  (64,) cpu(0)
sl shape:  (64,) cpu(0)
this batch max history len: 190
uid shape:  (512,)
hist shape:  (512, 430)
pre shape:  (512, 430)
label shape:  (512,)
sl shape:  (512,)


# 建立模型

In [5]:
class HybridNet(nn.HybridBlock):
    def __init__(self, item_count, cate_count, embed_size, num_hiddens,  ctx, **kwargs):
        super(HybridNet, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.item_embedding = nn.Embedding(item_count, embed_size)
        self.cate_embedding = nn.Embedding(cate_count, embed_size)
        self.batch_normal_layer = nn.BatchNorm()
        self.dense_layer = nn.Dense(num_hiddens)
        self.mlp = nn.HybridSequential()
        self.mlp.add(nn.BatchNorm())
        self.mlp.add(nn.Dense(80, activation='sigmoid'))
        self.mlp.add(nn.Dense(40, activation='sigmoid'))
        self.mlp.add(nn.Dense(1, activation=None))
    
    def hybrid_forward(self,F, item, cate, hist, hist_cate, ls):
        
        item = item.reshape((-1))                         # [B]
        item_emb_w = self.item_embedding(item)            # [B, H/2]      
        cate_emb_w = self.cate_embedding(cate)
        i_emb = F.concat(item_emb_w, cate_emb_w,dim=1)   # [B, H]
        
        hi_emb = self.item_embedding(hist)                 # [B, T, H/2]
        hc_emb = self.cate_embedding(hist_cate)                 # [B, T, H/2]
        h_emb = F.concat(hi_emb, hc_emb, dim=-1)          # [B, T, H]
        
        h_emb = F.SequenceMask(h_emb.swapaxes(0, 1), sequence_length=ls.reshape((-1)), use_sequence_length=True) # [T, B, H]
        h_emb = F.mean(h_emb, axis=0)
        h_emb = h_emb.reshape((-1, self.num_hiddens))
        h_emb = self.batch_normal_layer(h_emb)
        h_emb = self.dense_layer(h_emb)
        user_emb = h_emb
        
        din = F.concat(user_emb, i_emb, dim=-1)
        score = self.mlp(din)
        return score

# 模型训练

In [6]:
loss = gloss.SigmoidBinaryCrossEntropyLoss()
net = HybridNet(item_count, cate_count, 64, 128, ctx)
net.initialize(init=init.Xavier(),force_reinit=True, ctx=ctx)
net.hybridize()

In [7]:
loss_list, auc_list, acc_list = train_din(net, train_iter, 
                                      test_iter,loss,train_batch_size,
                                      0.1, 30, ctx,'../logs',loss_name='loss_base_din',
                                      auc_name='auc_base_din',acc_name='acc_base_din')

auc: 0.4705	acc:0.5000


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))




In [8]:
file_name = model_name + '_loss_auc_acc.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(loss_list, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(auc_list, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(acc_list, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(model_name, f, pickle.HIGHEST_PROTOCOL)
net.save_parameters(model_name+'.net')