In [1]:
import numpy as np
import tensorflow as tf

In [2]:
def dense_to_sparse(dense_arr):
    arr_idx=tf.where(tf.not_equal(dense_arr,1))
    arr_value=tf.gather_nd(dense_arr,indices=arr_idx)
    return tf.sparse.SparseTensor(indices=arr_idx,values=arr_value,dense_shape=dense_arr.shape)
dense_arr=np.random.choice([0,1],size=[5,2,3])

sparse_vec=dense_to_sparse(dense_arr)
print(sparse_vec.indices)

tf.Tensor(
[[0 0 0]
 [0 1 0]
 [1 0 1]
 [1 1 1]
 [2 0 1]
 [2 0 2]
 [2 1 0]
 [2 1 1]
 [2 1 2]
 [3 0 1]
 [3 0 2]
 [3 1 1]
 [3 1 2]
 [4 0 1]
 [4 1 0]
 [4 1 2]], shape=(16, 3), dtype=int64)


In [None]:
# 处理输入如果有one hot 和multi hot两种类型的embedding
class InputProcessLayer(tf.keras.layers.Layer):
    def __init__(self,feature_num,emb_size,*args,**kwargs):
        super(InputProcessLayer,self).__init__(*args,**kwargs)
        self.emb_table=tf.Variable(tf.random.truncated_normal(shape=[feature_num,emb_size]))

    def call(self, inputs, **kwargs):
        """
        将输入分成两部分，第一部分是one-hot的feature，第二部分是multi-hot的feature向量
        该代码中只用了一个非常大的emb_table，可以修改
        :param inputs: tensor: ((batch_size, dense_one_hot_features), ((batch_size, multi_hot_features), ...))
            例如[([[1,2,3],[4,5,6]]), ([[0,1,1,0],[1,0,0,1]], [[0,1,1,0],[1,0,0,1]])]
            分为两部分，第一部分为全部的one-hot类型的dense特征，第二部分包含多个multi-hot的稀疏te'zheng
        :param kwargs:
        :return:
        """
        one_hot_batch,multi_hot_batches=inputs
        one_hot_emb_vector=tf.nn.embedding_lookup(params=self.emb_table,ids=one_hot_batch)

        multi_hot_emb_vectors=list()
        for multi_hot_batch in multi_hot_batches:
            multi_hot_sparse_batch=dense_to_sparse(multi_hot_batch)
            multi_hot_emb_vector=tf.nn.embedding_lookup_sparse(params=self.emb_table,sp_ids=multi_hot_sparse_batch,sp_weights=None)
            multi_hot_emb_vectors.append(multi_hot_emb_vector)
        return one_hot_emb_vector,multi_hot_emb_vectors

one_hot_arr=np.random.randint(0,3,size=[10,3])
one_hot_ds=tf.data.Dataset.from_tensor_slices(one_hot_arr)
multi_hot_arr1=np.random.randint(3,5,size=[10,2])
multi_hot_arr2=np.random.randint(6,9,size=[10,3])
multi_hot_ds=tf.data.Dataset.from_tensor_slices((multi_hot_arr1,multi_hot_arr2))

ds=tf.data.Dataset.zip((one_hot_ds,multi_hot_ds))
batched_ds = ds.batch(4)

iterator=iter(batched_ds)
one_batch = next(iterator)
one_batch = next(iterator)
one_batch = next(iterator)

input_process_layer=InputProcessLayer(feature_num=10,emb_size=2)
input_process_layer(one_batch)
