In [1]:

import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as nn

from layer import attention, dice, AUGRU
from utils import sequence_mask

attention过程

In [2]:
# 初始化层
dim_layers = [36,1]
fc = tf.keras.Sequential()
for dim_layer in dim_layers[:-1]:
    fc.add(nn.Dense(dim_layer, activation='sigmoid'))
fc.add(nn.Dense(dim_layers[-1], activation=None))

传入参数：其中queries是 item_join_emb,keys是hist_join_emb
keys_length 是每个batch中user的序列长度

In [15]:
# 构造虚拟数据
queries  = np.random.rand(32,8+6)
keys = np.random.rand(32,20,8+6)
keys_length = np.random.randint(1,21,(32,))

In [16]:
# 为了实现queries和keys的类似Attention查询机制，将queries进行复制，使得二者维度相同
queries = tf.tile(tf.expand_dims(queries, 1), [1, tf.shape(keys)[1], 1])

In [17]:
# def printshape(name):
#     print(name.shape)
  
# for i in [queries, keys, queries-keys, queries*keys]:
#     printshape(i)

In [18]:
din_all  = tf.concat([queries, keys, queries-keys, queries*keys], axis=-1)
din_all.shape

TensorShape([32, 20, 56])

In [20]:
# 将数据丢入全连接层训练
outputs = tf.transpose(fc(din_all), [0,2,1])
outputs.shape

TensorShape([32, 1, 20])

In [29]:
# 训练结果中有些是无效的history产生的，将其mask掉
# 不能采用base模型的mask机制，无效值=0，会使得在最后的softmax中得到最终值为0.5。
key_masks = tf.sequence_mask(keys_length, max(keys_length), dtype=tf.float32)  # [B, T]
key_masks = tf.expand_dims(key_masks, 1)
key_masks.shape

TensorShape([32, 1, 20])

In [15]:
paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
paddings.shape

TensorShape([32, 20, 1])

In [16]:
# 执行mask,无效的值用这个填充-2 ** 32 + 1，后期softmax时将会变成0
outputs  = tf.where(key_masks, outputs, paddings) 

In [21]:
outputs.shape

TensorShape([32, 1, 20])

In [23]:
# scale
outputs = outputs / ((6+8) ** 0.5)

In [24]:
# Activation
outputs = tf.keras.activations.softmax(outputs, -1)
outputs.shape

TensorShape([32, 1, 20])

In [18]:
# 进行数据类型转换
outputs = tf.cast(outputs,dtype=tf.float32)
keys = tf.cast(keys,dtype=tf.float32)


In [19]:
# Weighted sum
outputs = tf.squeeze(tf.matmul(outputs, keys))  # [B, H]

In [20]:
outputs

<tf.Tensor: id=133, shape=(32, 14), dtype=float32, numpy=
array([[0.48441058, 0.5810695 , 0.5439096 , 0.36418653, 0.36395156,
        0.4636144 , 0.60085315, 0.56507784, 0.51698405, 0.48451677,
        0.44668573, 0.48170775, 0.4257315 , 0.57143736],
       [0.5191581 , 0.6042936 , 0.4374274 , 0.39265877, 0.28725877,
        0.40437052, 0.45542845, 0.5012349 , 0.43098414, 0.34788415,
        0.41686141, 0.59862727, 0.45880666, 0.48151138],
       [0.5191933 , 0.47186404, 0.4616608 , 0.49709198, 0.42934102,
        0.53557587, 0.56539303, 0.4764052 , 0.53780836, 0.4523728 ,
        0.4940291 , 0.55269784, 0.41719928, 0.56797355],
       [0.54032695, 0.8907372 , 0.1367836 , 0.6305817 , 0.52380645,
        0.38354266, 0.26662576, 0.31388074, 0.26164034, 0.4550085 ,
        0.54338574, 0.45832253, 0.39412034, 0.75757045],
       [0.48735633, 0.5731554 , 0.63998187, 0.61434305, 0.5733051 ,
        0.5038809 , 0.62862283, 0.47936112, 0.4204821 , 0.5654786 ,
        0.5793242 , 0.30579573, 0.