In [1]:
import torch
from torch import nn,optim
from torch.nn import functional as F
import pandas as pd
import numpy as np
import random
import pickle
import logging
from tqdm import tqdm
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
USE_CUDA = torch.cuda.is_available()

random.seed(2020)
np.random.seed(2020)
torch.manual_seed(2020)
if USE_CUDA:
    torch.cuda.manual_seed(2020)
# set cuda
gpu = 0
use_cuda = gpu >= 0 and torch.cuda.is_available()
if use_cuda:
    torch.cuda.set_device(gpu)
    device = torch.device("cuda", gpu)
else:
    device = torch.device("cpu")
logging.info("Use cuda: %s, gpu id: %d.", use_cuda, gpu)

In [2]:
# 将train_data和test_data读取存入list中
def ReadTxtName(rootdir):
    lines = []
    with open(rootdir, 'r') as file_to_read:
        while True:
            line = file_to_read.readline()
            if not line:
                break
            line = line.strip('\n')
            line = list(eval(line))
            lines.append(line)
    return lines

In [3]:
train_data_path = 'data_example/train_data.txt'
test_data_path = 'data_example/test_data.txt'

In [4]:
train_set = ReadTxtName(train_data_path)
test_set = ReadTxtName(test_data_path)

In [5]:
train_data = pd.DataFrame(columns=['user_id', 'item_seq_temp','cat_list','time_list','time_last_list','time_now_list','position_list','target','item_seq_len'],data = train_set)
test_data = pd.DataFrame(columns=['user_id', 'item_seq_temp','cat_list','time_list','time_last_list','time_now_list','position_list','target','item_seq_len'],data = test_set)

In [6]:
train_data.head()

Unnamed: 0,user_id,item_seq_temp,cat_list,time_list,time_last_list,time_now_list,position_list,target,item_seq_len
0,3522,"[118873, 190989, 73693, 354311, 73693, 354311,...","[571, 894, 482, 482, 482, 482, 482, 202, 330, ...","[398008, 398008, 398104, 398104, 398104, 39810...","[0, 0, 96, 0, 0, 0, 0, 864, 48, 0, 0, 0, 0, 0,...","[1032, 1032, 936, 936, 936, 936, 936, 72, 24, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[125165, 482, 399040]",26
1,3839,"[298668, 297576, 305398, 107015, 32892, 76809,...","[636, 506, 339, 339, 636, 25, 339, 339, 25, 94...","[401704, 401704, 401704, 401704, 401704, 40170...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 4...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[168523, 943, 401752]",23
2,3719,"[211347, 153658, 7223, 81509, 19699, 81509, 10...","[719, 719, 613, 352, 1333, 352, 352, 1333, 120...","[400216, 400216, 400480, 400552, 400552, 40057...","[0, 0, 264, 72, 0, 24, 0, 0, 0, 24, 24, 288, 0...","[1008, 1008, 744, 672, 672, 648, 648, 648, 648...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[329307, 344, 401224]",17
3,4362,"[95057, 68099, 161481, 366635, 150257, 303230,...","[1203, 704, 704, 704, 704, 704, 704, 704, 704,...","[399184, 400192, 400192, 400240, 400240, 40024...","[0, 1008, 0, 48, 0, 0, 0, 0, 0, 384, 0, 0, 72,...","[2520, 1512, 1512, 1464, 1464, 1464, 1464, 146...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[146956, 232, 401704]",19
4,1104,"[228127, 202790, 37771, 210693, 316530, 186423...","[537, 381, 194, 962, 194, 301, 194, 194, 1098,...","[397672, 397984, 398008, 398536, 398656, 39865...","[0, 312, 24, 528, 120, 0, 0, 0, 24, 312, 0, 0,...","[3864, 3552, 3528, 3000, 2880, 2880, 2880, 288...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[111155, 67, 401536]",18


In [7]:
train_data['item_seq_len'] = train_data['item_seq_len']-1

In [8]:
train_data.head()

