# Fasterer Python

Pure Python is slow, for two reasons. First, it is an **interpreted language**, which is always less efficient than a compiled language. Second, Python was originally conceived to work on a single thread, to make programming easier. It has a **Global Interpreter Lock (GIL)**, which means that only one python instruction is carried out at a time.

However, there are easy ways to make Python code run **very fast**, circumventing these problems.

First, Python is very good at interfacing with non-Python libraries. This means you can write fast, compiled code in Rust or C++, and then call it from Python. The compiled libraries are exempt from the GIL, meaning that they can do their own threading, as well.

Second, there are a couple forms of parallelism within Python which circumvent the GIL.

## Numpy

Numpy ("Numerical Python") is a fast library for doing all kinds of numerical array operations in Python. It's written in an unholy combination of C++ and Fortran, but as a user, you never have to worry about that!

The golden rule of Numpy is that using numpy operations on numpy arrays is going to be fast, but using Python operations will be slow. For example, iterating over a matrix with for loops is slow, but operating on a matrix with numpy operations is fast.

Let's see an example...

In [None]:
# Everyone calls numpy np.
import numpy as np

# Make a matrix of random numbers.
mat = np.random.default_rng().random((1024 * 1024, 2), dtype=np.float32)

# Use a loop to multiply it by 2.
def times_two(mat):
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            mat[i, j] *= 2
    return mat

# Use numpy to multiply it by 2.
def times_two_numpy(mat):
    return mat * 2

# On tmd@'s computer, the times_two_numpy is ~500x faster.
%time mat = times_two(mat)
%time mat = times_two_numpy(mat)

# Numpy Basics

* Most of the time, numpy works with *arrays*.
* Arrays are "rectangular", with any number of *axes*.
* You can get the size of an array with the `.shape` attribute.
* Often you can perfrom an operation on a particular axis: eg, `np.mean(X, axis=3)`.
* When combining two arrays, you can often *broadcast* an operation if the shapes match except some axes with size 1.

### Extremely useful operations

* Construct arrays
  * `np.ones`, `np.zeros` - make constant one/zero arrays
  * `np.eye` - make an identity matrix.
* Common mathy operations
  * `np.mean`, `np.std`, `np.maximum`, `np.minimum`, etc...
  * `np.linalg.norm` - L2 / Euclidean norm.
  * `np.sort`, `np.argsort`
  * `np.dot` (dot product), `np.matmul`
