In [11]:
import keras
from keras import ops
import tensorflow as tf

In [22]:
class MetricEveryN(keras.metrics.Metric):
  """A metric that only runs every `n` batches"""
  def __init__(self, metric_fn, name='custom_metric', n=10, **kwargs):
    super().__init__(name=name, **kwargs)
    self.metric_fn = metric_fn
    self.n = n
    self.total = self.add_weight(name='total', initializer='zeros')
    self.count = self.add_weight(name='count', initializer='zeros')
    self.batch_counter = self.add_weight(name='batch_counter', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    self.batch_counter.assign_add(1)

    # Only compute the metric every N batches
    result = ops.cond(
      ops.equal(self.batch_counter % self.n, 0),
      lambda: ops.mean(self.metric_fn(y_true, y_pred)),
      lambda: 0.0
    )
    count = ops.cond(
      ops.equal(self.batch_counter % self.n, 0),
      lambda: 1, lambda: 0
    )
    self.total.assign_add(result)
    self.count.assign_add(count)

  def result(self):
    return ops.cond(
      ops.equal(self.count, 0),
      lambda: 0.0,
      lambda: self.total / self.count
    )

  def reset_states(self):
    self.total.assign(0.0)
    self.count.assign(0.0)
    self.batch_counter.assign(0.0)


In [23]:
@tf.function
def metric(a, b):
  tf.print('Running')
  return keras.metrics.mean_squared_error(a, b)

model = keras.Sequential([
  keras.layers.Input(shape=(1,)),
  keras.layers.Dense(1)
])
model.compile(
  loss=keras.losses.MeanSquaredError(),
  optimizer=keras.optimizers.SGD(1e-12),
  metrics=[MetricEveryN(metric, n=10)]
)

X = tf.range(100, dtype=tf.float32)
Y = 2 * X + 1
ds = tf.data.Dataset.from_tensor_slices((X, Y)).batch(4)

model.fit(ds)
# for x, y in ds:
#   z = model(x, training=True)
#   l = model.loss(y, z)
#   print(l)

[1m 1/25[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m7s[0m 302ms/step - custom_metric: 0.0000e+00 - loss: 52.7875Running
Running
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - custom_metric: 18444.0801 - loss: 14810.4609


<keras.src.callbacks.history.History at 0x7f2a281b3e50>

In [None]:
for x, y in ds:
  z = model(x, training=True)
  l = model.loss(y, z)
  print(l)
  break

tf.Tensor(nan, shape=(), dtype=float32)
