In [1]:
import torch
import torch.nn as nn
import math

In [2]:
def get_timestep_embedding(timesteps, embedding_dim: int):
    """
    Retrieved from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py#LL90C1-L109C13
    Retrieved from https://www.udemy.com/course/diffusion-models/learn/lecture/37971218#overview
    """

    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2

    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = timesteps.type(torch.float32) [:, None] *emb[None, :]
    emb = torch.concat([torch.sin(emb), torch.cos(emb)], axis=1)

    print(emb.shape)

    if embedding_dim % 2 == 1: # zero pad
        emb = torch.pad(emb, [[0, 0], [0, 1]])
    
    assert emb.shape == (timesteps.shape[0], embedding_dim), f"{emb.shape}"

    return emb

In [3]:
t = (torch.rand(100) * 10).long()

In [4]:
type(t.shape)

torch.Size

In [5]:
get_timestep_embedding(t, 64)

torch.Size([100, 64])


tensor([[ 0.4121,  0.3926, -0.9675,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.1411,  0.7912,  0.9964,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.9093,  0.9964,  0.8930,  ...,  1.0000,  1.0000,  1.0000],
        ...,
        [ 0.6570, -0.8831, -0.6612,  ...,  1.0000,  1.0000,  1.0000],
        [-0.9589, -0.5423,  0.3724,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.8415,  0.6765,  0.5244,  ...,  1.0000,  1.0000,  1.0000]])

In [6]:
print(torch.range.__doc__)


range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor

Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1`
with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is
the gap between two values in the tensor.

.. math::
    \text{out}_{i+1} = \text{out}_i + \text{step}.

    This function is deprecated and will be removed in a future release because its behavior is inconsistent with
    Python's range builtin. Instead, use :func:`torch.arange`, which produces values in [start, end).

Args:
    start (float): the starting value for the set of points. Default: ``0``.
    end (float): the ending value for the set of points
    step (float): the gap between each pair of adjacent points. Default: ``1``.

Keyword args:
    out (Tensor, optional): the output tensor.
    dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor