<a href="https://colab.research.google.com/github/malcolmlett/ml-learning/blob/feature%2F20250110-explain-near-zero-grads/Learning_visualisations_v10a_MatMulExplainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Matrix Multiplication explainer
A tool for explaining the output value from matrix multiplication, particularly for near-zero outputs.

In [103]:
import os
if os.path.isdir('repo'):
  # discard any local changes and update
  !cd repo && git reset --hard HEAD
  !cd repo && git fetch
else:
  !git clone https://github.com/malcolmlett/ml-learning.git repo

# TEMP use branch
!cd repo && git checkout feature/20250110-explain-near-zero-grads

# lock to revision
#!cd repo && git checkout ea80c40
!cd repo && git pull

import sys
sys.path.append('repo')

import matmul_explainer as me
from importlib import reload
reload(me)

HEAD is now at 5e3f247 Tests working
Already on 'feature/20250110-explain-near-zero-grads'
Your branch is up to date with 'origin/feature/20250110-explain-near-zero-grads'.
Already up to date.


<module 'matmul_explainer' from '/content/repo/matmul_explainer.py'>

In [104]:
import matmul_explainer_test
reload(matmul_explainer_test)
matmul_explainer_test.run_test_suite()

In [97]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import math

## Classification
First we need to be able to classify the components that make up the matrix multiplication results.

We do that with a number of functions:

```
def classify_terms():
  return ['PP', 'PZ', 'PN', 'ZP', 'ZZ', 'ZN', 'NP', 'NZ', 'NN']

def matmul_classify(x1, x2, confidence: float=0.95, threshold1: float = None, threshold2: float = None):
 ...
 return counts, sums
```


In [63]:
# Standard ordered terms
print(f"Terms: {me.classify_terms()}")

Terms: ['PP', 'PZ', 'PN', 'ZP', 'ZZ', 'ZN', 'NP', 'NZ', 'NN']


In [80]:
# Simple 2D matmul explanation
reload(me)
a = np.arange(0.0, 1.0, 0.1)
a = np.tile(a, (10,1))
print(f"a = b = {a}")

counts, sums = me.matmul_classify(a, a, confidence=0.65)

print()
print("Details...")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[:,:,i]) > 0:
    print(f"Counts({name}): {counts[:,:,i]}")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[:,:,i]) > 0:
    print(f"Sums({name}): {sums[:,:,i]}")

print()
print("Summary...")
print(f"Classes: {me.classify_terms()}")
print(f"Counts by class: {np.sum(counts, axis=(0,1))}")
print(f"Sums by class: {np.sum(sums, axis=(0,1))}")

print()
print("Validation...")
print(f"True matmul: {np.matmul(a, a)}")
print(f"Derived matmul: {np.sum(sums, axis=-1)}")


a = b = [[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]]
threshold: 0.30000000000000004 (midpoint)

Details...
Counts(PP): [[0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]
 [0 0 0 0 6 6 6 6 6 6]]
Counts(PZ): [[6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]
 [6 6 6 6 0 0 0 0 0 0]]
