# Entity Embedding

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
EPS = 1e-5

## Note

This is the final implementation, but not the process.

In [3]:
class EntityEmbeddingLayer(nn.Module):
    def __init__(self, num_level, embedding_dim, centroid):
        super(EntityEmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(num_level, embedding_dim)
        self.centroid = torch.tensor(centroid).detach_().unsqueeze(1)

    def forward(self, x):
        """
        x must be batch_size times 1
        :param x:
        :return:
        """
        x = x.unsqueeze(1)
        d = 1.0 / ((x - self.centroid).abs() + EPS)
        w = F.softmax(d.squeeze(2), 1)
        v = torch.mm(w, self.embedding.weight)
        return v

## Experimentation

This is part where you experiment with the data

In [4]:
num_level = 10
embedding_dim = 5
embedding = nn.Embedding(num_level, embedding_dim)

In [19]:
batch_size = 4
x = torch.randn(batch_size, 1)
centroid = torch.randn(num_level, 1)

In [20]:
embedding.weight.shape

torch.Size([10, 5])

In [21]:
x.shape

torch.Size([4, 1])

In [28]:
(x - centroid).shape

torch.Size([4, 10, 1])

In [23]:
x = x.unsqueeze(1)

In [24]:
x.shape

torch.Size([4, 1, 1])

In [25]:
centroid = centroid.unsqueeze(0)

In [26]:
centroid.shape

torch.Size([1, 10, 1])

In [29]:
d = 1.0 / ((x-centroid).abs() + EPS)

In [31]:
d.shape

torch.Size([4, 10, 1])

In [33]:
w = F.softmax(d.squeeze(2), 1)

In [34]:
w.shape

torch.Size([4, 10])

In [35]:
w.sum()

tensor(4.)

In [36]:
v = torch.mm(w, embedding.weight)

In [38]:
v.shape

torch.Size([4, 5])

## Testing

In [39]:
x = torch.randn(batch_size, 1)
centroid = torch.randn(num_level)
entity_embedding = EntityEmbeddingLayer(num_level, embedding_dim, centroid)

  self.centroid = torch.tensor(centroid).detach_().unsqueeze(1)


In [40]:
entity_embedding(x)

tensor([[ 0.0774,  0.8502, -1.1379,  1.5159, -0.0448],
        [ 0.0897,  0.8454, -1.1106,  1.4162, -0.0494],
        [ 1.0609, -1.1119, -1.2553, -0.4212, -0.3883],
        [-0.3805, -1.1960,  2.2615, -2.0092,  0.9339]], grad_fn=<MmBackward0>)