In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import math

In [7]:
class PeriodicEmbeddings(nn.Module):
    def __init__(self, d_in, d_embedding, n_frequencies=16, max_frequency=10.0):
        super().__init__()
        self.d_in = d_in
        self.n_frequencies = n_frequencies

        freq_bands = torch.linspace(1.0, max_frequency, n_frequencies)
        self.register_buffer('frequencies', freq_bands[None, :].repeat(d_in, 1))

        self.norm = nn.LayerNorm(d_in)
        self.linear = nn.Linear(d_in * 2 * n_frequencies, d_embedding)

    def forward(self, x):
        x = self.norm(x)
        
        print(x.shape)
        x_proj = 2 * math.pi * x.unsqueeze(-1) * self.frequencies
        print(x_proj.shape)
        x_pe = torch.cat([x_proj.sin(), x_proj.cos()], dim=-1)
        print(x_pe.shape)
        return self.linear(x_pe.view(x.size(0), -1))


def make_mlp(in_dim, out_dim, hidden_dim, num_layers, activation='ReLU', dropout=0.0):
    act_layer = getattr(nn, activation)
    layers = [nn.Linear(in_dim, hidden_dim), act_layer(), nn.Dropout(dropout)]
    for _ in range(num_layers):
        layers += [nn.Linear(hidden_dim, hidden_dim), act_layer(), nn.Dropout(dropout)]
    layers.append(nn.Linear(hidden_dim, out_dim))
    return nn.Sequential(*layers)