Counts(ZP): [[0 0 0 0

In [161]:
# Convolution
def conv_classify(inputs, kernel, strides=1, padding="VALID", confidence: float = 0.95, inputs_threshold: float = None, kernel_threshold: float = None):
  """
  Like matmul_classify but for convolutions.
  Supports 1D, 2D and 3D convolution.

  Args:
      inputs: Tensor of rank N+2. `inputs` has shape
          `(batch_size,) + inputs_spatial_shape + (num_channels,)`
      kernel: Tensor of rank N+2. `kernel` has shape
          `(kernel_spatial_shape, num_input_channels, num_output_channels)`.
          `num_input_channels` should match the number of channels in
          `inputs`.
      strides: int or int tuple/list of `len(inputs_spatial_shape)`,
          specifying the strides of the convolution along each spatial
          dimension. If `strides` is int, then every spatial dimension shares
          the same `strides`.
      padding: string, either `"valid"` or `"same"`. `"valid"` means no
          padding is applied, and `"same"` results in padding evenly to the
          left/right or up/down of the input such that output has the
          same height/width dimension as the input when `strides=1`.
      confidence: statistical confidence (0.0 to 1.0) that you wish to meet
        that a value is accurately placed within the P, Z, or N categories.
        Higher values lead to more strict requirements for "near zero".
        1.0 only considers exactly 0.0 as "near zero".
      inputs_threshold: abs(X1) values less than this are considered near-zero,
        otherwise inferred from confidence
      kernel_threshold: abs(X2) values less than this are considered near-zero,
        otherwise inferred from confidence

  Returns:
      (counts, sums) containing the counts and sums of each component, respectively.
      Each a tensor with shape `(batch_size,) + inputs_spatial_shape + (num_channels,9)`.
  """
  # standardise on data format
  inputs = tf.constant(inputs)
  kernel = tf.constant(kernel)

  # determine thresholds
  # (note: on small matrices with few discrete numbers, percentile() will find a value on either side
  #  of the percentage threshold, thus we should apply the threshold rule as "zero if value < threshold"
  if inputs_threshold is None:
      inputs_threshold = tfp.stats.percentile(tf.abs(inputs), 100 * (1 - confidence), interpolation='midpoint')
  if kernel_threshold is None:
      kernel_threshold = tfp.stats.percentile(tf.abs(kernel), 100 * (1 - confidence), interpolation='midpoint')

  print(f"Thresholds: inputs={inputs_threshold}, kernel={kernel_threshold}")

  # create masks that classify each input individually
  inputs_p = inputs >= inputs_threshold
  inputs_z = np.abs(inputs) < inputs_threshold
  inputs_n = inputs <= -inputs_threshold

  print(f"Thresholded inputs:")
  print(f"inputs(p): {inputs_p[0,:,:,0]}")
  print(f"inputs(z): {inputs_z[0,:,:,0]}")
  print(f"inputs(n): {inputs_n[0,:,:,0]}")

  kernel_p = kernel >= kernel_threshold
  kernel_z = np.abs(kernel) < kernel_threshold
  kernel_n = kernel <= -kernel_threshold

  print(f"Thresholded kernel:")
  print(f"kernel(p): {kernel_p[:,:,0,0]}")
  print(f"kernel(z): {kernel_z[:,:,0,0]}")
  print(f"kernel(n): {kernel_n[:,:,0,0]}")

  # compute counts and sums for each classification
  counts = []
  inputs_pc = tf.cast(inputs_p, tf.float32)
  inputs_zc = tf.cast(inputs_z, tf.float32)
  inputs_nc = tf.cast(inputs_n, tf.float32)
  kernel_pc = tf.cast(kernel_p, tf.float32)
  kernel_zc = tf.cast(kernel_z, tf.float32)
  kernel_nc = tf.cast(kernel_n, tf.float32);
  counts.append(tf.nn.convolution(input=inputs_pc, filters=kernel_pc, strides=strides, padding=padding))
  counts.append(tf.nn.convolution(input=inputs_pc, filters=kernel_zc, strides=strides, padding=padding))
  counts.append(tf.nn.convolution(input=inputs_pc, filters=kernel_nc, strides=strides, padding=padding))
  counts.append(tf.nn.convolution(input=inputs_zc, filters=kernel_pc, strides=strides, padding=padding))
  counts.append(tf.nn.convolution(input=inputs_zc, filters=kernel_zc, strides=strides, padding=padding))
  counts.append(tf.nn.convolution(input=inputs_zc, filters=kernel_nc, strides=strides, padding=padding))
  counts.append(tf.nn.convolution(input=inputs_nc, filters=kernel_pc, strides=strides, padding=padding))
  counts.append(tf.nn.convolution(input=inputs_nc, filters=kernel_zc, strides=strides, padding=padding))
  counts.append(tf.nn.convolution(input=inputs_nc, filters=kernel_nc, strides=strides, padding=padding))

  sums = []
  inputs_pv = tf.where(inputs_p, inputs, tf.zeros_like(inputs))
  inputs_zv = tf.where(inputs_z, inputs, tf.zeros_like(inputs))
  inputs_nv = tf.where(inputs_n, inputs, tf.zeros_like(inputs))
  kernel_pv = tf.where(kernel_p, kernel, tf.zeros_like(kernel))
  kernel_zv = tf.where(kernel_z, kernel, tf.zeros_like(kernel))
  kernel_nv = tf.where(kernel_n, kernel, tf.zeros_like(kernel))
  sums.append(tf.nn.convolution(input=inputs_pv, filters=kernel_pv, strides=strides, padding=padding))
  sums.append(tf.nn.convolution(input=inputs_pv, filters=kernel_zv, strides=strides, padding=padding))
  sums.append(tf.nn.convolution(input=inputs_pv, filters=kernel_nv, strides=strides, padding=padding))
  sums.append(tf.nn.convolution(input=inputs_zv, filters=kernel_pv, strides=strides, padding=padding))
  sums.append(tf.nn.convolution(input=inputs_zv, filters=kernel_zv, strides=strides, padding=padding))
  sums.append(tf.nn.convolution(input=inputs_zv, filters=kernel_nv, strides=strides, padding=padding))
  sums.append(tf.nn.convolution(input=inputs_nv, filters=kernel_pv, strides=strides, padding=padding))
  sums.append(tf.nn.convolution(input=inputs_nv, filters=kernel_zv, strides=strides, padding=padding))
  sums.append(tf.nn.convolution(input=inputs_nv, filters=kernel_nv, strides=strides, padding=padding))

  # format into final output
  return tf.stack(counts, axis=-1), tf.stack(sums, axis=-1)


a = np.arange(0.0, 0.9, 0.1)
a = np.tile(a, (9,1)).astype(np.float32)
a = tf.reshape(a, shape=(1,9,9,1))
print(f"a = b: {a.shape} = {a[0,:,:,0]}")

k = np.array([
    [-1, 0, -1],
    [+1, +1, +1],
    [-1, 0, -1]
]).astype(np.float32)
k = tf.reshape(k, shape=(3,3,1,1))
print(f"k = {k.shape} = {k[:,:,0,0]}")

counts, sums = conv_classify(a, k, confidence=0.75)
print()
print("Results...")
print(f"Shapes: counts={counts.shape}, sums={sums.shape}")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[...,i]) > 0:
    print(f"Counts({name}): {counts[0,:,:,0,i]}")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[...,i]) > 0:
    print(f"Sums({name}): {sums[0,:,:,0,i]}")

