## JAX loss functions 
Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/jax-loss-functions/"><img src="https://digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logo.png" alt="Open in ML Nuggets"></a>

In [None]:
pip install optax

In [5]:
import optax

In [4]:
import jax.numpy as jnp

In [11]:
import jax
def custom_sigmoid_binary_cross_entropy(logits, labels):
  log_p = jax.nn.log_sigmoid(logits)
  log_not_p = jax.nn.log_sigmoid(-logits)
  return -labels * log_p - (1. - labels) * log_not_p

In [12]:
custom_sigmoid_binary_cross_entropy(0.5,0.0)

DeviceArray(0.974077, dtype=float32, weak_type=True)

## sigmoid binary cross entropy

In [15]:
optax.sigmoid_binary_cross_entropy(0.5,0.0)

DeviceArray(0.974077, dtype=float32, weak_type=True)

## Softmax cross entropy

In [10]:
logits = jnp.array([0.50,0.60,0.70,0.30,0.25])
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.softmax_cross_entropy(logits,labels)

DeviceArray(1.6341426, dtype=float32)

## Cosine distance

In [7]:
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.cosine_distance(predictions,targets,epsilon=0.7)

DeviceArray(0.4128204, dtype=float32)

## cosine similarity

In [17]:
predictions = jnp.array([50.0,60.0,70,30.0,25.0])
targets = jnp.array([20.0,30.0,10.0,20.0,20.0])
optax.cosine_similarity(predictions,targets,epsilon=0.5)

DeviceArray(0.87630975, dtype=float32)

In [30]:
predictions = jnp.array([12.0, 20.0,29., 60.])
targets = jnp.array([14., 18., 27., 55.])
optax.cosine_similarity(predictions,targets)

DeviceArray(0.9989536, dtype=float32)

## Huber loss

In [9]:
logits = jnp.array([0.50,0.60,0.70,0.30,0.25])
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.huber_loss(logits,labels)

DeviceArray([0.045     , 0.045     , 0.17999998, 0.005     , 0.00125   ],            dtype=float32)

## l2 loss

In [10]:
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.l2_loss(predictions,targets)

DeviceArray([0.045     , 0.045     , 0.17999998, 0.005     , 0.00125   ],            dtype=float32)

## log cosh

In [11]:
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.log_cosh(predictions,targets)

DeviceArray([0.04434085, 0.04434085, 0.17013526, 0.00499171, 0.00124949],            dtype=float32)

## Smooth labels

In [12]:
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.smooth_labels(labels,alpha=0.4)

DeviceArray([0.2 , 0.26, 0.14, 0.2 , 0.2 ], dtype=float32)

## jax_metrics

In [None]:
pip install jax_metrics


In [7]:
import jax_metrics as jm

crossentropy = jm.losses.Crossentropy()

logits = jnp.array([0.50,0.60,0.70,0.30,0.25])

y = jnp.array([0.50,0.60,0.70,0.30,0.25])
crossentropy(target=y, preds=logits) 

DeviceArray(3.668735, dtype=float32)

In [34]:

logits = jnp.array([0.50,0.60,0.70,0.30,0.25])
y = jnp.array([0.50,0.60,0.70,0.30,0.25])
jm.losses.crossentropy(target=y, preds=logits) 

DeviceArray(3.668735, dtype=float32)

In [17]:
target = jnp.array([50,60,70,30,25])
preds = jnp.array([0.50,0.60,0.70,0.30,0.25])
huber_loss = jm.losses.Huber()
huber_loss(target=target, preds=preds)

DeviceArray(46.030003, dtype=float32)

In [20]:
target = jnp.array([50,60,70,30,25])
preds = jnp.array([0.50,0.60,0.70,0.30,0.25])
jm.losses.mean_absolute_error(target=target, preds=preds)


DeviceArray(46.530003, dtype=float32)

In [26]:
rng = jax.random.PRNGKey(42)

target = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
preds = jax.random.uniform(rng, shape=(2, 3))

jm.losses.cosine_similarity(target, preds, axis=1)

DeviceArray([-0.8602638 , -0.33731455], dtype=float32)

In [27]:
target = jnp.array([50,60,70,30,25])
preds = jnp.array([0.50,0.60,0.70,0.30,0.25])
jm.losses.mean_absolute_percentage_error(target=target, preds=preds)


DeviceArray(98.99999, dtype=float32)

In [28]:
target = jnp.array([50,60,70,30,25])
preds = jnp.array([0.50,0.60,0.70,0.30,0.25])
jm.losses.mean_squared_logarithmic_error(target=target, preds=preds)


DeviceArray(11.7779, dtype=float32)

In [30]:
target = jnp.array([0.50,0.60,0.70,0.30,0.25])
preds = jnp.array([0.50,0.60,0.70,0.30,0.25])
jm.losses.mean_squared_error(target=target, preds=preds)


DeviceArray(0., dtype=float32)

## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.