Unnamed: 0,user_id,item_seq_temp,cat_list,time_list,time_last_list,time_now_list,position_list,target,item_seq_len
0,3522,"[118873, 190989, 73693, 354311, 73693, 354311,...","[571, 894, 482, 482, 482, 482, 482, 202, 330, ...","[398008, 398008, 398104, 398104, 398104, 39810...","[0, 0, 96, 0, 0, 0, 0, 864, 48, 0, 0, 0, 0, 0,...","[1032, 1032, 936, 936, 936, 936, 936, 72, 24, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[125165, 482, 399040]",25
1,3839,"[298668, 297576, 305398, 107015, 32892, 76809,...","[636, 506, 339, 339, 636, 25, 339, 339, 25, 94...","[401704, 401704, 401704, 401704, 401704, 40170...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 4...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[168523, 943, 401752]",22
2,3719,"[211347, 153658, 7223, 81509, 19699, 81509, 10...","[719, 719, 613, 352, 1333, 352, 352, 1333, 120...","[400216, 400216, 400480, 400552, 400552, 40057...","[0, 0, 264, 72, 0, 24, 0, 0, 0, 24, 24, 288, 0...","[1008, 1008, 744, 672, 672, 648, 648, 648, 648...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[329307, 344, 401224]",16
3,4362,"[95057, 68099, 161481, 366635, 150257, 303230,...","[1203, 704, 704, 704, 704, 704, 704, 704, 704,...","[399184, 400192, 400192, 400240, 400240, 40024...","[0, 1008, 0, 48, 0, 0, 0, 0, 0, 384, 0, 0, 72,...","[2520, 1512, 1512, 1464, 1464, 1464, 1464, 146...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[146956, 232, 401704]",18
4,1104,"[228127, 202790, 37771, 210693, 316530, 186423...","[537, 381, 194, 962, 194, 301, 194, 194, 1098,...","[397672, 397984, 398008, 398536, 398656, 39865...","[0, 312, 24, 528, 120, 0, 0, 0, 24, 312, 0, 0,...","[3864, 3552, 3528, 3000, 2880, 2880, 2880, 288...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[111155, 67, 401536]",17


In [9]:
train_data['item_seq_temp'] = train_data.item_seq_temp.apply(lambda x:[i+1 for i in x])

In [10]:
train_data = train_data[:1000]

In [11]:
for i in tqdm(range(len(train_data))):
    train_data['item_seq_temp'][i][train_data.item_seq_len[i]] = 0

100%|███████████████████████████████████| 1000/1000 [00:00<00:00, 25298.44it/s]


In [12]:
train_data.shape

(1000, 9)

In [13]:
input_length = 50
train_data['item_seq_temp'] = train_data['item_seq_temp'].apply(lambda x:x+[0]*(input_length-len(x))if len(x)<input_length else x[:input_length])

In [14]:
# 对train_data和test_data进行padding,即item_count+1作为padding_idx
# item_id,user_id,cat_list_id进行了重新编码
train_data['target_item'] = train_data['target'].apply(lambda x:x[0]) 
test_data['target_item'] = test_data['target'].apply(lambda x:x[0]) 

In [15]:
train_data.head()

Unnamed: 0,user_id,item_seq_temp,cat_list,time_list,time_last_list,time_now_list,position_list,target,item_seq_len,target_item
0,3522,"[118874, 190990, 73694, 354312, 73694, 354312,...","[571, 894, 482, 482, 482, 482, 482, 202, 330, ...","[398008, 398008, 398104, 398104, 398104, 39810...","[0, 0, 96, 0, 0, 0, 0, 864, 48, 0, 0, 0, 0, 0,...","[1032, 1032, 936, 936, 936, 936, 936, 72, 24, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[125165, 482, 399040]",25,125165
1,3839,"[298669, 297577, 305399, 107016, 32893, 76810,...","[636, 506, 339, 339, 636, 25, 339, 339, 25, 94...","[401704, 401704, 401704, 401704, 401704, 40170...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 4...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[168523, 943, 401752]",22,168523
2,3719,"[211348, 153659, 7224, 81510, 19700, 81510, 10...","[719, 719, 613, 352, 1333, 352, 352, 1333, 120...","[400216, 400216, 400480, 400552, 400552, 40057...","[0, 0, 264, 72, 0, 24, 0, 0, 0, 24, 24, 288, 0...","[1008, 1008, 744, 672, 672, 648, 648, 648, 648...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[329307, 344, 401224]",16,329307
3,4362,"[95058, 68100, 161482, 366636, 150258, 303231,...","[1203, 704, 704, 704, 704, 704, 704, 704, 704,...","[399184, 400192, 400192, 400240, 400240, 40024...","[0, 1008, 0, 48, 0, 0, 0, 0, 0, 384, 0, 0, 72,...","[2520, 1512, 1512, 1464, 1464, 1464, 1464, 146...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[146956, 232, 401704]",18,146956
4,1104,"[228128, 202791, 37772, 210694, 316531, 186424...","[537, 381, 194, 962, 194, 301, 194, 194, 1098,...","[397672, 397984, 398008, 398536, 398656, 39865...","[0, 312, 24, 528, 120, 0, 0, 0, 24, 312, 0, 0,...","[3864, 3552, 3528, 3000, 2880, 2880, 2880, 288...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[111155, 67, 401536]",17,111155


In [16]:
# 随便padding
def pad(data):
    for i in data.columns:
        if i =='cat_list':
            data[i] = data[i].apply(lambda x:x+[x[-1]-1]*(input_length-len(x))if len(x)<input_length else x[:input_length])
        elif i=='position_list':
            data[i] = data[i].apply(lambda x:[i for i in range(input_length)])

