Sinusoidal Embedding

In [17]:
import math

import torch
from torch import nn

In [18]:
class TimeEmbedding(nn.Module):
    """Input: t [B]
    Output: t_embedd [B, F]"""
    
    def __init__(self, dim): # set F as an EVEN number!
        super().__init__()
        self.dim = dim

    def forward(self, x): # x needs to be of type torch.tensor with length N!
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb



In [19]:
t = torch.tensor([0.4, 0.6])
time = TimeEmbedding(dim=5)
t_embedding = time(t)
print(t_embedding)

tensor([[3.8942e-01, 4.0000e-05, 9.2106e-01, 1.0000e+00],
        [5.6464e-01, 6.0000e-05, 8.2534e-01, 1.0000e+00]])
