<a href="https://colab.research.google.com/github/malcolmlett/ml-learning/blob/main/Learning_visualisations_v12a_MatMulExplainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Matrix Multiplication explainer
A tool for explaining the output value from matrix multiplication, particularly for near-zero outputs.

In [37]:
import os
if os.path.isdir('repo'):
  # discard any local changes and update
  !cd repo && git reset --hard HEAD
  !cd repo && git fetch
else:
  !git clone https://github.com/malcolmlett/ml-learning.git repo

# lock to revision
#!cd repo && git checkout ea80c40
!cd repo && git pull

import sys
sys.path.append('repo')

import matmul_explainer as me
from importlib import reload
reload(me)

HEAD is now at db39277 Bug fixes and additional unit test scenarios
remote: Enumerating objects: 5, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (1/1), done.[K
remote: Total 3 (delta 2), reused 3 (delta 2), pack-reused 0 (from 0)[K
Unpacking objects: 100% (3/3), 287 bytes | 143.00 KiB/s, done.
From https://github.com/malcolmlett/ml-learning
   db39277..fdc460c  main       -> origin/main
Updating db39277..fdc460c
Fast-forward
 matmul_explainer.py | 2 [32m+[m[31m-[m
 1 file changed, 1 insertion(+), 1 deletion(-)


<module 'matmul_explainer' from '/content/repo/matmul_explainer.py'>

In [30]:
import matmul_explainer_test
reload(matmul_explainer_test)
matmul_explainer_test.run_test_suite()

All matmul_explainer tests passed.


In [3]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import math

## Classification
First we need to be able to classify the components that make up the matrix multiplication results.

We do that with a number of functions:

```
def classify_terms():
  return ['PP', 'PZ', 'PN', 'ZP', 'ZZ', 'ZN', 'NP', 'NZ', 'NN']

def matmul_classify(x1, x2, confidence: float=0.95, threshold1: float = None, threshold2: float = None):
 ...
 return counts, sums
```


In [4]:
# Standard ordered terms
print(f"Terms: {me.classify_terms()}")

Terms: ['PP', 'PZ', 'PN', 'ZP', 'ZZ', 'ZN', 'NP', 'NZ', 'NN']


In [5]:
# Check that the summarisation API works under a number of usages:
reload(me)
a = np.arange(0.0, 1.0, 0.1)
a = np.tile(a, (10,1))
counts, sums = me.matmul_classify(a, a, confidence=0.75)

print(f"Summary when passed counts and sums separately:\n  {me.summarise(counts, sums)}")
print(f"Summary when passed counts and sums directly from classification:\n  {me.summarise(me.matmul_classify(a, a, confidence=0.75))}")


Summary when passed counts and sums separately:
  PP: 640 = 193.6, ZP: 160 = 4.399999999999999, PZ: 160 = 4.4, ZZ: 40 = 0.10000000000000003
Summary when passed counts and sums directly from classification:
  PP: 640 = 193.6, ZP: 160 = 4.399999999999999, PZ: 160 = 4.4, ZZ: 40 = 0.10000000000000003


In [31]:
# Show how the values within tensors get classified
def show_classification_results(x, confidence):
  p, z, n, t = me.classification_mask(x, confidence=confidence)
  print(f"classification results for {x} @ {confidence}:")
  print(f"  threshold: {t}")
  print(f"  p: {p}")
  print(f"  z: {z}")
  print(f"  n: {n}")

a = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
show_classification_results(a, confidence=0.75)

a = np.array([-0.5, 0, -0.5])
show_classification_results(a, confidence=0.75)

a = np.array([-0.5, 0.25, 0, 0.25, -0.5])
show_classification_results(a, confidence=0.95)

classification results for [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9] @ 0.75:
  threshold: 0.25
  p: [False False False  True  True  True  True  True  True  True]
  z: [ True  True  True False False False False False False False]
  n: [False False False False False False False False False False]
classification results for [-0.5  0.  -0.5] @ 0.75:
  threshold: 0.25
  p: [False False False]
  z: [False  True False]
  n: [ True False  True]
classification results for [-0.5   0.25  0.    0.25 -0.5 ] @ 0.95:
  threshold: 0.125
  p: [False  True False  True False]
  z: [False False  True False False]
  n: [ True False False False  True]


In [32]:
# Simple 2D matmul explanation
reload(me)
a = np.arange(0.0, 1.0, 0.1)
a = np.tile(a, (10,1))
print(f"a = b = {a}")

counts, sums = me.matmul_classify(a, a, confidence=0.75)

print()
print("Details...")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[:,:,i]) > 0:
    print(f"Counts({name}): {counts[:,:,i]}")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[:,:,i]) > 0:
    print(f"Sums({name}): {sums[:,:,i]}")

print()
print("Summary...")
print(f"Classes: {me.classify_terms()}")
print(f"Counts by class: {np.sum(counts, axis=(0,1))}")
print(f"Sums by class: {np.sum(sums, axis=(0,1))}")
print(f"Summary: {me.summarise(counts, sums)}")

print()
print("Validation...")
print(f"True matmul: {np.matmul(a, a)}")
print(f"Derived matmul: {np.sum(sums, axis=-1)}")


a = b = [[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]]

Details...
Counts(PP): [[0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]
 [0 0 8 8 8 8 8 8 8 8]]
Counts(PZ): [[8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]
 [8 8 0 0 0 0 0 0 0 0]]
Counts(ZP): [[0 0 2 2 2 2 2 2 2 2]
 [0 0 2 2 2 2 2 2 2 2]
 [0 0

In [33]:
# 2D convolution explanation
a = np.arange(0.0, 0.9, 0.1)
a = np.tile(a, (9,1)).astype(np.float32)
a = tf.reshape(a, shape=(1,9,9,1))
print(f"a = b: {a.shape} = {a[0,:,:,0]}")

k = np.array([
    [-1, 0, -1],
    [+1, +1, +1],
    [-1, 0, -1]
]).astype(np.float32)
k = tf.reshape(k, shape=(3,3,1,1))
print(f"k = {k.shape} = {k[:,:,0,0]}")

counts, sums = me.conv_classify(a, k, confidence=0.75)
print()
print("Results...")
print(f"Shapes: counts={counts.shape}, sums={sums.shape}")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[...,i]) > 0:
    print(f"Counts({name}): {counts[0,:,:,0,i]}")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[...,i]) > 0:
    print(f"Sums({name}): {sums[0,:,:,0,i]}")

print()
print("Summary...")
print(f"Classes: {me.classify_terms()}")
print(f"Counts by class: {np.sum(counts, axis=(0,1,2,3))}")
print(f"Sums by class: {np.sum(sums, axis=(0,1,2,3))}")
print(f"Summary: {me.summarise(counts, sums)}")

print()
print("Validation...")
expected_conv = tf.nn.convolution(a, k)
derived_conv = tf.reduce_sum(sums, axis=-1)
print(f"True conv: {expected_conv.shape} = {expected_conv[0,:,:,0]}")
print(f"Derived conv: {derived_conv.shape} = {derived_conv[0,:,:,0]}")
print(f"Devired conv == true conv: {np.allclose(expected_conv, derived_conv)}")


a = b: (1, 9, 9, 1) = [[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]]
k = (3, 3, 1, 1) = [[-1.  0. -1.]
 [ 1.  1.  1.]
 [-1.  0. -1.]]

Results...
Shapes: counts=(1, 7, 7, 1, 9), sums=(1, 7, 7, 1, 9)
Counts(PP): [[1. 2. 3. 3. 3. 3. 3.]
 [1. 2. 3. 3. 3. 3. 3.]
 [1. 2. 3. 3. 3. 3. 3.]
 [1. 2. 3. 3. 3. 3. 3.]
 [1. 2. 3. 3. 3. 3. 3.]
 [1. 2. 3. 3. 3. 3. 3.]
 [1. 2. 3. 3. 3. 3. 3.]]
Counts(PZ): [[0. 2. 2. 2. 2. 2. 2.]
 [0. 2. 2. 2. 2. 2. 2.]
 [0. 2. 2. 2. 2. 2. 2.]
 [0. 2. 2. 2. 2. 2. 2.]
 [0. 2. 2. 2. 2. 2. 2.]
 [0. 2. 2. 2. 2. 2. 2.]
 [0. 2. 2. 2. 2. 2. 2.]]
Counts(PN): [[2. 2. 4. 4. 4. 4. 4.]
 [2. 2. 4. 4. 4. 4. 4.]
 [2. 2. 4. 4. 4. 4. 4.]
 [2. 2. 4. 4. 4. 4. 4.]
 [2. 2. 4. 4. 4. 4. 4.]
 [

In [34]:
# 1D convolution explanation
reload(me)
a = np.arange(0.0, 0.9, 0.1).astype(np.float32)
a = tf.reshape(a, shape=(1,9,1))
print(f"a = b: {a.shape} = {a[0,:,0]}")

k = np.array([-1, 0, -1]).astype(np.float32)
k = tf.reshape(k, shape=(3,1,1))
print(f"k = {k.shape} = {k[:,0,0]}")

counts, sums, thresholds = me.conv_classify(a, k, confidence=0.90, return_thresholds=True)
print()
print("Results...")
print(f"Shapes: counts={counts.shape}, sums={sums.shape}, thresholds={thresholds}")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[...,i]) > 0:
    print(f"Counts({name}): {counts[0,:,0,i]}")
for i, name in enumerate(me.classify_terms()):
  if np.sum(counts[...,i]) > 0:
    print(f"Sums({name}): {sums[0,:,0,i]}")

print()
print("Summary...")
print(f"Classes: {me.classify_terms()}")
print(f"Counts by class: {np.sum(counts, axis=(0,1,2))}")
print(f"Sums by class: {np.sum(sums, axis=(0,1,2))}")
print(f"Summary: {me.summarise(counts, sums)}")

print()
print("Validation...")
expected_conv = tf.nn.convolution(a, k)
derived_conv = tf.reduce_sum(sums, axis=-1)
print(f"True conv: {expected_conv.shape} = {expected_conv[0,:,0]}")
print(f"Derived conv: {derived_conv.shape} = {derived_conv[0,:,0]}")
print(f"Devired conv == true conv: {np.allclose(expected_conv, derived_conv)}")

a = b: (1, 9, 1) = [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8]
k = (3, 1, 1) = [-1.  0. -1.]

Results...
Shapes: counts=(1, 7, 1, 9), sums=(1, 7, 1, 9), thresholds=[0.05, 0.5]
Counts(PZ): [1. 1. 1. 1. 1. 1. 1.]
Counts(PN): [1. 2. 2. 2. 2. 2. 2.]
Counts(ZN): [1. 0. 0. 0. 0. 0. 0.]
Sums(PZ): [0. 0. 0. 0. 0. 0. 0.]
Sums(PN): [-0.2       -0.4       -0.6       -0.8       -1.        -1.2
 -1.4000001]
Sums(ZN): [0. 0. 0. 0. 0. 0. 0.]

Summary...
Classes: ['PP', 'PZ', 'PN', 'ZP', 'ZZ', 'ZN', 'NP', 'NZ', 'NN']
Counts by class: [ 0.  7. 13.  0.  0.  1.  0.  0.  0.]
Sums by class: [ 0.   0.  -5.6  0.   0.   0.   0.   0.   0. ]
Summary: PN: 13.0 = -5.599999904632568, PZ: 7.0 = 0.0, ZN: 1.0 = 0.0

Validation...
True conv: (1, 7, 1) = [-0.2       -0.4       -0.6       -0.8       -1.        -1.2
 -1.4000001]
Derived conv: (1, 7, 1) = [-0.2       -0.4       -0.6       -0.8       -1.        -1.2
 -1.4000001]
Devired conv == true conv: True


In [38]:
# Simple single tensor explanation
a = np.arange(0.0, 0.9, 0.1).astype(np.float32)
a = tf.reshape(a, shape=(1,9,1))

counts, sums, thresholds = me.tensor_classify(a, confidence=0.90, return_threshold=True)
print(f"Shapes: counts={counts.shape}, sums={sums.shape}, thresholds={thresholds}")
print(f"Classes: {me.classify_terms(counts)}")
print(f"Counts by class: {np.sum(counts, axis=(0,1,2))}")
print(f"Sums by class: {np.sum(sums, axis=(0,1,2))}")
print(f"Summary: {me.summarise(counts, sums)}")



Shapes: counts=(1, 9, 1, 3), sums=(1, 9, 1, 3), thresholds=0.05000000074505806
Classes: ['P', 'Z', 'N']
Counts by class: [8. 1. 0.]
Sums by class: [3.6 0.  0. ]
Summary: P: 8.0 = 3.5999999046325684, Z: 1.0 = 0.0
