In [205]:
import torch
import numpy as np

## Loss functions and tensor types

The tensor types matter

In [206]:
loss = torch.nn.CrossEntropyLoss()
x    = torch.tensor([[0.2, 0.3, 0.4],[2, 0.3,0.6]])
y    = torch.tensor([0.,1.])

#  Uncomment line (*) and see: 
## RuntimeError: Expected object of scalar type Long but got scalar type Double for argument #2 'target'

#  loss(x,y) # (*) Uncommenting test

Nevertheless the following tensor will work

In [241]:
loss   = torch.nn.CrossEntropyLoss()
x      = torch.tensor([[0.2, 0.3, 0.4],[2, 0.3,0.6]], requires_grad=True)
y      = torch.tensor([0,1])
weigth = loss(x,y)
output = weigth*loss(x,y)

In [210]:
W = torch.nn.Linear(3,2)

In [199]:
W.state_dict()

OrderedDict([('weight', tensor([[ 0.4135, -0.2818, -0.5046],
                      [-0.5118, -0.4690,  0.2053]])),
             ('bias', tensor([-0.3032, -0.1485]))])

In [211]:
xout = W.forward(x)

In [215]:
yhat = loss(xout,y)

In [217]:
yhat.backward()

In [220]:
out.grad

In [224]:
x.grad

tensor([[-0.0363,  0.0043, -0.0466],
        [ 0.0670, -0.0080,  0.0860]])

In [227]:
W.parameters

<bound method Module.parameters of Linear(in_features=3, out_features=2, bias=True)>

### Implement a modification of a loss function

In [242]:
import warnings

from torch import functional as F
from torch.nn import Module
from torch.nn import _reduction as _Reduction
from torch._jit_internal import weak_module, weak_script_method
from torch.nn.functional import cross_entropy

class _Loss(Module):
    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction

class _WeightedLoss(_Loss):
    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
        super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)
        

class CrossEntropyLossSquared(_WeightedLoss):

    def __init__(self, weight=None, size_average=None, ignore_index=-100,
                 reduce=None, reduction='mean'):
        super(CrossEntropyLossSquared, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    @weak_script_method
    def forward(self, input, target):
        return cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)**2



In [256]:
loss = CrossEntropyLossSquared()
x    = torch.tensor([[0.2, 0.3, 0.4],[2, 0.3,0.6]], requires_grad=True)
y    = torch.tensor([0,1])
out  = loss(x,y)
out.backward()

In [257]:
x.grad

tensor([[-1.1397,  0.5414,  0.5983],
        [ 1.1401, -1.4213,  0.2812]])

In [258]:
loss   = torch.nn.CrossEntropyLoss()
x      = torch.tensor([[0.2, 0.3, 0.4],[2, 0.3,0.6]], requires_grad=True)
y      = torch.tensor([0,1])
weigth = loss(x,y)
output = weigth*loss(x,y)
output.backward()

In [259]:
x.grad

tensor([[-1.1397,  0.5414,  0.5983],
        [ 1.1401, -1.4213,  0.2812]])