In [3]:
import torch
from pydantic import BaseModel

from shortcutfm.nn import DotProductScaledVMFLoss, NormPenalizedVMFLoss


class VMF(BaseModel):
    """Von Mises-Fisher distribution."""

    lambda_1: float = 0.02
    lambda_2: float = 0.1


class Model(BaseModel):
    """Model configuration."""

    hidden_size: int = 128


class Loss(BaseModel):
    mvf_loss_config: VMF = VMF()


class Config(BaseModel):
    """Configuration for the model."""

    loss: Loss = Loss()
    model: Model = Model()


cfg = Config()


batch_size, seq_len, m = 32, 10, 128
output = torch.randn(batch_size, seq_len, m)
target = torch.randn(batch_size, seq_len, m)
target = target / torch.norm(target, dim=-1, keepdim=True)
loss_fn = NormPenalizedVMFLoss(cfg)
loss = loss_fn(output, target)
print(loss.shape)

torch.Size([32, 10])


In [4]:
loss_fn = DotProductScaledVMFLoss(cfg)
loss = loss_fn(output, target)
print(loss.shape, loss)

torch.Size([32, 128, 128]) tensor([[[235.5394, 235.5394, 235.5394,  ..., 235.5394, 235.5394, 235.5394],
         [235.2728, 235.2728, 235.2728,  ..., 235.2728, 235.2728, 235.2728],
         [235.3779, 235.3779, 235.3779,  ..., 235.3779, 235.3779, 235.3779],
         ...,
         [235.4033, 235.4033, 235.4033,  ..., 235.4033, 235.4033, 235.4033],
         [235.0594, 235.0594, 235.0594,  ..., 235.0594, 235.0594, 235.0594],
         [235.1856, 235.1856, 235.1856,  ..., 235.1856, 235.1856, 235.1856]],

        [[235.3070, 235.3070, 235.3070,  ..., 235.3070, 235.3070, 235.3070],
         [235.4443, 235.4443, 235.4443,  ..., 235.4443, 235.4443, 235.4443],
         [235.2218, 235.2218, 235.2218,  ..., 235.2218, 235.2218, 235.2218],
         ...,
         [235.2936, 235.2936, 235.2936,  ..., 235.2936, 235.2936, 235.2936],
         [235.0348, 235.0348, 235.0348,  ..., 235.0348, 235.0348, 235.0348],
         [235.4137, 235.4137, 235.4137,  ..., 235.4137, 235.4137, 235.4137]],

        [[235.246

In [4]:
loss_fn = torch.nn.MSELoss(reduction="none")

loss = loss_fn(output, target)
print(loss.shape, loss)

torch.Size([32, 300]) tensor([[6.5662e-01, 2.2117e-01, 1.6569e+00,  ..., 1.5923e-01, 1.1368e-02,
         1.7070e-01],
        [1.5247e+00, 5.9660e-01, 4.2197e-01,  ..., 1.1189e+00, 8.3982e-01,
         4.3370e+00],
        [1.5221e+00, 6.7003e-01, 5.1103e-02,  ..., 1.9825e+00, 1.4709e-03,
         1.8554e-01],
        ...,
        [1.8703e+00, 3.7110e-02, 9.9314e-01,  ..., 1.3050e+00, 8.9383e-03,
         5.6288e-01],
        [1.1116e+00, 3.9523e-01, 1.3246e+00,  ..., 1.2305e-01, 1.9230e-02,
         2.5331e-05],
        [1.0504e+00, 4.1154e-01, 4.1936e-01,  ..., 1.8439e-01, 2.0889e+00,
         1.2823e-01]])
