In [1]:
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import data as gdata, loss as gloss, nn, rnn
from sklearn import metrics
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import data as gdata, loss as gloss, nn
import pickle
import random
import time
from tqdm import tqdm, tqdm_notebook
from mxboard import SummaryWriter
import sys
sys.path.append('/data/CaoZhong/utils/')
from my_utils import *

# 加载数据

In [2]:
ctx = mx.gpu(7)
train_batch_size = 64
test_batch_size = 512
model_name = 'dien_base'
data_path = '/data/CaoZhong/data/dien/dataset_sub_gluon.pkl'

In [3]:
data_utils = DienDataIter(data_path, train_batch_size, test_batch_size)
train_iter, test_iter = data_utils.get_data_iter()
user_count, item_count, cate_count = data_utils.get_count()

user count: 1053	item count: 63001	cate count: 801
train set len:  129888
test set len:  2106
max history len: 431


In [4]:
# for batch in train_iter:
#     for name, data in zip(['uid','hist','hist_cate','hist_item_neg','hist_cate_neg','pre','pre_cate','label','sl'], batch):
#         print(name, 'shape: ', data.shape)    
#     print(batch[-1])
#     break
# for batch in test_iter:
#     for name, data in zip(['uid','hist','hist_cate','hist_item_neg','hist_cate_neg','pre','pre_cate','label','sl'], batch):
#         print(name, 'shape: ', data.shape)        
#     break

# 建立模型

In [8]:
class AuxiliaryNet(nn.Block):
    def __init__(self, **kwargs):
        super(AuxiliaryNet, self).__init__(**kwargs)
        self.aux_net = nn.Sequential()
        self.aux_net.add(nn.BatchNorm())
        self.aux_net.add(nn.Dense(100, activation='sigmoid', flatten=False))
        self.aux_net.add(nn.Dense(500, activation='sigmoid', flatten=False))
        self.aux_net.add(nn.Dense(1, activation='sigmoid', flatten=False))
        
    """
    Parameters: 
    
    """    
    def forward(self, h_states, click_hist_item, noclick_hist_item, sl):
        click_input = nd.concat(h_states, click_hist_item, dim=-1)        # [B, T-1, H]
        noclick_input =  nd.concat(h_states, noclick_hist_item, dim=-1)   # [B, T-1, H]
        click_prop = self.aux_net(click_input)                            # [B, T-1, 1]
        noclick_prop = self.aux_net(noclick_input)
        sl = sl - 1
        click_loss = nd.SequenceMask(-nd.log(click_prop).swapaxes(0,1), sequence_length=sl, use_sequence_length=True) # [T-1, B, H]
        click_loss = click_loss.swapaxes(0,1).reshape((-1,h_states.shape[1]))                                        # [B, T-1]
        noclick_loss = nd.SequenceMask(-nd.log(1-noclick_prop).swapaxes(0,1), sequence_length=sl, use_sequence_length=True)
        noclick_loss = noclick_loss.swapaxes(0,1).reshape((-1,h_states.shape[1]))
        aux_loss = (click_loss + noclick_loss).mean(axis=-1)
        return aux_loss

In [9]:
class Model(nn.Block):
    def __init__(self, n_uid, n_mid, n_cat, embed_size, hidden_size, attention_size, ctx, **kwargs):
        super(Model, self).__init__(**kwargs)
        self.hidden_size = hidden_size
        self.attention_size = attention_size
        self.embed_size = embed_size
        self.uid_embedding = nn.Embedding(n_uid, 2*embed_size)
        self.item_embedding = nn.Embedding(n_mid, embed_size)
        self.cate_embedding = nn.Embedding(n_cat, embed_size)
        self.rnn1 = rnn.GRU(hidden_size)
        self.rnn2 = rnn.GRU(hidden_size)
        self.attention_layer = Attention(attention_size)
        self.aux_net = AuxiliaryNet()
