<a href="https://colab.research.google.com/github/malcolmlett/ml-learning/blob/main/Learning_visualisations_v17_onDemandAutoGraphs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Learning Visualisations v17: On-demand Auto-Graphing in callbacks
Despite significant improvements in how I have written `@tf.function` methods, and I'm still getting a lot of these warnings when doing rapid development with lots of re-runs of experiments:
```
WARNING:tensorflow:5 out of the last 374 calls to <function ActivityStatsCollectingMixin._accum_activity_stats_internal at 0x7e2dfb532480> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 374 calls to <function ValueStatsCollectingMixin._compute_iteration_value_stats at 0x7e2dfb532ac0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 374 calls to <function ActivityStatsCollectingMixin._compute_activity_stats at 0x7e2dfb533600> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
```

It's clear that it's not actually doing re-tracing very often, and probably only because I'm doing lots of module re-loads. So it's not actually a performance problem. But it adds a lot of noise to experiments.

I've noticed that TF tends to manually call auto-graph at the start of training. This actually makes sense for my callbacks too. Whatever auto-graphing they come up with will almost certainly be irrelevant on the next run (when doing rapid development). So it makes sense to use programmatic auto-graphing at the start of training, and then to discard that auto-graph afterwards.

In [1]:
import os
if os.path.isdir('repo'):
  # discard any local changes and update
  !cd repo && git reset --hard HEAD
  !cd repo && git fetch
else:
  !git clone https://github.com/malcolmlett/ml-learning.git repo

# lock to revision
!cd repo && git checkout f289e95
#!cd repo && git pull

import sys
sys.path.append('repo')

import train_observability_toolkit as tot
from importlib import reload
reload(tot)

Cloning into 'repo'...
remote: Enumerating objects: 1178, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 1178 (delta 6), reused 9 (delta 3), pack-reused 1163 (from 2)[K
Receiving objects: 100% (1178/1178), 83.33 MiB | 17.73 MiB/s, done.
Resolving deltas: 100% (714/714), done.
Updating files: 100% (37/37), done.
Already up to date.


<module 'train_observability_toolkit' from '/content/repo/train_observability_toolkit.py'>

In [2]:
import train_observability_toolkit_test
reload(train_observability_toolkit_test)
reload(tot)
train_observability_toolkit_test.run_test_suite()

All train_observability_toolkit tests passed.


In [3]:
import keras
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers, models, datasets, optimizers, metrics
import numpy as np
from sklearn.decomposition import PCA
from scipy.stats import pearsonr
import pandas as pd
import matplotlib.pyplot as plt
import math
import sklearn
import sklearn.datasets
import time
import timeit
import tqdm

## Basics


In [4]:
def binary_classification_dataset():
  np.random.seed(1)
  train_X, train_Y = sklearn.datasets.make_circles(n_samples=300, noise=.05)
  np.random.seed(2)
  test_X, test_Y = sklearn.datasets.make_circles(n_samples=100, noise=.05)
  train_X = train_X
  train_Y = train_Y.reshape((-1, 1))
  test_X = test_X
  test_Y = test_Y.reshape((-1, 1))
  return train_X, train_Y, test_X, test_Y

def mnist_dataset():
  np.random.seed(1)
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # Normalize the data
  x_train = x_train / 255.0
  return x_train, y_train

def binary_classification_model(init_scheme):
  """
  init_scheme: one of "zeros", "large_normal", "he_normal"
  """
  if init_scheme == "zeros":
    kernel_initializer='zeros'
  elif init_scheme == "large_normal":
    kernel_initializer=tf.keras.initializers.RandomNormal(stddev=10.)
  elif init_scheme == "he_normal":
    kernel_initializer='he_normal'
  else:
    raise ValueError("Unknown init_scheme")

  model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(2,)),
    tf.keras.layers.Dense(100, activation='relu', kernel_initializer=kernel_initializer),
    tf.keras.layers.Dense(100, activation='relu', kernel_initializer=kernel_initializer),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(100, activation='relu', kernel_initializer=kernel_initializer),
    tf.keras.layers.Dense(100, activation='relu', kernel_initializer=kernel_initializer),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(100, activation='relu', kernel_initializer=kernel_initializer),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(5, activation='relu', kernel_initializer=kernel_initializer),
    tf.keras.layers.Dense(1, activation='sigmoid', kernel_initializer=kernel_initializer)
  ])

  return model