In [17]:
# input_length = 50
pad(train_data)

In [18]:
# 获得train_data和test_data
columns = ['user_id','item_seq_temp','cat_list','position_list','item_seq_len','target_item']
train_data= train_data[columns]
test_data = test_data[columns]

In [19]:
# 方便处理数据，转换为numpy格式
user_id = np.array(train_data['user_id'].tolist(),dtype = np.int32)
item_seq_temp = np.array(train_data['item_seq_temp'].tolist(),dtype = np.int32)
cat_list = np.array(train_data['cat_list'].tolist(),dtype = np.int32)
position_list = np.array(train_data['position_list'].tolist(),dtype = np.int32)
item_seq_len = np.array(train_data['item_seq_len'].tolist(),dtype = np.int64)
target_item = np.array(train_data['target_item'].tolist(),dtype = np.int32)

# 生成batch

In [20]:
def get_batch(x_user_id,x_item_seq_temp,x_cat_list,x_position_list,x_item_seq_len,y,batch_size,shuffle = True):
    assert x_user_id.shape[0] == y.shape[0]
    if shuffle:
        shuffled_index = np.random.permutation(y.shape[0])
        x_user_id,x_item_seq_len,x_cat_list,x_position_list,x_item_seq_len = x_user_id[shuffled_index],x_item_seq_len[shuffled_index],x_cat_list[shuffled_index],x_position_list[shuffled_index],x_item_seq_len[shuffled_index]
        y = y[shuffled_index]
    
    n_batches = int(x_user_id.shape[0]/batch_size)
    for i in range(n_batches-1):
        x_user_id_batch = x_user_id[i*batch_size:(i+1)*batch_size]
        x_item_seq_temp_batch = x_item_seq_temp[i*batch_size:(i+1)*batch_size]
        x_cat_list_batch = x_cat_list[i*batch_size:(i+1)*batch_size]
        x_position_list_batch = x_position_list[i*batch_size:(i+1)*batch_size]
        x_item_seq_len_batch = x_item_seq_len[i*batch_size:(i+1)*batch_size]
        y_batch = y[i*batch_size:(i+1)*batch_size]
        yield x_user_id_batch,x_item_seq_temp_batch,x_cat_list_batch,x_position_list_batch,x_item_seq_len_batch,y_batch

# 构造模型

In [56]:
class GetEmbedding(nn.Module):
    def __init__(self,parameter_path,input_length,embedding_dim):
        super(GetEmbedding,self).__init__()
        self.parameter_path = parameter_path
        self.parameters = self.get_parameter(self.parameter_path)
        self.user_count = self.parameters['user_count']
        self.item_count = self.parameters['item_count']
        self.category_count = self.parameters['category_count']
        self.input_length = input_length
        self.embedding_dim = embedding_dim
        self.user_id_embedding = nn.Embedding(self.user_count+3,self.embedding_dim)
        self.item_list_embedding = nn.Embedding(self.item_count+3,self.embedding_dim,padding_idx = 0, max_norm = 1.5)
        self.item_list_embedding_weight = self.item_list_embedding.weight
        self.category_list_embedding = nn.Embedding(self.category_count+3,self.embedding_dim)
        self.position_list_embeddig = nn.Embedding(self.input_length,self.embedding_dim)
        self.apply(self.init_weights)
    def init_weights(self, module):
        if isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight.data, mean=0.0, std=0.002)
        elif isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight.data, mean=0.0, std=0.05)
            if module.bias is not None:
                constant_(module.bias.data, 0)
        
    def get_parameter(self,file_path):
        with open(file_path, 'rb') as f:  
            parameters = pickle.loads(f.read())
        return parameters

In [81]:
class STAMP(nn.Module):
    def __init__(self,parameter_path,input_length,embedding_dim,hidden_size,n_layers = 1):
        super(STAMP,self).__init__()
        self.parameter_path = parameter_path
        self.input_length = input_length
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.emb = GetEmbedding(self.parameter_path,self.input_length,self.embedding_dim)
        self.W1 = nn.Linear(self.embedding_dim,self.embedding_dim,bias = False)
        self.W2 = nn.Linear(self.embedding_dim,self.embedding_dim,bias = False)
        self.W3 = nn.Linear(self.embedding_dim,self.embedding_dim,bias = False)
        self.W0 = nn.Linear(self.embedding_dim,1,bias = False)
        self.sigmoid = nn.Sigmoid()
        self.mlp_a = nn.Linear(self.embedding_dim,self.embedding_dim,bias = False)
        self.mlp_b = nn.Linear(self.embedding_dim,self.embedding_dim,bias = False)
        self.tanh = nn.Tanh()
        
        
        
       
    def forward(self,x_item_seq_temp_batch,x_item_seq_len_batch):
        #[batch_size,seq_len,embedding_dim]
        item_list_emb = self.emb.item_list_embedding(x_item_seq_temp_batch)
        #[batch_dize,embedding_dim]
        last_inputs = self.gather_indexes(item_list_emb,x_item_seq_len_batch-1)
        org_memory = item_list_emb
        #[batch_size,embedding_dim]
        ms = torch.div(torch.sum(org_memory,dim = 1),x_item_seq_len_batch.unsqueeze(1).float())
        alpha = self.count_alpha(org_memory,last_inputs,ms)
