- Paper: _Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). Focal loss for dense object detection. In Proceedings of the IEEE international conference on computer vision (pp. 2980-2988)_.

- Focal loss is used to give a higher loss values to the examples on which the model performed worst. Ex: if the GT is 1.0 and the model predicted 0.2 --> bad pred, so focal loss is high; if the GT is 1.0 and the model predicted 0.8 --> good prediction, so focal loss is low. But this is pretty generic, in fact it's true for all losses, like the crossentropy.


- The key is that the focal loss is a modified crossentropy. In fact, the focal loss is composed by two factor, let's say `A * B`, where `B` is the crossentropy while `A` is a multiplier of the crossentropy, a multiplier with range [0, 1], so actually the focal loss is a reduce crossentropy. The point is that for examples correctly classified by the model (ex: gt=1 and pred=0.9, or gt=0 and pred=0.1), the focal loss is a greatly reduced crossentropy (ce=0.1, fl=0.001, with fl's gamma=2). This is the power of focal loss: reducing the contribution of the well classified examples (usually examples found frequently in the dataset) to the total loss (like the total batch loss), so that the bad classified examples (usually the rarest examples) can stand out more, so that the training process can focus on them. It's is just a matter of proportion and not of "true value", in fact for badly classified examples (ex: gt=1 and pred=0.1, or gt=0 and pred=0.9), the focal loss is actually lower than the crossentropy (fl=1.87, ce=2.30).

- Focal loss formula (for a binary classification task) (y = gt, p = pred):
    - two cases:
        - if `y = 1` --> `-(1-p)**gamma * log(p)`
        - if `y = 0` --> `-(p)**gamma * log(1-p)`
    - the second factor is exactly the binary crossentropy, for the cases y=1 and y=0.
    - so, when gamma = 0, the focal loss is equal to the crossentropy.
    - when can also add another factor, an "alpha" parameter representing weights we can give to the different classes, to down-weight even more the contribution of the frequent classes and up-weight the impact of the rare classes (actually, since these weights are usually in the range [0, 1], technically we are just downweighting everything, just some classes more than others). This weighting factor alpha can be added also to the standard crossentropy, it's not peculiar of the focal loss.
    <p>

- For multiclass task there are two ways to compute the focal loss for each example: (1) considering only the prediction of the true class, and (2) considering the predictions for all the classes, even the wrong ones. In the second case we are saying "the correct class is A, and the model should have said neither class B or C, but let's still consider that it predicted 0.1 for class B and 0.2 for class B, along with the prediction 0.7 for class A". Instead, in the first case we only consider the prediction of the correct class A 0.7, thus implying that there's only one wrong class and its prediction is 0.3 (basically we are falling back to the binary case). The first approach has been said to be "the most common", while the second one is more thorough. Here we use the first method, but at the end classes for both are provided.

In [1]:
import torch
import numpy as np

# <ins> THEORY

In [2]:
def simple_bce_loss(pred, gt):
    ce = - (gt * np.log(pred) + (1-gt) * np.log(1-pred))
    return ce

def simple_focal_loss(pred, gt, gamma=2, alpha=1):
    pt = pred if gt == 1 else 1 - pred
    fl = - alpha * ((1 - pt)**gamma) * np.log(pt)
    return fl

In [3]:
print(simple_bce_loss(pred=0.9, gt=1))
print(simple_bce_loss(pred=0.1, gt=0))
print(simple_bce_loss(pred=0.1, gt=1))
print(simple_bce_loss(pred=0.9, gt=0))

0.10536051565782628
0.10536051565782628
2.3025850929940455
2.302585092994046


In [4]:
print(simple_focal_loss(pred=0.9, gt=1, gamma=2))
print(simple_focal_loss(pred=0.1, gt=0, gamma=2))
print(simple_focal_loss(pred=0.1, gt=1, gamma=2))
print(simple_focal_loss(pred=0.9, gt=0, gamma=2))

0.0010536051565782623
0.0010536051565782623
1.865093925325177
1.8650939253251773


In [5]:
# Demostration of how the focal loss makes the badly classified examples stand out more.
# Let's say we have a batch of two example, one of which is well classified while the other not
gt1 = 1
pred1 = 0.9
gt2 = 1
pred2 = 0.1