class MLP_PLR(nn.Module):
    def __init__(
        self,
        input_size=28 * 28,
        num_classes=10,
        hidden_dim=128,
        num_layers=2,
        embedding_type='periodic',
        d_embedding=128,
        n_frequencies=32,
        max_frequency=10.0,
        activation='ReLU',
        dropout=0.0,
    ):
        super().__init__()

        if embedding_type == 'periodic':
            self.embedding = PeriodicEmbeddings(
                d_in=input_size,
                d_embedding=d_embedding,
                n_frequencies=n_frequencies,
                max_frequency=max_frequency,
            )
            embedding_out_dim = d_embedding
        elif embedding_type == 'none':
            self.embedding = nn.Identity()
            embedding_out_dim = input_size
        else:
            raise ValueError(f"Unknown embedding_type: {embedding_type}")

        self.network = make_mlp(
            in_dim=embedding_out_dim,
            out_dim=num_classes,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            activation=activation,
            dropout=dropout,
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.embedding(x)
        return self.network(x)


In [8]:
periodic_emb = PeriodicEmbeddings(d_in=5, d_embedding=10)

In [9]:
x = torch.randn(1,5)

In [10]:
y = periodic_emb(x)
y.shape

torch.Size([1, 5])
torch.Size([1, 5, 16])
torch.Size([1, 5, 32])


torch.Size([1, 10])

In [11]:
from torch import Tensor
import statistics
from dataclasses import dataclass
from typing import Any, Callable, Literal, cast


def cos_sin(x: Tensor) -> Tensor:
    return torch.cat([torch.cos(x), torch.sin(x)], -1)

class PeriodicOptions:
    n: int  # the output size is 2 * n
    sigma: float
    trainable: bool
    initialization: Literal['log-linear', 'normal']


class Periodic(nn.Module):
    def __init__(self, n_features: int, options: PeriodicOptions) -> None:
        super().__init__()
        if options.initialization == 'log-linear':
            coefficients = options.sigma ** (torch.arange(options.n) / options.n)
            coefficients = coefficients[None].repeat(n_features, 1)
        else:
            assert options.initialization == 'normal'
            coefficients = torch.normal(0.0, options.sigma, (n_features, options.n))
        if options.trainable:
            self.coefficients = nn.Parameter(coefficients)  # type: ignore[code]
        else:
            self.register_buffer('coefficients', coefficients)

    def forward(self, x: Tensor) -> Tensor:
        assert x.ndim == 2
        return cos_sin(2 * torch.pi * self.coefficients[None] * x[..., None])

In [13]:
options = PeriodicOptions()
options.n = 10
options.sigma = 0.3
options.trainable = True
options.initialization = 'normal' 
periodic_emb = Periodic(n_features = 5 , options= options )

In [15]:
y = periodic_emb(x)
y.shape

torch.Size([1, 5, 20])

In [None]:
from torch.nn.parameter import Parameter


def _check_input_shape(x: Tensor, expected_n_features: int) -> None:
    if x.ndim < 1:
        raise ValueError(
            f'The input must have at least one dimension, however: {x.ndim=}'
        )
    if x.shape[-1] != expected_n_features:
        raise ValueError(
            'The last dimension of the input was expected to be'
            f' {expected_n_features}, however, {x.shape[-1]=}'
        )
class _Periodic(nn.Module):
    """
    NOTE: THIS MODULE SHOULD NOT BE USED DIRECTLY.

    Technically, this is a linear embedding without bias followed by
    the periodic activations. The scale of the initialization
    (defined by the `sigma` argument) plays an important role.
    """

    def __init__(self, n_features: int, k: int, sigma: float) -> None:
        if sigma <= 0.0:
            raise ValueError(f'sigma must be positive, however: {sigma=}')

        super().__init__()
        self._sigma = sigma
        self.weight = Parameter(torch.empty(n_features, k))
        self.reset_parameters()

    def reset_parameters(self):
        """Reset the parameters."""
        # NOTE[DIFF]
        # Here, extreme values (~0.3% probability) are explicitly avoided just in case.
        # In the paper, there was no protection from extreme values.
        bound = self._sigma * 3
        nn.init.trunc_normal_(self.weight, 0.0, self._sigma, a=-bound, b=bound)

    def forward(self, x: Tensor) -> Tensor:
        """Do the forward pass."""
        _check_input_shape(x, self.weight.shape[0])
        x = 2 * math.pi * self.weight * x[..., None]
        x = torch.cat([torch.cos(x), torch.sin(x)], -1)
        return x


# _NLinear is a simplified copy of delu.nn.NLinear:
# https://yura52.github.io/delu/stable/api/generated/delu.nn.NLinear.html
class _NLinear(nn.Module):
    """N *separate* linear layers for N feature embeddings.

    In other words,
    each feature embedding is transformed by its own dedicated linear layer.
    """

    def __init__(
        self, n: int, in_features: int, out_features: int, bias: bool = True
    ) -> None:
        super().__init__()
        self.weight = Parameter(torch.empty(n, in_features, out_features))
        self.bias = Parameter(torch.empty(n, out_features)) if bias else None
        self.reset_parameters()

    def reset_parameters(self):
        """Reset the parameters."""
        d_in_rsqrt = self.weight.shape[-2] ** -0.5
        nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Do the forward pass."""
        if x.ndim != 3:
            raise ValueError(
                '_NLinear supports only inputs with exactly one batch dimension,'
                ' so `x` must have a shape like (BATCH_SIZE, N_FEATURES, D_EMBEDDING).'
            )
        assert x.shape[-(self.weight.ndim - 1) :] == self.weight.shape[:-1]

        x = x.transpose(0, 1)
        x = x @ self.weight
        x = x.transpose(0, 1)
        if self.bias is not None:
            x = x + self.bias
        return x


class PeriodicEmbeddings(nn.Module):
    """Embeddings for continuous features based on periodic activations.

    See README for details.

    **Shape**

    - Input: `(*, n_features)`
    - Output: `(*, n_features, d_embedding)`

    **Examples**

    >>> batch_size = 2
    >>> n_cont_features = 3
    >>> x = torch.randn(batch_size, n_cont_features)
    >>>
    >>> d_embedding = 24
    >>> m = PeriodicEmbeddings(n_cont_features, d_embedding, lite=False)
    >>> m(x).shape
    torch.Size([2, 3, 24])
    >>>
    >>> m = PeriodicEmbeddings(n_cont_features, d_embedding, lite=True)
    >>> m(x).shape
    torch.Size([2, 3, 24])
    >>>
    >>> # PL embeddings.
    >>> m = PeriodicEmbeddings(n_cont_features, d_embedding=8, activation=False, lite=False)
    >>> m(x).shape
    torch.Size([2, 3, 8])
    """  # noqa: E501

    def __init__(
        self,
        n_features: int,
        d_embedding: int = 24,
        *,
        n_frequencies: int = 48,
        frequency_init_scale: float = 0.01,
        activation: bool = True,
        lite: bool,
    ) -> None:
        """
        Args:
            n_features: the number of features.
            d_embedding: the embedding size.
            n_frequencies: the number of frequencies for each feature.
                (denoted as "k" in Section 3.3 in the paper).
            frequency_init_scale: the initialization scale for the first linear layer
                (denoted as "sigma" in Section 3.3 in the paper).
                **This is an important hyperparameter**, see README for details.
            activation: if `False`, the ReLU activation is not applied.
                Must be `True` if ``lite=True``.
            lite: if True, the outer linear layer is shared between all features.
                See README for details.
        """
        super().__init__()
        self.periodic = _Periodic(n_features, n_frequencies, frequency_init_scale)
        self.linear: Union[nn.Linear, _NLinear]
        if lite:
            # NOTE[DIFF]
            # The lite variation was introduced in a different paper
            # (about the TabR model).
            if not activation:
                raise ValueError('lite=True is allowed only when activation=True')
            self.linear = nn.Linear(2 * n_frequencies, d_embedding)
        else:
            self.linear = _NLinear(n_features, 2 * n_frequencies, d_embedding)
        self.activation = nn.ReLU() if activation else None

    def forward(self, x: Tensor) -> Tensor:
        """Do the forward pass."""
        x = self.periodic(x)
        x = self.linear(x)
        if self.activation is not None:
            x = self.activation(x)
        return x