#         [batch_size,1,dim]
        vec = torch.matmul(alpha.unsqueeze(1),org_memory)
        ma = vec.squeeze(1)+ms
        hs = self.tanh(self.mlp_a(ma))
        ht = self.tanh( self.mlp_b(last_inputs))
        pred = hs * ht
        
        return pred,self.emb.item_list_embedding_weight
        
        
    def count_alpha(self,context,aspect,output):
        """
        context:orgmemory  [batch_size,seq_len,embedding_dim]
        aspect:last_inputs [batch_size,embedding_dim]
        output:ms: [batch_size,embedding_dim]
        ctx_bitamp:[batch_size,seq_len]  : use to remove influence of padding
        """
        timesteps = context.size()[1]
        aspect_3dim = aspect.repeat(1,timesteps).view(-1,timesteps,self.embedding_dim)
        output_3dim = output.repeat(1,timesteps).view(-1,timesteps,self.embedding_dim)
        res_ctx = self.W1(context)
        res_asp = self.W2(aspect_3dim)
        res_output = self.W3(output_3dim)
        res_sum = res_ctx + res_asp + res_output 
        res_act = self.W0(self.sigmoid(res_sum))
        alpha = res_act.squeeze(2)
        return alpha
        
        
    
    def gather_indexes(self, output, gather_index):
        "Gathers the vectors at the spexific positions over a minibatch"
        gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, self.embedding_dim)
        output_tensor = output.gather(dim=1, index=gather_index)
        return output_tensor.squeeze(1)     

In [82]:
class Myloss(nn.Module):
    def __init__(self):
        super(Myloss,self).__init__()
        self.log_softmax = nn.LogSoftmax()
    def forward(self,emb,pred,truth):
        item_lookup_table_T = emb.t()
        logits = torch.matmul(pred,item_lookup_table_T)
        log_probs = self.log_softmax(logits)
        truth = torch.reshape(truth,[-1])
        one_hot_labels = F.one_hot(truth, num_classes=emb.shape[0])
        loss_origin = -torch.sum(log_probs.float() * one_hot_labels.float(), dim=-1)
        loss = torch.mean(loss_origin)
        return loss    

In [83]:
# train
batch_size = 10
parameter_path = 'data_example/parameters.pkl'
input_length = 50
embedding_dim = 32
hidden_size=16
n_layers = 1
stamp = STAMP(parameter_path,input_length,embedding_dim,hidden_size,n_layers = 1)
stamp = stamp.to(device)
stamp.train()
criterion = Myloss()
train_loss = []
optimizer = optim.Adam(stamp.parameters(),lr=0.001)
for epoch in range(10):
    for batch_idx,data in enumerate(get_batch(user_id,item_seq_temp,cat_list,position_list,item_seq_len,target_item,batch_size)):
        x_user_id_batch = torch.LongTensor(data[0])
        x_item_seq_temp_batch = torch.LongTensor(data[1])
        x_cat_list_batch = torch.LongTensor(data[2])
        x_position_list_batch = torch.LongTensor(data[3])
        x_item_seq_len_batch = torch.LongTensor(data[4])
        y_batch = torch.LongTensor(data[5])
        pred,item_list_emb_weight = stamp(x_item_seq_temp_batch,x_item_seq_len_batch)
#         print('item_list_emb.shape',item_list_emb_weight.shape)
        loss = criterion(item_list_emb_weight,pred,y_batch)
        train_loss.append(loss.item())
        # Backward and optimizer
        # 1.优化器保存先前的梯度信息
        optimizer.zero_grad()
        # 2.计算梯度
        loss.backward()
#         3.梯度更新 w' = w - lr*grads
        optimizer.step()
        if batch_idx%200 == 0:
            print(epoch,batch_idx,loss.item())

  


0 0 12.823461532592773
1 0 12.823461532592773
2 0 12.820480346679688
3 0 8.538239479064941
4 0 7.723677158355713
5 0 7.875708103179932
6 0 7.633425712585449
7 0 6.969459533691406
8 0 7.026850700378418
9 0 7.232846260070801
