What is a Loss Function?

A loss function is a tool that tells a neural network how “wrong” its predictions are compared to the actual answers. It converts the difference between what the model predicts and what’s actually correct into a single number (the “loss”). The smaller this number, the better the model is doing.



Why Do We Need Loss Functions?

Loss functions are essential for training. After the model makes a prediction, the loss function measures its performance. The network uses this feedback to improve its parameters (like weights and biases), learning how to make better predictions over time.


Types of Loss Functions

1. Mean Squared Error (MSE)
Used for problems where the output is a number (like predicting price or temperature).
Checks how far off the predicted number is from the actual number, and penalizes bigger mistakes even more.


2. Binary Cross-Entropy (BCE)
Used for tasks where there are only two possible answers (like yes/no, true/false, or cat/dog).
Measures how well the model’s predicted probabilities match what actually happened.


3. Categorical Cross-Entropy
Used when there are more than two possible categories (like classifying images into dog, cat, or horse).
Looks at how much the model’s prediction matches the correct class out of several options.

In [9]:
import torch
import torch.nn as nn

In [13]:
print("calculation: mean-squared-error ")
y_true = torch.tensor([2.5, 0.0, 2.0, 8.0])
y_pred = torch.tensor([3.0, -0.5, 2.0, 7.0])
mse_loss = nn.MSELoss()
loss_mse = mse_loss(y_pred, y_true)
print("y_true:", y_true.tolist())
print("y_pred:", y_pred.tolist())
print("MSE Loss:", loss_mse.item())
print("-" * 40)


calculation: mean-squared-error 
y_true: [2.5, 0.0, 2.0, 8.0]
y_pred: [3.0, -0.5, 2.0, 7.0]
MSE Loss: 0.375
----------------------------------------


In [15]:
y_true = torch.tensor([1., 0., 1., 0.])
y_pred = torch.tensor([0.9, 0.2, 0.8, 0.1])
bin_loss = nn.BCELoss()
calculated_loss_bce = bin_loss(y_pred, y_true)
print("Bce Loss")
print(calculated_loss_bce)

Bce Loss
tensor(0.1643)
