What `clu` helps is not direct computation of cross entropy, etc. Instead, it helps JAX friendly metrics computation as the name stands for Common Loop Utils. 

* https://colab.research.google.com/github/google/CommonLoopUtils/blob/main/clu_synopsis.ipynb
* https://github.com/google/CommonLoopUtils/blob/main/clu/metrics.py#L326

In [1]:
from clu import metrics
import flax
import jax.numpy as jnp

# Metric computation

Metrics are computed in three steps. Remember that Jax needs functional code.

In [2]:
# 1. Compute intermediate values from model outputs
accuracy_batch1 = metrics.Accuracy.from_model_output(
    logits=jnp.array([[-1., 1.], [1., -1.]]),
    labels=jnp.array([0, 0]),  # i.e. 1st incorrect, 2nd correct
)
accuracy_batch2 = metrics.Accuracy.from_model_output(
    logits=jnp.array([[-1., 1.], [1., -1.]]),
    labels=jnp.array([1, 0]),  # i.e. both correct
)

# 2. Intermediate values are aggregated
accuracy = accuracy_batch1
accuracy = accuracy.merge(accuracy_batch2)

# 3. Final metrics
accuracy.compute()

Array(0.75, dtype=float32)

# Average

https://github.com/google/CommonLoopUtils/blob/main/clu/metrics.py#L96

Demonstrate what metrics.Metric APIs do

In [3]:
@flax.struct.dataclass
class MyAverage(metrics.Metric):
    total: jnp.array
    count: jnp.array
    
    
    @classmethod
    def from_model_output(cls, value: jnp.array, **_) -> metrics.Metric:
        # This name is really confusing due to "model_output", but all it does
        # is just take what's given as parameter to compute the metric.
        # The model itself is not referenced.
        return cls(total=value, count=jnp.prod(jnp.array(value.shape)))
    
    def merge(self, other: metrics.Metric) -> metrics.Metric:
        return type(self)(total=self.total + other.total, 
                          count=self.count + other.count)
    
    def compute(self):
        return self.total / self.count     

In [4]:
average = None
data = jnp.array([1, 2])
for value in data:
    update = MyAverage.from_model_output(value)
    average = update if average is None else average.merge(update)
    
# Computed avg.
print(average.compute())

# See total and count.
print(average)

1.5
MyAverage(total=Array(3, dtype=int32), count=Array(2., dtype=float32))


Actual Average class in clu

In [5]:
average = None
data = jnp.array([[1],[2]])
for value in data:
    # CLU Average also keeps track of total and count but it can considers mask
    # (which we don't demonstrate here).
    update = metrics.Average.from_model_output(value)
    average = update if average is None else average.merge(update)
print(average)

Average(total=Array(3, dtype=int32), count=Array(2, dtype=int32))


# Collection of metrics

In [6]:
@flax.struct.dataclass
class MyMetrics(metrics.Collection):
    accuracy: metrics.Accuracy
    
my_metrics = None

# List of array of (input, label) pairs
data = [(0, jnp.array(0)), 
        (1, jnp.array(0)),
        (2, jnp.array(1)),
        (3, jnp.array(1))]

def model(x):
    # If input is 0, predict label 0.
    if x == 0:
        return jnp.array([0.7, 0.2, 0.1])
    # Otherwise, predict label 1.
    else:
        return jnp.array([0.4, 0.5, 0.1])
    
# Given the above, accuracy should be 3/4 = 0.75
# input 0 -> predicted 0, actual 0
# input 1 -> predicted 1, actual 0
# input 2 -> predicted 1, actual 1
# input 3 -> predicted 1, actual 1

for inputs, labels in data:
    logits = model(inputs)
    # single_from_model_ouptut is when pmap isn't involved. Under pmap, use
    # gather_from_model_output.
    # 
    # The function essentially returns:
    # {
    #   'accuracy': Accuracy.from_model_output(logits=logits, labels=labels)
    # }
    #
    # Accuracy.from_model_output requires logits and lables where logits is
    # jnp.int32 array while logits's ndim is lables.ndim + 1. In our example,
    # label:  jnp.array(3).dtype == jnp.int32
    #         jnp.array(3).ndim == 0
    # logits: jnp.array([0.7, 0.2, 0.1]) == 1
    #
    # If batched, their ndim will be 1 and 2 for each of label and logits,
    # respectively.
    update = MyMetrics.single_from_model_output(logits=logits, labels=labels)
    my_metrics = update if my_metrics is None else my_metrics.merge(update)
    
print(my_metrics.accuracy)
print(my_metrics.accuracy.compute())

Accuracy(total=Array(3., dtype=float32), count=Array(4, dtype=int32))
0.75


Example with multiple metrics

In [7]:
@flax.struct.dataclass
class MyCollection(metrics.Collection):
    # Accuracy uses loss and logits.
    accuracy: metrics.Accuracy
    # Using 'from_output', specify what to collect from 
    # single_from_model_output parameters, i.e., collect 'loss'.
    loss: metrics.Average.from_output('loss')
    loss_std: metrics.Std.from_output('loss')

my_collection = MyCollection.single_from_model_output(
    # correct
    loss=0.1, logits=jnp.array([0.9, 0.1]), labels=jnp.array(0))
print(f'{my_collection.compute()=}')
update = MyCollection.single_from_model_output(
    # wrong as logit says label 0
    loss=0.7, logits=jnp.array([0.8, 0.2]), labels=jnp.array(1))
print(f'{update.compute()=}')
my_collection = my_collection.merge(update)
print(f'{my_collection.compute()=}')

my_collection.compute()={'accuracy': Array(1., dtype=float32), 'loss': Array(0.1, dtype=float32), 'loss_std': Array(0., dtype=float32)}
update.compute()={'accuracy': Array(0., dtype=float32), 'loss': Array(0.7, dtype=float32), 'loss_std': Array(0., dtype=float32)}
my_collection.compute()={'accuracy': Array(0.5, dtype=float32), 'loss': Array(0.4, dtype=float32), 'loss_std': Array(0.29999995, dtype=float32)}


# Collecting metrics

We can collect numbers in the metrics and then use them for computing metrics on the host CPU, e.g., gather logits and labels but compute accuracy using sklearn. See https://github.com/google/CommonLoopUtils/blob/cdd17c5d5f69280d216ab47061ef9a87f3a0a5a4/clu/metrics.py#L326