Loading factorizations found by AlphaTensor and recombination.

- Copyright 2022 DeepMind Technologies Limited
- All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0
- All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY).  You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode
- Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.
- This is not an official Google product.

In [None]:
import numpy as np
from google.colab import files

Upload one of the two files provided in the same folder: `factorization_r.npz` (algorithms in standard arithmetic) or `factorization_f2.npz` (algorithms in arithmetic modulo 2).

In [None]:
uploaded = files.upload()
filename = list(uploaded.keys())[0]
with open(filename, 'rb') as f:
  factorizations = dict(np.load(f, allow_pickle=True))

In [None]:
# Print available factorizations and their shapes.
for key in factorizations:
  u, v, w = factorizations[key]
  rank = u.shape[-1]
  assert rank == v.shape[-1] and rank == w.shape[-1]
  print(f'{key}: rank={u.shape[-1]}')

Please note that as provided, the factorizations decompose the *symmetrized* version of the matrix multiplication tensor, representing the bilinear operation $\mathbf{A}, \mathbf{B} \mapsto (\mathbf{A} \cdot \mathbf{B})^T$. This is standard in the literature, and factorizations can be easily converted
between the symmetrized and non-symmetrized versions.

In [None]:
def get_mamu_tensor_rectangular(a: int, b: int, c: int) -> np.ndarray:
  """Returns the symmetrized matrix multiplication tensor T_{a, b, c}."""
  result = np.full((a*b, b*c, c*a), 0, dtype=np.int32)
  for i in range(a):
    for j in range(b):
      for k in range(c):
        result[i * b  + j][j * c + k][k * a + i] = 1
  return result


# Test correctness of a factorization.
tensor = get_mamu_tensor_rectangular(3, 4, 5)
u, v, w = factorizations['3,4,5']
reconstruction = np.einsum('ir,jr,kr->ijk', u, v, w)
if np.array_equal(tensor, reconstruction):
  print('Factorization is correct in R (standard arithmetic).')
elif np.array_equal(tensor, np.mod(reconstruction, 2)):
  print('Factorization is correct in F2 (modular arithmetic).')
else:
  print('Factorization is incorrect.')

In [None]:
def matrix_multiplication_using_factorization(matrix_a, matrix_b, factorization_key):
  """Multiplies two matrices using the provided factorization.

  Args:
    matrix_a: The first matrix (numpy array).
    matrix_b: The second matrix (numpy array).
    factorization_key: The key for the factorization in the factorizations
      dictionary.

  Returns:
    The product of matrix_a and matrix_b.
  """
  u, v, w = factorizations[factorization_key]
  a_flat = matrix_a.flatten()
  b_flat = matrix_b.flatten()

  # Perform the matrix multiplication using the factorization
  intermediate_a = np.dot(a_flat, u)
  intermediate_b = np.dot(b_flat, v)
  intermediate = (intermediate_a * intermediate_b)
  result_flat = np.dot( w,intermediate)

  # Reshape the result to the correct matrix dimensions
  rows_a = matrix_a.shape[0]
  cols_b = matrix_b.shape[1]
  result_matrix = result_flat.reshape((rows_a, cols_b), order='f')

  return result_matrix

# Example usage:
matrix_a = np.random.rand(3, 4)
matrix_b = np.random.rand(4, 5)

# Assuming you have a factorization for matrices of size (3, 4, 5)
factorization_key = '3,4,5'

if factorization_key in factorizations:
  product_matrix = matrix_multiplication_using_factorization(matrix_a, matrix_b, factorization_key)
  print("Product Matrix (using factorization):")
  print(product_matrix)

  # Compare with standard matrix multiplication (for verification)
  standard_product = np.dot(matrix_a, matrix_b)
  print("\nProduct Matrix (standard multiplication):")
  print(standard_product)

  # Check if the results are close (allow for small numerical errors)
  if np.allclose(product_matrix, standard_product):
    print("\nResults are close, factorization appears to be working.")
  else:
    print("\nResults differ significantly, factorization may not be working correctly.")
else:
  print(f"Factorization for key '{factorization_key}' not found.")