def mnist_cnn_model():
  model = tf.keras.Sequential([
    layers.Input(shape=(28, 28)),
    layers.Reshape((28, 28, 1)),
    layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', strides=2, activation='relu'),
    layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=2, activation='relu'),
    layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same', strides=2, activation='relu'),
    layers.Flatten(),
    layers.Dense(32, activation='relu'),
    layers.Dense(10, activation='softmax')  # Output layer for 10 classes
  ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

  return model

class TimingCallback(tf.keras.callbacks.Callback):
  def __init__(self):
    super().__init__()
    self._start = None
    self._epochs = None
    self.total = None
    self.per_epoch = None

  def on_epoch_begin(self, epoch, logs=None):
    if epoch == 1:
      self._start = tf.timestamp()
    self._epochs = epoch-1

  def on_train_end(self, logs=None):
    self.total = (tf.timestamp() - self._start).numpy()
    self.per_epoch = self.total / (self._epochs)

# Basic implementation
Let's first measure this directly and experiment with implementation options.

In [94]:
# This demonstrates it quite nicely.
# Re-run this cell block multiple times and you'll get the following approximately every second run:
#   WARNING:tensorflow:5 out of the last 5 calls to <function example_function at 0x798d5ef09620> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
#   WARNING:tensorflow:6 out of the last 6 calls to <function example_function at 0x798d5ef09620> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
def example_function(x):
  return x * x

example_fn = tf.function(example_function)
print(f"{example_fn.name=}")
print(f"{example_fn.__name__=}")
print()

# Call function with different dtypes (causing retracing)
example_fn(tf.constant(3.0))  # float32
example_fn(tf.constant(3))    # int32
example_fn(tf.constant([1, 2, 3], dtype=tf.float32))  # 1D tensor, float32

# Print retracing info
print(example_fn.pretty_printed_concrete_signatures())



example_fn.name='example_function'
example_fn.__name__='example_function'

Input Parameters:
  x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
  None

Input Parameters:
  x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
  None

Input Parameters:
  x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(3,), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(3,), dtype=tf.float32, name=None)
Captures:
  None


In [28]:
# simple example callback that just computes the mean-of-means across all values
class MyCallback1(tot.BaseGradientCallback):
  def __init__(self):
    super().__init__()
    self.means = []

  def on_epoch_end(self, epoch, loss, gradients, trainable_variables, activations, output_gradients):
    self.means.append(self._compute(gradients));

  @tf.function
  def _compute(self, values):
    means = [tf.reduce_mean(tensor) for tensor in values]
    mean = tf.reduce_mean(means)
    return mean

class MyCallback2(tot.BaseGradientCallback):
  def __init__(self):
    super().__init__()
    self.means = []
    self._compute_fn = tf.function(self._compute)

  def on_epoch_end(self, epoch, loss, gradients, trainable_variables, activations, output_gradients):
    self.means.append(self._compute_fn(gradients));

  def _compute(self, values):
    means = [tf.reduce_mean(tensor) for tensor in values]
    mean = tf.reduce_mean(means)
    return mean


In [29]:
# Both callbacks seem to perform in almost identical time, but I'm unable to replicate the re-tracing problem
reload(tot)
tf.config.run_functions_eagerly(False)

cb = MyCallback2()
timing = TimingCallback()
model = binary_classification_model('he_normal')
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', 'mse', 'binary_crossentropy'])
train_X, train_Y, _, _ = binary_classification_dataset()
dataset = tf.data.Dataset.from_tensor_slices((train_X, train_Y))
start = tf.timestamp()
history = tot.fit(model, dataset.batch(32), epochs=10, verbose=0, callbacks=[tot.LessVerboseProgressLogger(), tot.HistoryStats(), cb, timing])
duration = (tf.timestamp() - start).numpy()
print(f"Total training time: {duration:.2f} secs. Average: {timing.per_epoch*1000:.2f}ms/epoch")

Epoch     1 - 3.44s/epoch: accuracy: 0.5567  binary_crossentropy: 0.6904  loss: 0.6904  mse: 0.2482  
Epoch     2 - 107.02ms/epoch: accuracy: 0.5667  binary_crossentropy: 0.6708  loss: 0.6708  mse: 0.2392  
Epoch     3 - 58.90ms/epoch: accuracy: 0.6300  binary_crossentropy: 0.6427  loss: 0.6427  mse: 0.2258  
Epoch     4 - 60.44ms/epoch: accuracy: 0.6533  binary_crossentropy: 0.6235  loss: 0.6235  mse: 0.2165  
Epoch     5 - 57.55ms/epoch: accuracy: 0.6733  binary_crossentropy: 0.6030  loss: 0.6030  mse: 0.2080  
Epoch     6 - 60.81ms/epoch: accuracy: 0.6767  binary_crossentropy: 0.5949  loss: 0.5949  mse: 0.2038  
Epoch     7 - 60.32ms/epoch: accuracy: 0.7533  binary_crossentropy: 0.5381  loss: 0.5381  mse: 0.1790  
Epoch     8 - 64.88ms/epoch: accuracy: 0.7367  binary_crossentropy: 0.5422  loss: 0.5422  mse: 0.1815  
Epoch     9 - 58.66ms/epoch: accuracy: 0.7633  binary_crossentropy: 0.5075  loss: 0.5075  mse: 0.1667  
Epoch    10 - 60.30ms/epoch: accuracy: 0.7300  binary_crossentrop

