In [1]:
from sklearn import metrics
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import data as gdata, loss as gloss, nn
import pickle
import random
import time
import sys
sys.path.append('/data/CaoZhong/utils/')
from my_utils import *
from tqdm import tqdm_notebook,tqdm

## 加载数据

In [2]:
ctx = mx.gpu(5)
train_batch_size = 64
test_batch_size = 512
model_name = 'rnn_din'
data_path = '/data/CaoZhong/data/din/dataset_sub_gluon.pkl'

In [3]:
data_iter = DinDataIter(data_path, train_batch_size, test_batch_size)
user_count, item_count, cate_count = data_iter.get_count()
train_iter, test_iter = data_iter.get_data_iter()

user count: 1053	item count: 63001	cate count: 801
train set len:  129888
test set len:  2106


In [4]:
for batch in train_iter:
    for name, data in zip(['uid','hist','hist_cate','pre','cate','label','sl'], batch):
        print(name, 'shape: ', data.shape, data.context)        
    break
for batch in test_iter:
    for name, data in zip(['uid','hist','hist_cate','pre','cate','label','sl'], batch):
        print(name, 'shape: ', data.shape)        
    break

uid shape:  (64,) cpu(0)
hist shape:  (64, 225) cpu(0)
hist_cate shape:  (64, 225) cpu(0)
pre shape:  (64,) cpu(0)
cate shape:  (64,) cpu(0)
label shape:  (64,) cpu(0)
sl shape:  (64,) cpu(0)
uid shape:  (512,)
hist shape:  (512, 350)
hist_cate shape:  (512, 350)
pre shape:  (512,)
cate shape:  (512,)
label shape:  (512,)
sl shape:  (512,)


## 建立模型

In [5]:
class Model(nn.Block):
    def __init__(self,item_count, cate_count, embed_size, num_hiddens, ctx, **kwargs):
        super(Model, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.item_embedding = nn.Embedding(item_count, embed_size)
        self.cate_embedding = nn.Embedding(cate_count, embed_size)
        self.batch_normal_layer = nn.BatchNorm()
        self.dense_layer = nn.Dense(num_hiddens)
        self.encoder = rnn.LSTM(2*num_hiddens)
       
        self.mlp = nn.Sequential()
        self.mlp.add(nn.BatchNorm())
        self.mlp.add(nn.Dense(80, activation='sigmoid'))
        self.mlp.add(nn.Dense(40, activation='sigmoid'))
        self.mlp.add(nn.Dense(1, activation=None))       
    
    def forward(self, item, cate, hist_item, hist_cate, seq_len):
        item = item.reshape((-1))
        seq_len = seq_len.reshape((-1))
        item_idx_emb = self.item_embedding(item)
        cate_idx_emb = self.cate_embedding(cate)
        item_emb = nd.concat(item_idx_emb, cate_idx_emb, dim=1)
        
        hist_item_idx_emb = self.item_embedding(hist_item)
        hist_cate_idx_emb = self.cate_embedding(hist_cate)
        hist_emb = nd.concat(hist_item_idx_emb, hist_cate_idx_emb, dim = -1)
        c = dynamic_rnn(self.encoder, hist_emb.swapaxes(0,1), sequence_length=seq_len)  # [B, H]
        
        h_emb = self.batch_normal_layer(c)
        h_emb = self.dense_layer(h_emb)
        user_emb = h_emb
        din = nd.concat(user_emb, item_emb, dim=-1)
        score = self.mlp(din)
        return score

## 模型训练

In [6]:
loss = gloss.SigmoidBinaryCrossEntropyLoss()
net = Model(item_count, cate_count, embed_size=64, num_hiddens=128, ctx=ctx)
net.initialize(init=init.Xavier(),force_reinit=True, ctx=ctx)

In [6]:
loss_list, auc_list, acc_list = train_din(net, train_iter, 
                                      test_iter,loss,train_batch_size,
                                      0.1, 30, ctx,'../logs',loss_name='loss_rnn_din',
                                      auc_name='auc_rnn_din',acc_name='acc_rnn_din')

auc: 0.4873


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2151), HTML(value='')))




In [7]:
file_name = 'train_result_'+model_name+'.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(loss_list, f,pickle.HIGHEST_PROTOCOL)
    pickle.dump(auc_list, f,pickle.HIGHEST_PROTOCOL)