In [6]:
# The batch loss is the sum of their losses
ce_loss_1 = simple_bce_loss(pred=pred1, gt=gt1)
ce_loss_2 = simple_bce_loss(pred=pred2, gt=gt2)
ce_batch_loss = ce_loss_1 + ce_loss_2

ce_loss_1, ce_loss_2

(0.10536051565782628, 2.3025850929940455)

In [7]:
# Let's see the proportional impact of each example of the batch loss
ce_loss_1 / ce_batch_loss, ce_loss_2 / ce_batch_loss

(0.04375535530340077, 0.9562446446965992)

In [8]:
# Now let's do the same with the focal loss
fl_loss_1 = simple_focal_loss(pred=pred1, gt=gt1, gamma=2)
fl_loss_2 = simple_focal_loss(pred=pred2, gt=gt2, gamma=2)
fl_batch_loss = fl_loss_1 + fl_loss_2

fl_loss_1, fl_loss_2

(0.0010536051565782623, 1.865093925325177)

In [9]:
# We can see that the contribution of the well classified example to the batch loss is way lower
fl_loss_1 / fl_batch_loss, fl_loss_2 / fl_batch_loss

(0.0005645883507968253, 0.9994354116492032)

In [10]:
# And higher the gamma, even lower the contribution of the bad classified examples
fl_loss_1 = simple_focal_loss(pred=pred1, gt=gt1, gamma=5)
fl_loss_2 = simple_focal_loss(pred=pred2, gt=gt2, gamma=5)
fl_batch_loss = fl_loss_1 + fl_loss_2
fl_loss_1 / fl_batch_loss, fl_loss_2 / fl_batch_loss

(7.74906520057872e-07, 0.9999992250934798)

In [11]:
# When gamma = 0, focal loss is the crossentropy loss.
print(simple_focal_loss(pred=0.7, gt=1, gamma=0))
print(simple_bce_loss(pred=0.7, gt=1))

0.35667494393873245
0.35667494393873245


# <ins> INNER WORKING

## FOR BINARY CLASSIFICATION

In [12]:
target = torch.tensor([1., 0., 0., 1., 1., 0., 1., 0.])
pred = torch.tensor([0.9, 0.3, 0.2, 0.3, 0.6, 0.4, 0.1, 0.8])

In [13]:
# To matrix form
target = torch.stack((1 - target, target), dim=1)
pred = torch.stack((1 - pred, pred), dim=1)

print(target)
print(pred)

tensor([[0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.]])
tensor([[0.1000, 0.9000],
        [0.7000, 0.3000],
        [0.8000, 0.2000],
        [0.7000, 0.3000],
        [0.4000, 0.6000],
        [0.6000, 0.4000],
        [0.9000, 0.1000],
        [0.2000, 0.8000]])


In [14]:
# We compute the two focal loss factors separately.
# Multiplication by `1-target` to ensure that when y=1, we use `1-p` and 
# when y=0, we use `p`, as by the formula.
a = (1 - target) * pred
a

tensor([[0.1000, 0.0000],
        [0.0000, 0.3000],
        [0.0000, 0.2000],
        [0.7000, 0.0000],
        [0.4000, 0.0000],
        [0.0000, 0.4000],
        [0.9000, 0.0000],
        [0.0000, 0.8000]])

In [15]:
# Here we use `target` in order to use `p` if y=1 and `1-p` if y=0 (opposite of above)
b = target * torch.log(pred)
b

tensor([[-0.0000, -0.1054],
        [-0.3567, -0.0000],
        [-0.2231, -0.0000],
        [-0.0000, -1.2040],
        [-0.0000, -0.5108],
        [-0.5108, -0.0000],
        [-0.0000, -2.3026],
        [-1.6094, -0.0000]])

In [16]:
# Flatten to 1 dimension, otherwise if matrix multiplication of between matrices yields a zero matrix
c = torch.sum(a, dim=1)
d = torch.sum(b, dim=1)

print(c)
print(d)

tensor([0.1000, 0.3000, 0.2000, 0.7000, 0.4000, 0.4000, 0.9000, 0.8000])
tensor([-0.1054, -0.3567, -0.2231, -1.2040, -0.5108, -0.5108, -2.3026, -1.6094])