#         self.aux_net = nn.Sequential()
#         self.aux_net.add(nn.BatchNorm())
#         self.aux_net.add(nn.Dense(100, activation='sigmoid', flatten=False))
#         self.aux_net.add(nn.Dense(500, activation='sigmoid', flatten=False))
#         self.aux_net.add(nn.Dense(1, activation='sigmoid', flatten=False))

        
        
        self.mlp = nn.Sequential()
        self.mlp.add(nn.BatchNorm())
        self.mlp.add(nn.Dense(200, activation='relu'))
        self.mlp.add(nn.Dense(80, activation='relu'))
        self.mlp.add(nn.Dense(1, activation=None))

    def forward(self, uid, hist_item, hist_cate, noclk_hist_item, noclk_hist_cate,item, cate, seq_len):
        uid_embed = self.uid_embedding(uid)
        item_idx_embed = self.item_embedding(item)
        cate_idx_embed = self.cate_embedding(cate)
        item_embed = nd.concat(item_idx_embed, cate_idx_embed, dim=-1)
        
        hist_item_idx_embed = self.item_embedding(hist_item)
        hist_cate_idx_embed = self.cate_embedding(hist_cate)
        hist_item_embed = nd.concat(hist_item_idx_embed, hist_cate_idx_embed, dim=-1)  # [B, T, 2E]
        hist_item_sum = nd.SequenceMask(hist_item_embed.swapaxes(0, 1), sequence_length=seq_len.reshape((-1)), use_sequence_length=True) # [T, B, H]
        hist_item_sum = nd.sum(hist_item_sum, axis=0)  # [B, 2E]
        
        noclk_hist_item = self.item_embedding(noclk_hist_item)
        noclk_hist_cate = self.item_embedding(noclk_hist_cate)
        noclk_hist = nd.concat(noclk_hist_item, noclk_hist_cate, dim=-1)
        noclk_hist_item_embed = noclk_hist[:,:,0,:]
        noclk_hist_item_embed = noclk_hist_item_embed.reshape((-1, hist_item_embed.shape[1],hist_item_embed.shape[-1]))  # [B, T, 2E]

        rnn_outputs = self.rnn1(hist_item_embed.swapaxes(0,1))  # [T, B, H]
        
        rnn_out2 = self.rnn2(rnn_outputs)
        
        c = self.attention_layer(rnn_out2, item_embed, seq_len)
        
        interset_state = rnn_outputs.swapaxes(0,1)[:,:-1,:]               # [B, T-1, H]
        click_item = hist_item_embed[:,1:,:]
        noclick_item = noclk_hist_item_embed[:,1:,:]
        
        aux_loss = self.aux_net(interset_state, click_item, noclick_item,seq_len)
        
#         click_input = nd.concat(interset_state, click_item, dim=-1)       # [B, T-1, H]
#         noclick_input =  nd.concat(interset_state, noclick_item,dim=-1)   # [B, T-1, H]
#         click_prop = self.aux_net(click_input)                            # [B, T-1, 1]
#         noclick_prop = self.aux_net(noclick_input)
#         seq_len = seq_len - 1
#         click_loss = nd.SequenceMask(-nd.log(click_prop).swapaxes(0,1), sequence_length=seq_len, use_sequence_length=True) # [T-1, B, H]
#         click_loss = click_loss.swapaxes(0,1).reshape((-1,interset_state.shape[1]))                                        # [B, T-1]
#         noclick_loss = nd.SequenceMask(-nd.log(1-noclick_prop).swapaxes(0,1), sequence_length=seq_len, use_sequence_length=True)
#         noclick_loss = noclick_loss.swapaxes(0,1).reshape((-1,interset_state.shape[1]))
#         aux_loss = (click_loss + noclick_loss).mean(axis=-1)
        
        inp = nd.concat(uid_embed, item_embed,hist_item_sum, item_embed*hist_item_sum, c, dim=1)
        score = self.mlp(inp)
        score = score.reshape((-1))
        return aux_loss, score
            

# 模型训练

In [10]:
loss = gloss.SigmoidBinaryCrossEntropyLoss()
net = Model(user_count, item_count, cate_count, embed_size=16, hidden_size=32, attention_size=16, ctx=ctx)
net.initialize(init=init.Xavier(),force_reinit=True, ctx=ctx)

In [11]:
loss_list,auc_list,acc_list = train_dien(net, train_iter, test_iter, loss, train_batch_size, 0.1, 30, ctx, './logs', 'loss_base_dien','auc_base_dien',acc_name='acc_base_dien')

auc: 0.5027	acc:0.5052


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2030), HTML(value='')))




In [12]:
file_name = model_name+'_loss_auc_acc.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(loss_list, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(auc_list, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(acc_list, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(model_name, f, pickle.HIGHEST_PROTOCOL)
net.save_parameters(model_name+'.net')