Loading factorizations found by AlphaTensor and recombination.

- Copyright 2022 DeepMind Technologies Limited
- All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0
- All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY).  You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode
- Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.
- This is not an official Google product.

In [1]:
import numpy as np
# from google.colab import files

Upload one of the two files provided in the same folder: `factorization_r.npz` (algorithms in standard arithmetic) or `factorization_f2.npz` (algorithms in arithmetic modulo 2).

In [2]:
# uploaded = files.upload()
# filename = list(uploaded.keys())[0]
filename = "factorizations_r.npz"
with open(filename, 'rb') as f:
  factorizations = dict(np.load(f, allow_pickle=True))

In [3]:
# Print available factorizations and their shapes.
for key in factorizations:
  u, v, w = factorizations[key]
  rank = u.shape[-1]
  assert rank == v.shape[-1] and rank == w.shape[-1]
  before_sums = np.count_nonzero(u) + np.count_nonzero(v) - 2*u.shape[1]
  after_sums = + np.count_nonzero(w) - w.shape[0]
  
  print(f'{key}: \trank={u.shape[-1]}\tsums:{before_sums+after_sums}')

2,2,2: 	rank=7	sums:22
2,2,3: 	rank=11	sums:25
2,2,4: 	rank=14	sums:48
2,2,5: 	rank=18	sums:65
2,2,6: 	rank=21	sums:54
2,2,7: 	rank=25	sums:75
2,2,8: 	rank=28	sums:103
2,3,3: 	rank=15	sums:58
2,3,4: 	rank=20	sums:88
2,3,5: 	rank=25	sums:113
2,4,4: 	rank=26	sums:122
2,4,5: 	rank=33	sums:200
2,5,5: 	rank=40	sums:283
3,3,3: 	rank=23	sums:110
3,3,4: 	rank=29	sums:148
3,3,5: 	rank=36	sums:185
3,4,4: 	rank=38	sums:204
3,4,5: 	rank=47	sums:293
3,4,11: 	rank=103	sums:708
3,5,5: 	rank=58	sums:369
3,5,9: 	rank=105	sums:665
3,9,11: 	rank=225	sums:3299
4,4,4: 	rank=49	sums:468
4,4,5: 	rank=63	sums:473
4,5,5: 	rank=76	sums:549
4,5,9: 	rank=139	sums:1026
4,5,10: 	rank=152	sums:1568
4,5,11: 	rank=169	sums:1226
4,9,10: 	rank=255	sums:3104
4,9,11: 	rank=280	sums:3568
4,11,11: 	rank=343	sums:4426
4,11,12: 	rank=366	sums:4864
5,5,5: 	rank=98	sums:643
5,5,7: 	rank=134	sums:918
5,7,9: 	rank=234	sums:2613
5,7,10: 	rank=257	sums:3048
5,7,11: 	rank=280	sums:6174
5,8,9: 	rank=262	sums:3203
5,8,10: 	rank=287	su

Please note that as provided, the factorizations decompose the *symmetrized* version of the matrix multiplication tensor, representing the bilinear operation $\mathbf{A}, \mathbf{B} \mapsto (\mathbf{A} \cdot \mathbf{B})^T$. This is standard in the literature, and factorizations can be easily converted
between the symmetrized and non-symmetrized versions.

In [4]:
def get_mamu_tensor_rectangular(a: int, b: int, c: int) -> np.ndarray:
  """Returns the symmetrized matrix multiplication tensor T_{a, b, c}."""
  result = np.full((a*b, b*c, c*a), 0, dtype=np.int32)
  for i in range(a):
    for j in range(b):
      for k in range(c):
        result[i * b  + j][j * c + k][k * a + i] = 1
  return result

# FACT = (4,4,4)
FACT = (2,2,2)

# Test correctness of a factorization.
tensor = get_mamu_tensor_rectangular(*FACT)
u, v, w = factorizations[','.join(str(f) for f in FACT)]
reconstruction = np.einsum('ir,jr,kr->ijk', u, v, w)
if np.array_equal(tensor, reconstruction):
  print('Factorization is correct in R (standard arithmetic).')
elif np.array_equal(tensor, np.mod(reconstruction, 2)):
  print('Factorization is correct in F2 (modular arithmetic).')
else:
  print('Factorization is incorrect.')
  
before_sums = np.count_nonzero(u) + np.count_nonzero(v) - 2*u.shape[1]
after_sums = + np.count_nonzero(w) - w.shape[0]
print(f'The computation uses {u.shape[1]} muls and {before_sums+after_sums} sums ({before_sums}+{after_sums})')

Factorization is correct in R (standard arithmetic).
The computation uses 7 muls and 22 sums (14+8)


In [5]:
tensor = get_mamu_tensor_rectangular(2,2,2)
tensor

array([[[1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0]],

       [[0, 1, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1]]], dtype=int32)

In [6]:
factorizations['2,2,2']

