In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [43]:
EPS = 0.00001

In [78]:
class EntityEmbeddingLayer(nn.Module):
    def __init__(self, num_level,embedding_dim,centroid):
        super(EntityEmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(num_level, embedding_dim)
        self.centroid = torch.tensor(centroid).detach_().unsqueeze(1)
    def forward(self, x):
        x = x.unsqueeze(1)
        d = 1.0 / ((x - self.centroid).abs() + EPS)
        w = F.softmax(d.squeeze(2),1)
        v = torch.mm(w,self.embedding.weight)
        return v

In [62]:
num_level = 10
embedding_dim = 5
embedding = nn.Embedding(num_level,embedding_dim) # 每个level（分组）的embedding vector

In [46]:
batch_size = 4
x = torch.randn(batch_size, 1)
centroid = torch.randn(num_level, 1)

In [47]:
embedding.weight.shape

torch.Size([10, 5])

In [48]:
x.shape

torch.Size([4, 1])

In [49]:
centroid.shape

torch.Size([10, 1])

In [50]:
x = x.unsqueeze(2) # 注意，如果不增加维度，这里不满足broadcast的机制，从后往前看，要么没有，要么为1，要么相同
x.shape

torch.Size([4, 1, 1])

In [51]:
(x - centroid).shape

torch.Size([4, 10, 1])

In [52]:
d = 1.0 / ((x - centroid).abs() +EPS)
d.shape

torch.Size([4, 10, 1])

In [53]:
w = F.softmax(d,1)
w.shape

torch.Size([4, 10, 1])

In [54]:
w.sum()

tensor(4.)

In [55]:
v = torch.mm(w.squeeze(2),embedding.weight)
v.shape

torch.Size([4, 5])

In [79]:
# Test
x =  torch.randn(batch_size, 1)
centroid = torch.randn(num_level)
entity_embedding = EntityEmbeddingLayer(num_level, embedding_dim, centroid)
entity_embedding(x)

  """


tensor([[-0.3525,  0.5941, -0.9614,  1.1326, -0.2228],
        [-0.1477, -0.0731, -0.0249,  0.7893, -0.0524],
        [-2.7144, -1.4985, -0.7244, -0.1882,  0.2108],
        [-0.1037,  0.0973, -0.0225,  0.0235,  0.2428]], grad_fn=<MmBackward>)