In [None]:
import torch

## Base Loss Code From PyTorch Docs

In [None]:
def mse_loss(x, y):
    return torch.mean((x - y) ** 2)

In [None]:
class _Loss(Module):
    reduction: str

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super().__init__()
        if size_average is not None or reduce is not None:
            self.reduction: str = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction

In [None]:
class MSELoss(_Loss):
    r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
    each element in the input :math:`x` and target :math:`y`.

    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:

    .. math::
        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
        l_n = \left( x_n - y_n \right)^2,

    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
    (default ``'mean'``), then:

    .. math::
        \ell(x, y) =
        \begin{cases}
            \operatorname{mean}(L), &  \text{if reduction} = \text{`mean';}\\
            \operatorname{sum}(L),  &  \text{if reduction} = \text{`sum'.}
        \end{cases}

    :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
    of :math:`n` elements each.

    The mean operation still operates over all the elements, and divides by :math:`n`.

    The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.

    Args:
        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
            the losses are averaged over each loss element in the batch. Note that for
            some losses, there are multiple elements per sample. If the field :attr:`size_average`
            is set to ``False``, the losses are instead summed for each minibatch. Ignored
            when :attr:`reduce` is ``False``. Default: ``True``
        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
            losses are averaged or summed over observations for each minibatch depending
            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
            batch element instead and ignores :attr:`size_average`. Default: ``True``
        reduction (str, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``

    Shape:
        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
        - Target: :math:`(*)`, same shape as the input.

    Examples::

        >>> loss = nn.MSELoss()
        >>> input = torch.randn(3, 5, requires_grad=True)
        >>> target = torch.randn(3, 5)
        >>> output = loss(input, target)
        >>> output.backward()
    """
    __constants__ = ['reduction']

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super().__init__(size_average, reduce, reduction)

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.mse_loss(input, target, reduction=self.reduction)

## Update This Func to Reflect The Above Methods

In [None]:
def cost_l2_torch(F, D, V, learning_batch, lambdaF=1e-7, lambdaD=1e-3, lambdaE=1e-6, Nd=2, Ne=64):
    # c_L2 = (lambdaE||DF + V+||_2)^2 + lambdaD*(||D||_2)^2 + lambdaF*(||F||_2)^2
    
    '''
    F: 64 channels x time EMG signals
    V: 2 x time target velocity
    D: 2 (x y vel) x 64 channels decoder
    H: 2 x 2 state transition matrix
    alphaE is 1e-6 for all conditions
    ''' 
    
    # Hmm should I detach and use numpy or just stick with tensor ops?
    # I don't want gradient to be tracked here but idk if it matters...

    Nt = learning_batch
    D = D.view(Nd, Ne)  #np.reshape(D,(Nd,Ne))
    Vplus = V[:,1:]
    # Performance
    term1 = lambdaE*(torch.linalg.matrix_norm((torch.matmul(D, F) - Vplus))**2)
    # D Norm
    term2 = lambdaD*(torch.linalg.matrix_norm((D)**2))
    # F Norm
    term3 = lambdaF*(torch.linalg.matrix_norm((F)**2))
    return (term1 + term2 + term3)

In [None]:
class CPHSLoss(_Loss):
    def __init__(self, F, D, V, learning_batch, lambdaF=0, lambdaD=1e-6, lambdaE=1e-3, Nd=2, Ne=64) -> None:
        super().__init__(F, D, V, learning_batch, lambdaF=0, lambdaD=1e-6, lambdaE=1e-3, Nd=2, Ne=64)
        self.F = F
        self.D = D
        self.V = V
        self.learning_batch = learning_batch
        self.lambdaF = lambdaF
        self.lambdaD = lambdaD
        self.lambdaE = lambdaE
        self.Nd = Nd
        self.Ne = Ne
        
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        Nt = self.learning_batch
        self.D = self.D.view(self.Nd, self.Ne)
        Vplus = self.V[:,1:]
        # Performance
        term1 = self.lambdaE*(torch.linalg.matrix_norm((torch.matmul(self.D, self.F) - Vplus))**2)
        # D Norm
        term2 = self.lambdaD*(torch.linalg.matrix_norm((self.D)**2))
        # F Norm
        term3 = self.lambdaF*(torch.linalg.matrix_norm((self.F)**2))
        return (term1 + term2 + term3)