In [17]:
# Focal loss computation for each example
gamma = 2
fl = - torch.pow(c, gamma) * d
fl = torch.round(fl, decimals=2)
fl

tensor([0.0000, 0.0300, 0.0100, 0.5900, 0.0800, 0.0800, 1.8700, 1.0300])

In [18]:
# Another way to use `p` or `1-p` in function of y, is by using `torch.where`,
# which corresponds to the use of the variable `pt` in the paper.
pt = torch.where(target == 1, pred, 1 - pred)
pt

tensor([[0.9000, 0.9000],
        [0.7000, 0.7000],
        [0.8000, 0.8000],
        [0.3000, 0.3000],
        [0.6000, 0.6000],
        [0.6000, 0.6000],
        [0.1000, 0.1000],
        [0.2000, 0.2000]])

In [19]:
# Keep just one column since they are the same
pt = pt[..., 0]
fl = - torch.pow((1 - pt), gamma) * torch.log(pt)
fl = torch.round(fl, decimals=2)
fl

tensor([0.0000, 0.0300, 0.0100, 0.5900, 0.0800, 0.0800, 1.8700, 1.0300])

## FOR MULTICLASS CLASSIFICATION

In [20]:
target = torch.tensor([
    [1, 0, 0],
    [0, 0, 1],
    [1, 0, 0],
    [1, 0, 0],
    [0, 1, 0],
    [1, 0, 0],
    [1, 0, 0],
    [1, 0, 0]
])

In [21]:
pred = torch.tensor([
    [0.6828, 0.2600, 0.0572],
    [0.4128, 0.0719, 0.5153],
    [0.7387, 0.1748, 0.0865],
    [0.6274, 0.1481, 0.2245],
    [0.2029, 0.5384, 0.2587],
    [0.5034, 0.0887, 0.4079],
    [0.4194, 0.2748, 0.3058],
    [0.4968, 0.1833, 0.3198]
])

In [342]:
# As explained at the beginning of the notebook, when computing the focal loss 
# for multi-class classification problems, it's common to calculate the loss 
# for the true class only.

In [23]:
# Following the approach above...
pt = torch.where(target == 1, pred, 1 - pred)
pt

tensor([[0.6828, 0.7400, 0.9428],
        [0.5872, 0.9281, 0.5153],
        [0.7387, 0.8252, 0.9135],
        [0.6274, 0.8519, 0.7755],
        [0.7971, 0.5384, 0.7413],
        [0.5034, 0.9113, 0.5921],
        [0.4194, 0.7252, 0.6942],
        [0.4968, 0.8167, 0.6802]])

In [24]:
# Multiply by the target to keep only the `p` of the true class
fl = - target * torch.pow(1 - pt, gamma) * torch.log(pt)
fl

tensor([[0.0384, -0.0000, -0.0000],
        [-0.0000, -0.0000, 0.1558],
        [0.0207, -0.0000, -0.0000],
        [0.0647, -0.0000, -0.0000],
        [-0.0000, 0.1319, -0.0000],
        [0.1693, -0.0000, -0.0000],
        [0.2929, -0.0000, -0.0000],
        [0.1771, -0.0000, -0.0000]])

In [25]:
fl = torch.sum(fl, dim=1)

fl

tensor([0.0384, 0.1558, 0.0207, 0.0647, 0.1319, 0.1693, 0.2929, 0.1771])

In [26]:
# On another perspective, since as said before for a multiclass task it's common 
# to calculate the loss for the true class only, basically we are falling back
# to the binary case.

# We start by keeping only the `p` for the real class.
a = target * pred

a

tensor([[0.6828, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5153],
        [0.7387, 0.0000, 0.0000],
        [0.6274, 0.0000, 0.0000],
        [0.0000, 0.5384, 0.0000],
        [0.5034, 0.0000, 0.0000],
        [0.4194, 0.0000, 0.0000],
        [0.4968, 0.0000, 0.0000]])

In [27]:
# Then we put everything in one column...
b = torch.sum(a, dim=1)

b

tensor([0.6828, 0.5153, 0.7387, 0.6274, 0.5384, 0.5034, 0.4194, 0.4968])

In [28]:
# ...which is equivalent to these pred and target matrices
pred2 = torch.stack([1 - b, b], dim=1)

