# AlphaTensor algorithm for matrix multiplication


First of all, let us import the algorithm discovered by AlphaTensor, represented as factorizations of matrix multiplication tensors.

## Tensor decomposition

In [15]:
# Import necessary libraries
import numpy as np
from google.colab import files

Upload the provided file to work in standard arithmetic. 

In [16]:
uploaded = files.upload()
filename = list(uploaded.keys())[0]
with open(filename, 'rb') as f:
  factorizations = dict(np.load(f, allow_pickle=True))

Saving factorizations_r.npz to factorizations_r (1).npz


In [17]:
# Print available factorizations and their shapes.
for key in factorizations:
  u, v, w = factorizations[key]
  rank = u.shape[-1]
  assert rank == v.shape[-1] and rank == w.shape[-1]
  print(f'{key}: rank={u.shape[-1]}')

2,2,2: rank=7
2,2,3: rank=11
2,2,4: rank=14
2,2,5: rank=18
2,2,6: rank=21
2,2,7: rank=25
2,2,8: rank=28
2,3,3: rank=15
2,3,4: rank=20
2,3,5: rank=25
2,4,4: rank=26
2,4,5: rank=33
2,5,5: rank=40
3,3,3: rank=23
3,3,4: rank=29
3,3,5: rank=36
3,4,4: rank=38
3,4,5: rank=47
3,4,11: rank=103
3,5,5: rank=58
3,5,9: rank=105
3,9,11: rank=225
4,4,4: rank=49
4,4,5: rank=63
4,5,5: rank=76
4,5,9: rank=139
4,5,10: rank=152
4,5,11: rank=169
4,9,10: rank=255
4,9,11: rank=280
4,11,11: rank=343
4,11,12: rank=366
5,5,5: rank=98
5,5,7: rank=134
5,7,9: rank=234
5,7,10: rank=257
5,7,11: rank=280
5,8,9: rank=262
5,8,10: rank=287
5,8,11: rank=317
5,9,9: rank=296
5,9,10: rank=323
5,9,11: rank=358
5,9,12: rank=381
6,7,9: rank=270
6,7,10: rank=296
6,7,11: rank=322
6,8,10: rank=329
6,8,11: rank=365
6,9,9: rank=342
6,9,10: rank=373
6,9,11: rank=411
7,7,9: rank=318
7,7,10: rank=350
7,7,11: rank=384
7,8,9: rank=354
7,8,10: rank=393
7,8,11: rank=432
7,8,12: rank=462
7,9,9: rank=399
7,9,10: rank=441
7,9,11: rank=481
7,

In [18]:
# Get the tensor decomposition
def get_mamu_tensor_rectangular(a: int, b: int, c: int) -> np.ndarray:
  """Returns the symmetrized matrix multiplication tensor T_{a, b, c}."""
  result = np.full((a*b, b*c, c*a), 0, dtype=np.int32)
  for i in range(a):
    for j in range(b):
      for k in range(c):
        result[i * b  + j][j * c + k][k * a + i] = 1
  return result


# Test correctness of a factorization.
tensor = get_mamu_tensor_rectangular(4, 4, 4)
u, v, w = factorizations['4,4,4']
reconstruction = np.einsum('ir,jr,kr->ijk', u, v, w)
if np.array_equal(tensor, reconstruction):
  print('Factorization is correct in R (standard arithmetic).')
elif np.array_equal(tensor, np.mod(reconstruction, 2)):
  print('Factorization is correct in F2 (modular arithmetic).')
else:
  print('Factorization is incorrect.')

Factorization is correct in R (standard arithmetic).


## From tensor decomposition to the result

Once we have the tensor decomposition, we can multiply two matrices. Let us see an example of $4\times 4$ matrices.

In [19]:
n2 = u.shape[0]
R = u.shape[1]

In [20]:
# Definition of example matrices
A = np.asarray([
     [2, 32.2, 0, -1],
     [0, 343, 0, -213],
     [2, -4, 2, 0],
     [0, 43.1, 0, -2]]).reshape(16)

B = np.asarray([
     [1, 0, -139, 0],
     [32, 2, 0, 41],
     [0, -43.3, 3, 0],
     [1, 0.001, 0, 4]]).reshape(16)

In [21]:
# Algorithm alphaTensor to multiply A x B
m = []
for r in range(R):
  m.append(np.dot(u[:, r], A) * np.dot(v[:, r], B))

C = []
for i in range(n2):
  C.append(np.dot(w[i,:], m))

C = np.asarray(C).reshape(4,4).T


print(C)

[[ 1.03140000e+03  6.43990000e+01 -2.78000000e+02  1.31620000e+03]
 [ 1.07630000e+04  6.85787000e+02 -3.63797881e-12  1.32110000e+04]
 [-1.26000000e+02 -9.46000000e+01 -2.72000000e+02 -1.64000000e+02]
 [ 1.37720000e+03  8.61980000e+01  0.00000000e+00  1.75910000e+03]]


In [23]:
# Usual algorithm to multiply A x B
np.dot(A.reshape(4,4), B.reshape(4,4))

array([[ 1031.4  ,    64.399,  -278.   ,  1316.2  ],
       [10763.   ,   685.787,     0.   , 13211.   ],
       [ -126.   ,   -94.6  ,  -272.   ,  -164.   ],
       [ 1377.2  ,    86.198,     0.   ,  1759.1  ]])