array([[[ 0,  1,  1,  0,  1,  1,  0],
        [ 0,  0, -1,  1,  0,  0,  0],
        [ 1,  1,  1,  0,  1,  0,  0],
        [-1, -1, -1,  0,  0,  0,  1]],

       [[ 0,  0,  0,  0,  1,  1,  0],
        [ 1,  1,  0,  0,  1,  0,  1],
        [ 0,  1,  1,  1,  1,  0,  0],
        [ 0,  1,  1,  0,  1,  0,  1]],

       [[ 0,  0,  0,  1,  0,  1,  0],
        [ 0, -1,  0,  0,  1, -1, -1],
        [-1,  1, -1, -1,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0,  1]]], dtype=int32)

The einsum expression `ir,jr,kr->ijk` means:

$$\text{reconstruction}_{ijk}=\sum_i\sum_j\sum_k u_{ir}\cdot v_{jr}\cdot w_{kr}$$

Which is like the decomposition of a tensor:

$$\mathcal{T}=\sum^R u^{(r)}\otimes v^{(r)}\otimes w^{(r)}$$

Where $\otimes$ is the Kronecker product (in this case, equivalent to the outer or tensor product).

In [7]:
# for r in range(rank):
#     print(u[:,r], v[:,r], w[:,r])

In [8]:
u_dim, rank = u.shape
v_dim, _ = v.shape
w_dim, _ = w.shape

reconstruction_manual = np.zeros((u_dim, v_dim, w_dim), dtype=int)
for r in range(rank):
    reconstruction_manual += np.kron(np.kron(u[:,r], v[:,r]), w[:,r]).reshape(u_dim,v_dim,w_dim)
    
np.array_equal(tensor, reconstruction_manual)

True

It seems so, now I want to "reconstruct it" in the shape of muls and sums... Wish me luck.

The *rank* tells the number of multiplications, but, what is the number of aditions?

Una operación de multiplicación de matrices de 2x2 puede representarse como un tensor de tamaño 4x4x4 (`tensor` o `reconstruction`). Ese tensor es independiente de las matrices a ser multiplicadas.

Voy a implementar el **Algorithm 1** del paper.

In [9]:
def term2tex(i, c, txt):
    if c == 0:
        return ''
    elif c < 0:
        c_str = '-'
    else:
        c_str = ''
    
    if abs(c) > 1:
        c_str += str(abs(c))
    
    return f'{c_str}{txt}_{{{i}}}'

term2tex(4, -3, 'a')

'-3a_{4}'

In [10]:
from IPython.display import Latex

before_sums = 0
after_sums = 0

t_strings = []
for r in range(rank):  
    before_sums += np.count_nonzero(u[:,r])-1 + np.count_nonzero(v[:,r])-1
    
    a_terms = [(i, int(u[i,r])) for i in range(u_dim) if u[i,r] != 0]
    b_terms = [(i, int(v[i,r])) for i in range(v_dim) if v[i,r] != 0]
    
    # TODO: Remove uneeded parens
    a_string = '+'.join(term2tex(i+1, ui, 'a') for i, ui in a_terms).replace('+-', '-')
    if len(a_terms) != 1:
        a_string = ''.join(['(', a_string, ')'])
    
    b_string = '+'.join(term2tex(i+1, vi, 'b') for i, vi in b_terms).replace('+-', '-')
    if len(b_terms) != 1:
        b_string = ''.join(['(', b_string, ')'])
    
    t_strings.append(f"t_{{{r+1}}} = {a_string}{b_string}")
    
for k in range(u_dim):
    after_sums += np.count_nonzero(w[k,:])-1
    c_terms = [(r, int(w[k,r])) for r in range(rank) if w[k,r] != 0]
    # print(c_terms)
    c_string = '+'.join(term2tex(r+1, ur, 't') for r, ur in c_terms).replace('+-', '-')

    t_strings.append(f"c_{{{k+1}}} = {c_string}")
    
print(f"Total: {rank} muls and {before_sums+after_sums} sums ({before_sums}/{after_sums})")

tex_code = "\\begin{align*}\n"+" \\\\\n".join(t_strings).replace(' = ', ' &= ')+"\n\\end{align*}"
print(tex_code)
Latex(tex_code)

Total: 7 muls and 22 sums (14/8)
\begin{align*}
t_{1} &= (a_{3}-a_{4})b_{2} \\
t_{2} &= (a_{1}+a_{3}-a_{4})(b_{2}+b_{3}+b_{4}) \\
t_{3} &= (a_{1}-a_{2}+a_{3}-a_{4})(b_{3}+b_{4}) \\
t_{4} &= a_{2}b_{3} \\
t_{5} &= (a_{1}+a_{3})(b_{1}+b_{2}+b_{3}+b_{4}) \\
t_{6} &= a_{1}b_{1} \\
t_{7} &= a_{4}(b_{2}+b_{4}) \\
c_{1} &= t_{4}+t_{6} \\
c_{2} &= -t_{2}+t_{5}-t_{6}-t_{7} \\
c_{3} &= -t_{1}+t_{2}-t_{3}-t_{4} \\
c_{4} &= t_{1}+t_{7}
\end{align*}


<IPython.core.display.Latex object>