In [None]:
cb = MyCallback1()
timing = TimingCallback()
model = binary_classification_model('he_normal')
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', 'mse', 'binary_crossentropy'])
train_X, train_Y, _, _ = binary_classification_dataset()
dataset = tf.data.Dataset.from_tensor_slices((train_X, train_Y))
# model = mnist_cnn_model()
# train_X, train_Y = mnist_dataset()
# dataset = tf.data.Dataset.from_tensor_slices((train_X, train_Y)).take(32000).batch(64)
start = tf.timestamp()
history = tot.fit(model, dataset.batch(32), epochs=100, verbose=0, callbacks=[tot.LessVerboseProgressLogger(), tot.HistoryStats(), cb, timing])
#history = tot.fit(model, dataset, epochs=10, callbacks=[tot.HistoryStats(), cb , timing])
duration = (tf.timestamp() - start).numpy()
print(f"Total training time: {duration:.2f} secs. Average: {timing.per_epoch*1000:.2f}ms/epoch")

## Using existing callbacks

In [25]:
# here the problem occurs because these different callbacks collect different data from differently shaped tensors, but re-use shared code.
# The problem doesn't occur if I only use a single callback.
#
# Unfortunately, with the code modified to use a simple programmatic tf.function(func_ref) I still get retracing notifications, but worse, now they don't
# even mention the function name:
#  WARNING:tensorflow:5 out of the last 50 calls to <tensorflow.python.eager.polymorphic_function.polymorphic_function.Function object at 0x798d5d252ed0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

reload(tot)
tf.config.run_functions_eagerly(False)

per_step=False
variables = tot.VariableHistoryCallback(per_step=per_step, collection_sets=[{}])
gradients = tot.GradientHistoryCallback(per_step=per_step, collection_sets=[{}])
outputs = tot.LayerOutputHistoryCallback(per_step=per_step, collection_sets=[{}])
epoch_gradients = tot.GradientHistoryCallback(per_step=per_step, collection_sets=[{}])
output_gradients = tot.LayerOutputGradientHistoryCallback(per_step=per_step, collection_sets=[{}])

model = binary_classification_model('he_normal')
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', 'mse', 'binary_crossentropy'])
train_X, train_Y, _, _ = binary_classification_dataset()
dataset = tf.data.Dataset.from_tensor_slices((train_X, train_Y))
history = tot.fit(model, dataset.batch(32), epochs=10, verbose=0, callbacks=[tot.LessVerboseProgressLogger(), variables, gradients, epoch_gradients, outputs, output_gradients, tot.HistoryStats(per_step=False)])
#history = tot.fit(model, dataset.batch(32), epochs=10, verbose=0, callbacks=[tot.LessVerboseProgressLogger(), output_gradients, tot.HistoryStats(per_step=False)])

Epoch     1 - 5.16s/epoch: accuracy: 0.6500  binary_crossentropy: 0.7035  loss: 0.7035  mse: 0.2377  




Epoch     2 - 12.23s/epoch: accuracy: 0.6100  binary_crossentropy: 0.6655  loss: 0.6655  mse: 0.2347  
Epoch     3 - 461.14ms/epoch: accuracy: 0.6367  binary_crossentropy: 0.6391  loss: 0.6391  mse: 0.2234  
Epoch     4 - 531.53ms/epoch: accuracy: 0.6700  binary_crossentropy: 0.6210  loss: 0.6210  mse: 0.2162  
Epoch     5 - 248.03ms/epoch: accuracy: 0.7000  binary_crossentropy: 0.5920  loss: 0.5920  mse: 0.2013  
Epoch     6 - 272.24ms/epoch: accuracy: 0.7333  binary_crossentropy: 0.5360  loss: 0.5360  mse: 0.1812  
Epoch     7 - 254.83ms/epoch: accuracy: 0.7167  binary_crossentropy: 0.5360  loss: 0.5360  mse: 0.1797  
Epoch     8 - 302.25ms/epoch: accuracy: 0.7633  binary_crossentropy: 0.5071  loss: 0.5071  mse: 0.1671  
Epoch     9 - 248.69ms/epoch: accuracy: 0.7833  binary_crossentropy: 0.5063  loss: 0.5063  mse: 0.1636  
Epoch    10 - 255.62ms/epoch: accuracy: 0.7200  binary_crossentropy: 0.5173  loss: 0.5173  mse: 0.1754  


## Conclusions

* I'm giving up for now. It's not worth the effort.