target2 = torch.stack([ # the target reflects the position of the true class `p`
    torch.zeros(size=[8,]),
    torch.ones(size=[8,])
    ],
    dim = 1
)

print(pred2)
print(target2)

tensor([[0.3172, 0.6828],
        [0.4847, 0.5153],
        [0.2613, 0.7387],
        [0.3726, 0.6274],
        [0.4616, 0.5384],
        [0.4966, 0.5034],
        [0.5806, 0.4194],
        [0.5032, 0.4968]])
tensor([[0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.]])


In [29]:
# Here we don't need any summing to bring all `p` to the same column,
# since it was already built that way, so we can easily compute fl
fl = - torch.pow(1 - b, gamma) * torch.log(b)

fl

tensor([0.0384, 0.1558, 0.0207, 0.0647, 0.1319, 0.1693, 0.2929, 0.1771])

# FOR MULTICLASS SEGMENTATION

In [30]:
# - understand how compute fl for images (use flattening first)
# - after "inner working" chapter, add chapter (title 1) with the proper class functions (with alpha)

# [B, C, H, W]

In [45]:
gt = torch.tensor([
    [2, 2, 2, 0, 0, 0, 0, 0],
    [2, 2, 0, 3, 3, 3, 3, 0],
    [2, 0, 0, 3, 3, 3, 3, 0],
    [0, 0, 0, 3, 3, 3, 3, 0],
    [0, 0, 0, 3, 3, 3, 3, 0],
    [0, 0, 0, 3, 3, 3, 3, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 0, 0]
])

In [46]:
# Class frequency
classes, counts = torch.unique(gt, return_counts=True)

print(classes)
print(counts / (gt.shape[0] * gt.shape[1]))

tensor([0, 1, 2, 3])
tensor([0.5625, 0.0312, 0.0938, 0.3125])


In [47]:
# To one-hot
target = torch.nn.functional.one_hot(gt, num_classes=4).permute(2, 0, 1).unsqueeze(0) # [B, C, H, W]
target.shape

torch.Size([1, 4, 8, 8])

In [82]:
# Generate random prediction (high loss expected)
torch.manual_seed(2296)
pred_bad = torch.randn(size=(1, 4, 8, 8))
pred_bad = torch.nn.Softmax(dim=1)(pred_bad)
pred_bad.shape

torch.Size([1, 4, 8, 8])

In [121]:
# Generate a prediction close to the gt (low loss expected)
pred_labels = torch.tensor([
    [2, 0, 0, 0, 0, 0, 0, 0],
    [2, 2, 0, 3, 2, 2, 2, 0],
    [2, 0, 0, 3, 3, 2, 1, 0],
    [0, 0, 0, 3, 3, 3, 1, 0],
    [1, 1, 0, 3, 3, 3, 3, 0],
    [0, 0, 0, 3, 3, 3, 3, 0],
    [0, 0, 0, 0, 0, 1, 0, 0],
    [1, 2, 0, 0, 0, 1, 0, 0]
])

pred_good = torch.nn.functional.one_hot(pred_labels, num_classes=4).permute(2, 0, 1).unsqueeze(0).type(torch.float32)
pred_good = torch.nn.Softmax(dim=1)(pred_good) # to turn [1., 0., 0., 0.] into [0.4754, 0.1749, 0.1749, 0.1749]

pred_good.shape

torch.Size([1, 4, 8, 8])

### Using flattening to mimick the classification case

In [122]:
target_flat = target.reshape([1, 4, 64])
pred_bad_flat = pred_bad.reshape([1, 4, 64])
pred_good_flat = pred_good.reshape([1, 4, 64])

target_flat = target_flat.permute(0, 2, 1)[0]
pred_bad_flat = pred_bad_flat.permute(0, 2, 1)[0]
pred_good_flat = pred_good_flat.permute(0, 2, 1)[0]

target_flat.shape, pred_bad_flat.shape, pred_good_flat.shape

(torch.Size([64, 4]), torch.Size([64, 4]), torch.Size([64, 4]))

In [123]:
pt_bad = pred_bad_flat[target_flat == 1]
pt_good = pred_good_flat[target_flat == 1]

# # Which is equivalent to what before was done like this
# pt = torch.where(target == 1, pred, 0)
# pt = pt.sum(dim=1)

In [101]:
gamma = 2

