## Setup

In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn

## MSE Loss

In [3]:
pred = torch.tensor([0.5, 0.6, 0.7, 0.8])
target = torch.tensor([0.0, 3.0, 4.0, 5.0])

In [4]:
# reduction default is 'mean'
loss = F.mse_loss(pred, target)
loss

tensor(8.6350)

In [5]:
# reduction = "none"
loss = F.mse_loss(pred, target, reduction='none')
loss

tensor([ 0.2500,  5.7600, 10.8900, 17.6400])

In [6]:
# Apply weights
pos_weight = 4
neg_weight = 1

weight = torch.where(
    target > 0,
    torch.ones_like(pred) * pos_weight,
    torch.ones_like(pred) * neg_weight
)

loss = loss * weight
loss = loss.mean()
loss

tensor(34.3525)

## Custom Loss

In [7]:
class WeightedMSELoss(nn.Module):
    def __init__(self, pos_weight=4, neg_weight=1):
        super(WeightedMSELoss, self).__init__()
        self.pos_weight = pos_weight
        self.neg_weight = neg_weight

    def forward(self, pred, target):
        weight = torch.where(
            target > 0,
            torch.ones_like(pred) * self.pos_weight,
            torch.ones_like(pred) * self.neg_weight
        )
        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * weight
        return loss.mean()

In [8]:
pred = torch.tensor([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
])

target = torch.tensor([
    [0.0, 1.0, 2.0],
    [3.0, 4.0, 5.0],
])

In [9]:
F.mse_loss(pred, target)

tensor(6.9850)

In [10]:
loss = WeightedMSELoss()

loss_value = loss(pred, target)
loss_value

tensor(27.9350)

## Weighted BCE Loss

In [2]:
import torch
import torch.nn as nn

# Example tensor of predictions (sigmoid output)
predictions = torch.tensor([0.1, 0.4, 0.35, 0.8], requires_grad=True)

# Example tensor of target labels (0 or 1)
targets = torch.tensor([0, 1, 0, 1], dtype=torch.float32)

# Define the weights for the classes
class_weights = torch.tensor([1.0, 2.0, 1.0, 2.0])  # Weight for class 0 and class 1

# Initialize the BCELoss function
criterion = nn.BCELoss()

# Calculate the loss
loss = criterion(predictions, targets)

print('BCELoss without class weights:', loss.item())

# Now using weighted BCE loss by incorporating weights
loss_with_weights = nn.BCELoss(weight=class_weights)(predictions, targets)

print('BCELoss with class weights:', loss_with_weights.item())


BCELoss without class weights: 0.4188944399356842
BCELoss with class weights: 0.7037529945373535