* Manipulate arrays
  * `np.concat` - Concatenate on some axis. (eg, combine arrays with shapes [2, 3], [3, 3] to a single array with shape [5, 3] by concat'ing along the 0th axis.)
  * `np.stack` - Stack arrays with matching shape along a new axis. (eg,  combine arrays with shapes [4, 3], [4, 3] to a new array with shape [2, 4, 3].)
  * `np.reshape` - Re-organize the shape of an array, eg [10, 8] -> [2, 5, 8].
  * `np.transpose` - Swap axes.
* Slices
  * Slices notation can select a sub-array, eg `X[0:4]` selects the first three rows of `X`. You can slice along multiple axes (`X[0:4,


In [None]:
X = np.round(np.random.uniform(size=(3, 3)), 2)
print(X)

print('\nsorted on axis 0 (cols): ')
print(np.sort(X, axis=0))

print('\nsorted on axis 1 (rows): ')
print(np.sort(X, axis=1))

print('\nindices of sorted cols: ')
print(np.argsort(X, axis=0))

np.dot()

## Building a Tokenizer

Let's get the hangs of working with common Numpy operations by writing a small **tokenizer**.

A tokenizer converts vectors to discrete objects... It's a fancy way of saying that we'll do some clustering, and number a vector according to which cluster it wound up in.

There are approximately a thousand ways to do clustering, but it turns out that k-means is still pretty damned useful, and more-or-less what people do in practice. So, let's do some k-means clustering.

In [None]:
def generate_initial_centroids(data: np.ndarray, k: int) -> np.ndarray:
  """Generate initial centroids for k-means clustering.

  Args:
    data: The data to cluster.
    k: The number of clusters to generate.

  Returns:
    The initial centroids.
  """
  return data[np.random.choice(data.shape[0], k, replace=False)]


def assign_to_clusters(data: np.ndarray, centroids: np.ndarray) -> np.ndarray:
  """Assign data points to clusters.

  Args:
    data: The data to cluster.
    centroids: The centroids to use for clustering.

  Returns:
    Array with numerically assigned clusters.
  """
  distances = np.linalg.norm(data[:, np.newaxis] - centroids, axis=2)
  return np.argmin(distances, axis=1)


def update_centroids(data: np.ndarray, centroids: np.ndarray) -> np.ndarray:
  """Update the centroids to be the mean of the data points in each cluster.

  Args:
    data: The data to cluster.
    centroids: The centroids to use for clustering.

  Returns:
    The updated centroids.
  """
  assignments = assign_to_clusters(data, centroids)
  for i in range(centroids.shape[0]):
    centroids[i] = np.mean(data[assignments == i], axis=0)
  return centroids

def reconstruction_loss(data: np.ndarray, centroids: np.ndarray) -> float:
  """Compute the mean distance to the closest centroid.

  Args:
    data: The data to cluster.
    centroids: The centroids to use for clustering.

  Returns:
    The mean distance to the closest centroid.
  """
  distances = np.linalg.norm(data[:, np.newaxis] - centroids, axis=2)
  return np.mean(np.min(distances, axis=1))

In [None]:
k = 64
n = 1024 * 8

mat = np.random.default_rng().random((n, 2), dtype=np.float32)
%time centroids = generate_initial_centroids(mat, k)
print(centroids.shape)

import tqdm
pbar = tqdm.tqdm(range(100))
for i in pbar:
  assign_to_clusters(mat, centroids)
  update_centroids(mat, centroids)
  loss = reconstruction_loss(mat, centroids)
  pbar.set_description(f'loss: {loss:8.3f}')

In [None]:
from matplotlib import pyplot as plt
colors = assign_to_clusters(mat, centroids)

plt.scatter(mat[:, 0], mat[:, 1], c=colors, alpha=0.1)
plt.scatter(centroids[:, 0], centroids[:, 1], c='r', marker='*', s=100.0)
plt.show()

Extra challenges...

* Usually, you can't fit all the data in memory. In that case, you want to do *partial updates* of the cluster centroids. Change the `update_centroids` to take a step-size parameter. Replace each centroid with a weighted average of the current centroid and the mean of the points assigned to the centroid.

* A problem for k-means clustering on real data is 'dead' centroids. Come up with a rule to replace centroids which haven't been used in a while with a 'fresh' centroid.

* (Advanced!) When you have high-dimensional data, you need more centroids to represent the data. When you have lots of centroids, things slow down. What to do?? In *product quantization*, you break a vector into parts and apply k-means to each subvector. Then for any vector, it gets assigned a cluster for each sub-vector. Create a collection of centroids with shape `[K, P, N/P]`, and update your algorithm to train centroids for all of the sub-vectors. Try to stick to Numpy operations as much as possible!

# Fast Data Pipeline with Threading

The GIL means that threads are not always helpful in Python, but there are two special cases where Python threading can be helpful:
* I/O operations can occur without blocking the GIL, and
* Calls to C++/Rust libraries (eg, many librosa operations) release the GIL while doing the heavy lifting.

In both of these cases, we can improve execution speed by parallelizing effectively.

This is particularly important for *data pipelines*. When training a model, we typically need to feed it subsets of our (gigantic, too big to fit in memory) dataset. Without threading, we need to load the next batch of data, wait for the model to evaluate and update, then load the next batch of data, and so on. This means that our fancy GPU is going to waste while we load the data!

Instead, we can load data *while* the model is running, and (hopefully) have the next batch of data ready and waiting. This greatly improves throughput. At the same time, we can make the data loader multi-threaded, to load data from many files in parallel.

Let's start with a very simple example of how to use a `ThreadPoolExecutor`.

In [None]:
from concurrent.futures import ThreadPoolExecutor

# The function we want the workers to evaluate.
fun = lambda a: a**2

# Create an `executor` that we can pass work to.
with ThreadPoolExecutor(max_workers=2) as executor:
  results = []
  for i in range(10):
    # Each result is a `future` object, which can tell us whether
    # it is still running, and eventually provide a result.
    results.append(executor.submit(fun, i))
  results = [r.result() for r in results]
  print(results)

# We can also use a `map` approach more concisely.
with ThreadPoolExecutor(max_workers=2) as executor:
  # map_results is a generator, yielding results as they are ready.
  map_results = executor.map(fun, range(10))
  print(list(map_results))


In [None]:
import dataclasses
from typing import Generator
from etils import epath
import numpy as np
import librosa
import tqdm


@dataclasses.dataclass
class AudioExample:
  audio: np.ndarray
  file_id: str
  offset: float


def data_loader(
    base_path: str,
    file_glob: str,
    window_size_s: float,
    sample_rate: int) -> Generator[AudioExample, None, None]:
  base_path = epath.Path(base_path)
  glob = base_path.glob(file_glob)
  window_size = int(window_size_s * sample_rate)
  for f in glob:
    file_id = f.relative_to(base_path)
    audio, _ = librosa.load(f.as_posix(), sr=sample_rate, res_type='polyphase')
    if len(audio.shape) == 2:
      audio = audio[0]
    # pad to multiple of 5s.
    pad_samples = window_size - (audio.shape[-1] % window_size)
    audio = np.pad(audio, (0, pad_samples))
    # reshape into 5s chunks.
    audio = np.reshape(audio, [-1, window_size])
    for i, a in enumerate(audio):
      yield AudioExample(audio=a, file_id=file_id, offset=i * window_size_s)

for x in tqdm.tqdm(data_loader(
    '/usr/local/google/home/tomdenton/terrorbyte/anuraset',
    'raw_data/INCT17/*.wav',
    5.0, 32000)):
  pass

In [None]:
import time

def data_loader_threaded(
    base_path: str,
    file_glob: str,
    window_size_s: float,
    sample_rate: int,
    max_workers: int = 2) -> Generator[AudioExample, None, None]:
  base_path = epath.Path(base_path)
  glob = base_path.glob(file_glob)
  window_size = int(window_size_s * sample_rate)

  def _pad_reshape(audio):
    # pad to multiple of 5s.
    pad_samples = window_size - (audio.shape[-1] % window_size)
    audio = np.pad(audio, (0, pad_samples))
    # reshape into 5s chunks.
    audio = np.reshape(audio, [-1, window_size])
    return audio

  def _load_audio(f: epath.Path) -> np.ndarray:
    audio, _ = librosa.load(f.as_posix(), sr=sample_rate, res_type='polyphase')
    if len(audio.shape) == 2:
      audio = audio[0]
    return f, _pad_reshape(audio)

  executor = ThreadPoolExecutor(max_workers=max_workers, )
  try:
    for f, audio in executor.map(_load_audio, glob):
      file_id = f.relative_to(base_path)
      for i, a in enumerate(audio):
        yield AudioExample(audio=a, file_id=file_id, offset=i * window_size_s)
  finally:
    executor.shutdown(wait=False, cancel_futures=True)

for i, x in tqdm.tqdm(enumerate(data_loader_threaded(
    '/usr/local/google/home/tomdenton/terrorbyte/anuraset',
    'raw_data/INCT17/*.wav',
    5.0, 32000, max_workers=2))):
  pass


In [None]:
base_path = epath.Path('/usr/local/google/home/tomdenton/terrorbyte/anuraset')
glob = base_path.glob('raw_data/INCT17/*.wav')
glob

Challenge: Our current data iterator is pretty good, but has a subtle failure mode. If we process the data slower than the iterator loads the data, we can wind up with an ever-growing backlog of loaded data, until we run out of memory. Refactor the data loader so that it only preloads a fixed number of examples at a time.

### Itertools

In many contexts, we will want to deal with **batches** of examples. Or we might want to manipulate the examples in one way or another.

Now that we have a data iterator, we are in good position to make use of `itertools`, a Python package for working with iterators.

# Intro to Jax


### Use jit to run a matrix multiply on the GPU.

In [None]:
# Multiply a vector by a matrix in numpy.

input_dim = 128
output_dim = 1024

inputs = np.random.normal(size=(100, input_dim))
weights = np.random.normal(size=(input_dim, output_dim))
bias = np.random.normal(size=(output_dim,))

%time got = np.dot(inputs, weights) + bias
%time got = np.dot(inputs, weights) + bias
%time got = np.dot(inputs, weights) + bias

In [None]:
# Now let's do it with jax.

import jax
from jax import numpy as jnp

@jax.jit
def dense_layer(inputs, weights, bias):
  return jnp.dot(inputs, weights) + bias

jinputs = jnp.array(inputs)
jweights = jnp.array(weights)
jbias = jnp.array(bias)
%time got = dense_layer(jinputs, jweights, jbias)
%time got = dense_layer(jinputs, jweights, jbias)
%time got = dense_layer(jinputs, jweights, jbias)

In [None]:
# Example of a helper function.

def normalize(inputs: jnp.ndarray) -> jnp.ndarray:
  return inputs / jnp.linalg.norm(inputs, axis=1, keepdims=True)

@jax.jit
def dense_layer(inputs, weights, bias):
  return jnp.dot(normalize(inputs), weights) + bias

jinputs = jnp.array(inputs)
jweights = jnp.array(weights)
jbias = jnp.array(bias)
%time got = dense_layer(jinputs, jweights, jbias)
%time got = dense_layer(jinputs, jweights, jbias)
%time got = dense_layer(jinputs, jweights, jbias)

### Generate data from a mixture of Gaussians using jnp.random.

In [None]:
# Exercise: Generate multivariate Gaussian samples using jnp.random.

def gaussian_generator(key, num_samples, mu, sigma):
  """Generate samples from a multivariate Gaussian distribution."""
  # TODO: Fix this.
  while True:
    yield jnp.ones((num_samples, mu.shape[0]))


In [None]:
# Exercise: Generate a mixture of multivariate Gaussian samples.

def mixed_gaussian_generator(key, num_samples, mus, sigmas, pis):
  """Generate samples from a mixture of Gaussians."""
  # TODO: Fix this.
  while True:
    yield jnp.ones((num_samples, mus.shape[0]))


In [None]:
# Solution.

def mixed_gaussian_generator(key, num_samples, mus, sigmas, pis):
  """Generate samples from a mixture of Gaussians.

  Args:
    key: A JAX PRNGKey.
    num_samples: The number of samples to generate.
    mus: An array of means for the Gaussians, with shape [num_components, dim].
    sigmas: Array of standard deviations for the Gaussians.
    pis: An array of probabilities for each Gaussian.

  Yields:
    An array of samples from the mixture of Gaussians.
  """
  num_components, dim = mus.shape
  while True:
    key = jax.random.split(key)[0]
    normal_samples = jax.random.normal(key, shape=(num_samples, dim))
    components = jax.random.categorical(
        key, jnp.log(jnp.maximum(pis, 1e-6)), shape=(num_samples,))
    yield mus[components] + normal_samples * sigmas[components]

In [None]:
from matplotlib import pyplot as plt

n_components = 3
dim = 2
mus = jnp.array(np.random.normal(size=(n_components, dim)))
sigmas = jnp.array(
    np.random.uniform(low=0.1, high=0.4, size=(n_components, dim)))
pis = jnp.array(np.random.uniform(size=(n_components,)))
pis = pis / jnp.sum(pis)

got = next(mixed_gaussian_generator(
    key=jax.random.PRNGKey(0),
    num_samples=1000,
    mus=mus,
    sigmas=sigmas,
    pis=pis))
print(got.shape)
plt.scatter(got[:, 0], got[:, 1], alpha=0.25)

### Tokenizer in Jax

In [None]:
# Exercise: Select data points as the initial centroids.

def generate_initial_centroids(
    key: jax.random.PRNGKey, data: jnp.ndarray, k: int) -> jnp.ndarray:
  pass

# Exercise: Assign data points to clusters.

def assign_to_clusters(
    data: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray:
  pass

# Exercise: Update the centroids.
# This time, instead of replacing the centroid, return an average of the
# old centroids and the new centroids, weighted by mu and (1- mu).
def update_centroids(
    data: jnp.ndarray,
    centroids: jnp.ndarray,
    mu: float) -> jnp.ndarray:
  """Update the centroids to be the mean of the data points in each cluster."""
  pass

# Exercise: Compute the reconstruction loss.
def reconstruction_loss(data: jnp.ndarray, centroids: jnp.ndarray) -> float:
  """Compute the mean distance to the closest centroid."""
  pass

# Exercise: Write the training step.
@jax.jit
def step(
    data: jnp.ndarray,
    centroids: jnp.ndarray,
    mu: float) -> tuple[jnp.ndarray, float]:
  """Perform one step of k-means clustering."""
  pass



In [None]:
# Run the training loop.
import tqdm
pbar = tqdm.tqdm(range(1000))
for i in pbar:
  data = next(data_generator)
  centroids, loss = step(data, centroids, mu)
  pbar.set_description(f'loss: {loss:8.3f}')


In [None]:
# Tokenizer in Jax

def generate_initial_centroids(
    key: jax.random.PRNGKey, data: jnp.ndarray, k: int) -> jnp.ndarray:
  """Generate initial centroids for k-means clustering."""
  return data[jax.random.choice(key, data.shape[0], shape=(k,), replace=False)]


def assign_to_clusters(
    data: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray:
  """Assign data points to clusters."""
  distances = jnp.linalg.norm(data[:, jnp.newaxis] - centroids, axis=2)
  return jnp.argmin(distances, axis=1)


def update_centroids(
    data: jnp.ndarray,
    centroids: jnp.ndarray,
    mu: float) -> jnp.ndarray:
  """Update the centroids to be the mean of the data points in each cluster."""
  assignments = assign_to_clusters(data, centroids)
  for i in range(centroids.shape[0]):
    mask = (assignments == i)[:, jnp.newaxis]
    data_mu = mu * jnp.sum(data * mask, axis=0) / jnp.maximum(jnp.sum(mask), 1e-6)
    updated = (1 - mu) * centroids[i] + data_mu
    centroids = centroids.at[i].set(updated)
  return centroids

def reconstruction_loss(data: jnp.ndarray, centroids: jnp.ndarray) -> float:
  """Compute the mean distance to the closest centroid."""
  distances = jnp.linalg.norm(data[:, jnp.newaxis] - centroids, axis=2)
  return jnp.mean(jnp.min(distances, axis=1))


@jax.jit
def step(
    data: jnp.ndarray,
    centroids: jnp.ndarray,
    mu: float) -> tuple[jnp.ndarray, float]:
  """Perform one step of k-means clustering."""
  centroids = update_centroids(data, centroids, mu)
  loss = reconstruction_loss(data, centroids)
  return centroids, loss

In [None]:
# Create a random mixture-of-Gaussians distribution to select data from.
seed = np.random.randint(0, 1000000)
n_components = 3
dim = 2
mus = jnp.array(
    np.random.normal(size=(n_components, dim)))
sigmas = jnp.array(
    np.random.uniform(low=0.1, high=0.4, size=(n_components, dim)))
pis = jnp.array(
    np.random.uniform(size=(n_components,)))
pis = pis / jnp.sum(pis)

data_generator = mixed_gaussian_generator(
    key=jax.random.PRNGKey(seed),
    num_samples=1000,
    mus=mus,
    sigmas=sigmas,
    pis=pis)


# Train the tokenizer.
k = 8
mu = 0.1
initial_data = next(data_generator)
centroids = generate_initial_centroids(
    jax.random.PRNGKey(42), initial_data, k)

import tqdm
pbar = tqdm.tqdm(range(1000))
for i in pbar:
  data = next(data_generator)
  centroids, loss = step(data, centroids, mu)
  pbar.set_description(f'loss: {loss:8.3f}')


In [None]:
from matplotlib import pyplot as plt

# Plot some data and the centroids.
plt.scatter(initial_data[:, 0], initial_data[:, 1], alpha=0.25)
plt.scatter(centroids[:, 0], centroids[:, 1], color='red', alpha=0.5)

# Linear Classifier with Optax.

Now let's use Optax to write a tiny classifier...

In [None]:
# Generate some gaussians with labels.

def labeled_mixed_gaussian_generator(key, num_samples, mus, sigmas, pis):
  """Generate samples from a mixture of Gaussians and their labels.

  Args:
    key: A JAX PRNGKey.
    num_samples: The number of samples to generate.
    mus: An array of means for the Gaussians, with shape [num_components, dim].
    sigmas: Array of standard deviations for the Gaussians.
    pis: An array of probabilities for each Gaussian.

  Yields:
    An array of samples from the mixture of Gaussians and the source components.
  """
  num_components, dim = mus.shape
  while True:
    key = jax.random.split(key)[0]
    normal_samples = jax.random.normal(key, shape=(num_samples, dim))
    components = jax.random.categorical(
        key, jnp.log(jnp.maximum(pis, 1e-6)), shape=(num_samples,))
    features = mus[components] + normal_samples * sigmas[components]
    yield features, components

num_features = 128
batch_size = 64
num_classes = 3

data_generator = labeled_mixed_gaussian_generator(
    key=jax.random.PRNGKey(42),
    num_samples=batch_size,
    mus=np.random.normal(size=(num_classes, num_features)),
    sigmas=np.random.uniform(
        low=0.1, high=0.4, size=(num_classes, num_features)),
    pis=jnp.ones(num_classes) / num_classes)


In [None]:
import optax

# Initialize the weights and bias for the linear model.
params = {
    'weights': jnp.array(np.random.normal(size=(num_features, num_classes))),
    'bias': jnp.array(np.random.normal(size=(num_classes,)))
}

def infer_logits(data, weights, bias):
  """Infer the labels for the given data."""
  return jnp.dot(data, weights) + bias

optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)
loss_fn = optax.sigmoid_binary_cross_entropy

def compute_loss(params, data, labels):
  """Compute the loss for the given data and labels."""
  logits = infer_logits(data, params['weights'], params['bias'])
  return jnp.mean(loss_fn(logits, labels))

@jax.jit
def train_step(data, labels, opt_state, params):
  # convert labels to one-hot.
  one_hot_labels = jax.nn.one_hot(labels, num_classes)
  loss, grad = jax.value_and_grad(compute_loss)(
      params, data, one_hot_labels)
  updates, opt_state = optimizer.update(grad, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss


In [None]:
import tqdm
pbar = tqdm.tqdm(range(1024))
for i in pbar:
  data, labels = next(data_generator)
  params, opt_state, loss = train_step(data, labels, opt_state, params)
  pbar.set_description(f'loss: {loss:8.3f}')


## Neural Networks with flax

Flax provides tools to handle bigger models efficiently.
The `flax.nnx` interfaces are used to define model layers, which then join together into the full model. It takes care of all of the parameter naming and such automagically.

In [None]:
from flax import nnx

class TwoLayerDenseModel(nnx.Module):
  """A simple two-layer neural network."""

  def __init__(
      self,
      input_dim: int,
      hidden_dim: int,
      output_dim: int,
      dropout_rate: float,
      rngs: nnx.Rngs):
    self.batch_norm = nnx.BatchNorm(num_features=input_dim, rngs=rngs)
    self.dense1 = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
    self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
    self.dense2 = nnx.Linear(hidden_dim, output_dim, rngs=rngs)

  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    x = self.batch_norm(x)
    x = self.dense1(x)
    x = jax.nn.relu(x)
    x = self.dropout(x)
    x = self.dense2(x)
    return x


In [None]:
learning_rate = 0.005
momentum = 0.9

model = TwoLayerDenseModel(
    input_dim=num_features,
    hidden_dim=32,
    output_dim=num_classes,
    dropout_rate=0.5,
    rngs=nnx.Rngs(0))

data, labels = next(data_generator)
y = model(data)
print(y.shape)

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)


In [None]:
def loss_fn(model, features, labels):
  logits = model(features)
  one_hot_labels = jax.nn.one_hot(labels, num_classes)
  loss = optax.sigmoid_binary_cross_entropy(logits, one_hot_labels)
  return jnp.mean(loss), logits

@nnx.jit
def train_step(model, optimizer, metrics, data, labels):
  """Train one step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grad = grad_fn(model, data, labels)
  metrics.update(loss=loss, logits=logits, labels=labels)
  optimizer.update(grad)
  return loss

metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
}
pbar = tqdm.tqdm(range(1024))
for step in pbar:
  data, labels = next(data_generator)
  loss = train_step(model, optimizer, metrics, data, labels)
  pbar.set_description(f'loss: {loss:8.3f}')

  if step % 64 == 0:
    # log the metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'train_{metric}'].append(value)
    metrics.reset()


In [None]:
plt.plot(metrics_history['train_loss'])
plt.ylabel('loss')
plt.xlabel('step/64')
plt.show()
plt.plot(metrics_history['train_accuracy'])
plt.ylabel('accuracy')
plt.xlabel('step/64')
plt.show()