In [124]:
pixel_fl_bad = - torch.pow(1 - pt_bad, gamma) * torch.log(pt_bad)
fl_bad = pixel_fl_bad.mean()
fl_bad

tensor(1.1204)

In [125]:
pixel_fl_good = - torch.pow(1 - pt_good, gamma) * torch.log(pt_good)
fl_good = pixel_fl_good.mean()
fl_good

tensor(0.4042)

### With images

In [239]:
pt = pred_bad[target == 1] # [H*W]

pt

tensor([0.2985, 0.6023, 0.1072, 0.1263, 0.6341, 0.0896, 0.3313, 0.5475, 0.1245,
        0.9195, 0.2905, 0.3415, 0.5671, 0.6836, 0.0615, 0.0411, 0.5475, 0.1193,
        0.5202, 0.7757, 0.3471, 0.0399, 0.0388, 0.6864, 0.0644, 0.0207, 0.0962,
        0.2078, 0.2975, 0.3145, 0.5481, 0.2197, 0.4078, 0.1834, 0.0129, 0.3133,
        0.2000, 0.2322, 0.4959, 0.3939, 0.0679, 0.5505, 0.3391, 0.0510, 0.2323,
        0.2390, 0.1906, 0.1828, 0.3390, 0.1070, 0.5231, 0.0141, 0.1243, 0.5086,
        0.2491, 0.2309, 0.1206, 0.2604, 0.2431, 0.4729, 0.1216, 0.5495, 0.2166,
        0.0394])

In [240]:
pt.shape # []

torch.Size([64])

In [241]:
pixel_losses = - torch.pow(1 - pt, gamma) * torch.log(pt) # [H*W]
fl = pixel_losses.mean()
fl

tensor(1.1204)

In [202]:
# Check if pixel losses make sense
i = 5
j = 5

print(target[0, :, i, j])
print(pred_bad[0, :, i, j])
print(pixel_losses.reshape([8, 8])[i, j])

# tensor([1, 0, 0, 0])
# tensor([0.9195, 0.0358, 0.0117, 0.0329])
# tensor(0.0370)

# tensor([1, 0, 0, 0])
# tensor([0.7757, 0.0320, 0.0889, 0.1034])
# tensor(0.1206)

# tensor([1, 0, 0, 0])
# tensor([0.0411, 0.1552, 0.7253, 0.0783])
# tensor(1.1311)

# tensor([0, 0, 1, 0])
# tensor([0.2949, 0.0735, 0.0679, 0.5637])
# tensor(1.7795)

# tensor([0, 0, 0, 1])
# tensor([0.0232, 0.0272, 0.4409, 0.5086])
# tensor(0.5984)

tensor([0, 0, 0, 1])
tensor([0.3160, 0.2328, 0.2346, 0.2166])
tensor(0.8291)


## FOR A BATCH OF IMAGES

In [303]:
target_batch = torch.concatenate([target, target])
pred_batch = torch.concatenate([pred_bad, pred_good])

target_batch.shape, pred_batch.shape

(torch.Size([2, 4, 8, 8]), torch.Size([2, 4, 8, 8]))

In [304]:
# Compute `pt`, turning to zero the prediction for the wrong classes.
pt_batch = torch.where(target_batch == 1, pred_batch, 0)

pt_batch.shape

torch.Size([2, 4, 8, 8])

In [305]:
# Keep only `p` of the right class.
pt = torch.sum(pt_batch, dim=1)
pt.shape

torch.Size([2, 8, 8])

In [306]:
# Compute pixel focal losses
gamma = 2
pixel_losses = - torch.pow(1 - pt, gamma) * torch.log(pt)
pixel_losses.shape

torch.Size([2, 8, 8])

In [307]:
# Compute overall focal loss for each image in the batch
image_losses = torch.mean(pixel_losses, dim=[1, 2])
print(image_losses) # remember: bad prediction, good prediction

# Compute batch loss
batch_loss = torch.mean(image_losses)
batch_loss

tensor([1.1204, 0.4042])


tensor(0.7623)

### Demonstration that when `gamma=0` we fall back to the Crossentropy loss

In [None]:
# FL with gamma=0
# tensor([1.5920, 0.9468])
# tensor(1.2694)

