# **Loss Functions**

**Loss functions** are crucial for training neural networks and improving task accuracy. They fall into two main categories corresponding to supervised learning types:

1. **Regression Losses** - MSE, MAE, RMSE
2. **Classification Losses** - Cross Entropy(Variations: Binary, Categorical, Sparse Categorical), Huber Loss

I'll demonstrate implementations across sklearn, JAX, and PyTorch, focusing on the most common loss functions in ML and deep learning. This section assumes models are already built and initialized—when you see 'model', interpret it as the neural network architecture within each framework.
"""


# Mean Squared Error Loss Function

$\text{MSE} = 1/n ∑ᵢ₌₁ⁿ (yᵢ - ŷᵢ)²$

Where:

$\text{n}$ = number of samples

$yᵢ$ = true value for sample i

$ŷᵢ$ = predicted value for sample i

$∑$ = summation over all samples

# Using Sklearn Metrics api

In [2]:
from sklearn.metrics import mean_squared_error
import numpy as np

# Generate some random data
y_true = np.random.rand(100)
y_pred = np.random.rand(100)

# Calculate the MSE
mse = mean_squared_error(y_true, y_pred)

mse


0.1744077613492566

# Using jax

In [9]:
import jax
import jax.numpy as jnp


def mse(y_true, y_pred):
  """Takes in arrays of y_true and y_pred. Array shapes must be the same"""
  y_true = jnp.array(y_true)
  y_pred = jnp.array(y_pred)
  if y_true.shape != y_pred.shape:
    raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")
  else:
    return jnp.mean(jnp.square(y_true - y_pred))
jax.jit(mse(y_true, y_pred))

Array(0.17440777, dtype=float32)

### Taking it a step further for neural networks


In [None]:
def mse(y_true, params, x):
  # define the forward pass for the network
  y_true = jnp.array(y_true)
  y_pred = model.apply(params,x)
  # Assess shape and then calculate loss
  if y_true.shape != y_pred.shape:
    raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")
  else:
    return jnp.mean(jnp.square(y_true - y_pred))


# Using Pytorch

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as f

def mse(y_true, y_pred):
  # change numpy arrays into torch tensors
  y_true = torch.tensor(y_true)
  y_pred = torch.tensor(y_pred)
  if y_true.shape != y_pred.shape:
    raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")
  else:
    # make predictions
    return torch.mean(torch.square(y_true - y_pred))
  return None

mse(y_true, y_pred)

tensor(0.1744, dtype=torch.float64)

## Taking it a step further for Neural networks

In [None]:
def mse(y_true, x):
  # define the forward pass for the network
  y_pred = model.forward(x)
  if y_true.shape != y_pred.shape:
    raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")
  else:
    return torch.mean(torch.square(y_true - y_pred))

  return None


# Binary CrossEntropy
$\text{BCE} = -[y_{true} * log(y_{pred}) + (1 - y_{true}) * log(1 - y_{pred})]$

y_true = one-hot encoded true labels

y_pred = predicted probabilities (after sigmoid/softmax)

Sklearn.

In [24]:
y_true = np.random.randint(low=0, high=2, size = (1,4))
y_true

# y_pred is a probability of what y_true at it's index is.
y_pred = np.random.rand(1,4)
y_pred

array([[0.36016328, 0.42604236, 0.2383781 , 0.09121179]])

In [32]:
from sklearn.metrics import log_loss
log_loss(y_true, y_pred)




4.268985431428761

Using Jax

In [27]:
def CrossEntropy(y_true, y_pred):
  y_true = jnp.asarray(y_true)
  y_pred = jnp.asarray(y_pred)
  if y_true.shape != y_pred.shape:
    raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")
  else:
    return -jnp.sum(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred))
jax.jit(CrossEntropy(y_true, y_pred))

Array(4.5412903, dtype=float32)

Taking a step for neural networks

In [None]:
def CrossEntropy(y_true, params, x):
  y_true = jnp.asarray(y_true)
  #forward pass to predict y
  y_pred = model.apply(params, x)
  if y_true.shape != y_pred.shape:
    raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")
  else:
    return -jnp.sum(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred))
jax.jit(CrossEntropy(y_true, y_pred)) # jit wrapper for faster just in time compilations

# Using Pytorch

In [34]:
def CrossEntropy(y_true, y_pred):
  y_true = torch.tensor(y_true)
  y_pred = torch.tensor(y_pred)
  if y_true.shape != y_pred.shape:
    raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")
  else:
    return -torch.sum((y_true * torch.log(y_pred)) + (1 - y_true) * torch.log(1 - y_pred))
CrossEntropy(y_true, y_pred)

tensor(4.5413, dtype=torch.float64)

# Taking it a step Further for Neural networks

In [31]:
def CrossEntropy(y_true, x):
  y_true = torch.tensor(y_true)
  y_pred = model.forwar(x)
  if y_true.shape != y_pred.shape:
    raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")
  else:
    return -torch.sum(y_true * torch.log(y_pred) - (1 - y_true) * torch.log(1 - y_pred))

# Summary:
I implemented MSE and BCE manually in both PyTorch and JAX/Flax to show how loss functions work at the mathematical and functional-programming level. These two losses are the simplest representatives of regression and binary classification, so implementing them proves the core structure of how predictions, targets, and gradients interact.

However, for real training it is safer and more efficient to use the built-in loss functions. Framework implementations handle critical issues like numerical stability (clamping, safe logs, overflow protection), memory efficiency (fused ops, optimized kernels), and dtype/device correctness. They also avoid silent bugs that arise in custom code during large-scale training.

So the point of manual losses was to demonstrate understanding—while the actual training relies on the optimized, stable, framework-provided versions.