print()
print("Summary...")
print(f"Classes: {me.classify_terms()}")
print(f"Counts by class: {np.sum(counts, axis=(0,1,2,3))}")
print(f"Sums by class: {np.sum(sums, axis=(0,1,2,3))}")

print()
print("Validation...")
expected_conv = tf.nn.convolution(a, k)
derived_conv = tf.reduce_sum(sums, axis=-1)
print(f"True conv: {expected_conv.shape} = {expected_conv[0,:,:,0]}")
print(f"Derived conv: {derived_conv.shape} = {derived_conv[0,:,:,0]}")
print(f"Devired conv == true conv: {np.allclose(expected_conv, derived_conv)}")


a = b: (1, 9, 9, 1) = [[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]]
k = (3, 3, 1, 1) = [[-1.  0. -1.]
 [ 1.  1.  1.]
 [-1.  0. -1.]]
Thresholds: inputs=0.20000000298023224, kernel=1.0
Thresholded inputs:
inputs(p): [[False False  True  True  True  True  True  True  True]
 [False False  True  True  True  True  True  True  True]
 [False False  True  True  True  True  True  True  True]
 [False False  True  True  True  True  True  True  True]
 [False False  True  True  True  True  True  True  True]
 [False False  True  True  True  True  True  True  True]
 [False False  True  True  True  True  True  True  True]
 [False False  True  True  True  True  True  True  True]
 [False False  True  Tr

In [160]:
k = np.array([
    [-1, 0, -1],
    [+1, +1, +1],
    [-1, 0, -1]
]).astype(np.float32)
k = tf.reshape(k, shape=(3,3,1,1))
print(f"k = {k.shape} = {k[:,:,0,0]}")

#tfp.stats.percentile(tf.abs(k), 100 * (1 - 0.95), interpolation='midpoint')
tf.abs(k)[:,:,0,0]

k = (3, 3, 1, 1) = [[-1.  0. -1.]
 [ 1.  1.  1.]
 [-1.  0. -1.]]


<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 1.],
       [1., 1., 1.],
       [1., 0., 1.]], dtype=float32)>

In [120]:
a = np.arange(0.0, 0.9, 0.1)
a = np.tile(a, (9,1)).astype(np.float32)
a = tf.reshape(a, shape=(1,9,9,1))
print(f"a = b: {a.shape} = {a[0,:,:,0]}")

k = np.array([
    [-1, 0, -1],
    [+1, +1, +1],
    [-1, 0, -1]
]).astype(np.float32)
k = tf.reshape(k, shape=(3,3,1,1))
print(f"k = {k.shape} = {k[:,:,0,0]}")

tf.nn.convolution(a, k)

a = b: (1, 9, 9, 1) = [[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]]
k = (3, 3, 1, 1) = [[-1.  0. -1.]
 [ 1.  1.  1.]
 [-1.  0. -1.]]


<tf.Tensor: shape=(1, 7, 7, 1), dtype=float32, numpy=
array([[[[-0.09999999],
         [-0.20000002],
         [-0.29999995],
         [-0.40000004],
         [-0.5       ],
         [-0.6       ],
         [-0.70000005]],

        [[-0.09999999],
         [-0.20000002],
         [-0.29999995],
         [-0.40000004],
         [-0.5       ],
         [-0.6       ],
         [-0.70000005]],

        [[-0.09999999],
         [-0.20000002],
         [-0.29999995],
         [-0.40000004],
         [-0.5       ],
         [-0.6       ],
         [-0.70000005]],

        [[-0.09999999],
         [-0.20000002],
         [-0.29999995],
         [-0.40000004],
         [-0.5       ],
         [-0.6       ],
         [-0.70000005]],

        [[-0.09999999],
         [-0.20000002],
         [-0.29999995],
         [-0.40000004],
         [-0.5       ],
         [-0.6       ],
         [-0.70000005]],

        [[-0.09999999],
         [-0.20000002],
         [-0.29999995],
         [-0.40000004],