# NOTE: can't use `torch.nn.CrossEntropyLoss()` because it applies a Softmax internally,
# while these tensors are already activated like a Softmax would do.
# The example in the doc page (https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
# implies the PyTorch function doesn't apply a Softmax internally, but it does. Proof:
# logits = torch.randn([1, 4, 8, 8])
# preds = torch.nn.Softmax(dim=1)(logits)
# a = target * torch.log(preds) # [1, C, H, W]
# pixel_losses = - torch.sum(a, dim=1) # [1, H, W]
# ce = pixel_losses.mean()
# ce_torch = torch.nn.CrossEntropyLoss()(input=logits, target=target.type(torch.float32))
# ce, ce_torch
# # (tensor(1.8743), tensor(1.8743))

In [313]:
# Compute crossentropy (just for `pred_bad`).
# We could have used also my `WeightedCrossEntropyLoss` function, 
# (see the crossentropy notebook) but turning the internal Softmax off.
a = target * torch.log(pred_bad) # [1, C, H, W]
pixel_losses = - torch.sum(a, dim=1) # [1, H, W]
ce = pixel_losses.mean()
ce

# 1.5920

tensor(1.5920)

# <ins> CLASSES FOR FOCAL LOSS

In [340]:
# Class that considers just the predictions for the correct class.
class FocalLoss1(torch.nn.Module):
    def __init__(self, gamma, class_weights=None):
        super().__init__()
        self.gamma = gamma
        self.class_weights = class_weights
        self.activation = torch.nn.Softmax(dim=1)
        
    def forward(self, input, target):        
        input_activated = self.activation(input)
        # Compute `pt`, turning to zero the prediction for the wrong classes.
        pt_batch = torch.where(target == 1, input_activated, 0) # [B, C, H, W]
        # Keep only `p` of the right class.
        pt = torch.sum(pt_batch, dim=1) # [B, H, W]
        pt_log = torch.log(torch.where(pt==0, 1e-8, pt)) # safely compute log
        pixel_losses = -torch.pow((1 - pt), self.gamma) * pt_log # [B H, W]
        
        if self.class_weights is not None:
            weights = self.compute_weights(gt=target, class_weights=self.class_weights)
            pixel_losses = pixel_losses * weights 
        
        batch_loss = torch.mean(pixel_losses) # []
        
        return batch_loss
        
    def compute_weights(self, gt, class_weights):
        weights = gt * class_weights.reshape(1, class_weights.shape[0], 1, 1) # [B, C, H, W]
        weights = torch.sum(weights, dim=1) # [B, H, W]
        return weights

In [341]:
torch.manual_seed(2296)
logits = torch.randn([2, 4, 8, 8])
fl_fn = FocalLoss1(gamma=2)
fl_fn(input=logits, target=target_batch)

tensor(1.0843)

In [336]:
# Class that considers also predictions for the wrong classes.
class FocalLoss2(torch.nn.Module):
    def __init__(self, gamma, class_weights=None):
        super().__init__()
        self.gamma = gamma
        self.class_weights = class_weights
        self.activation = torch.nn.Softmax(dim=1)
        
    def forward(self, input, target):        
        input_activated = self.activation(input)
        pt = torch.where(target==1, input_activated, 1 - input_activated)
        pt_log = torch.log(torch.where(pt==0, 1e-8, pt)) # safely compute log
        pixel_losses = -torch.pow((1-pt), self.gamma) * pt_log # [batch_size, n_classes, H, W]
        summed_pixel_losses = torch.sum(pixel_losses, dim=1) # [batch_size, H, W]
        
        if self.class_weights is not None:
            weights = self.compute_weights(gt=target, class_weights=self.class_weights)
            summed_pixel_losses = summed_pixel_losses * weights 
        
        batch_loss = torch.mean(summed_pixel_losses) # []
        
        return batch_loss
        
    def compute_weights(self, gt, class_weights):
        weights = gt * class_weights.reshape(1, class_weights.shape[0], 1, 1) # [batch_size, n_classes, H, W]
        weights = torch.sum(weights, dim=1) # [batch_size, H, W]
        return weights

In [337]:
torch.manual_seed(2296)
logits = torch.randn([2, 4, 8, 8])
fl_fn = FocalLoss2(gamma=2)
fl_fn(input=logits, target=target_batch)

tensor(1.2854)