In [2]:
import random
import time
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Callable

## Question # 02: Strassen Algorithm

In [3]:
def add_matrix(A, B):
    """Add two matrices A and B."""
    return [[A[i][j] + B[i][j] for j in range(len(A[0]))] for i in range(len(A))]

def subtract_matrix(A, B):
    """Subtract matrix B from matrix A."""
    return [[A[i][j] - B[i][j] for j in range(len(A[0]))] for i in range(len(A))]

def strassen(A, B, depth=0):
    """Perform Strassen's matrix multiplication on matrices A and B."""
    n = len(A)
    indent = "  " * depth  # Indentation based on recursion depth for better readability

    print(f"{indent}Strassen called at depth {depth} for matrix size {n}x{n}")

    # Base case: when the matrix is 1x1, just multiply the elements
    if n == 1:
        print(f"{indent}Base case reached: Multiplying {A[0][0]} and {B[0][0]}")
        return [[A[0][0] * B[0][0]]]

    # Divide the matrices into quarters
    mid = n // 2
    A11 = [row[:mid] for row in A[:mid]]
    A12 = [row[mid:] for row in A[:mid]]
    A21 = [row[:mid] for row in A[mid:]]
    A22 = [row[mid:] for row in A[mid:]]

    B11 = [row[:mid] for row in B[:mid]]
    B12 = [row[mid:] for row in B[:mid]]
    B21 = [row[:mid] for row in B[mid:]]
    B22 = [row[mid:] for row in B[mid:]]

    # Recursively calculate the 7 products
    print(f"{indent}Calculating M1...")
    M1 = strassen(add_matrix(A11, A22), add_matrix(B11, B22), depth + 1)
    
    print(f"{indent}Calculating M2...")
    M2 = strassen(add_matrix(A21, A22), B11, depth + 1)
    
    print(f"{indent}Calculating M3...")
    M3 = strassen(A11, subtract_matrix(B12, B22), depth + 1)
    
    print(f"{indent}Calculating M4...")
    M4 = strassen(A22, subtract_matrix(B21, B11), depth + 1)
    
    print(f"{indent}Calculating M5...")
    M5 = strassen(add_matrix(A11, A12), B22, depth + 1)
    
    print(f"{indent}Calculating M6...")
    M6 = strassen(subtract_matrix(A21, A11), add_matrix(B11, B12), depth + 1)
    
    print(f"{indent}Calculating M7...")
    M7 = strassen(subtract_matrix(A12, A22), add_matrix(B21, B22), depth + 1)

    # Compute the resulting submatrices
    C11 = add_matrix(subtract_matrix(add_matrix(M1, M4), M5), M7)
    C12 = add_matrix(M3, M5)
    C21 = add_matrix(M2, M4)
    C22 = add_matrix(subtract_matrix(add_matrix(M1, M3), M2), M6)

    # Combine the submatrices into one
    C = [[0] * n for _ in range(n)]
    for i in range(mid):
        C[i][:mid] = C11[i]
        C[i][mid:] = C12[i]
        C[i + mid][:mid] = C21[i]
        C[i + mid][mid:] = C22[i]

    print(f"{indent}Returning result from depth {depth}")
    return C

# Main function to multiply two matrices using Strassen's algorithm
def main():
    # 4x4 matrix multiplication example
    A = np.array([[1, 0, 2, 1],
                  [4, 1, 1, 0],
                  [0, 1, 3, 0],
                  [5, 0, 2, 1]])
    
    B = np.array([[0, 1, 0, 1],
                  [2, 1, 1, 4],
                  [2, 0, 1, 1],
                  [1, 3, 5, 0]])

    # Perform Strassen's multiplication
    print("Starting Strassen's matrix multiplication...")
    C = strassen(A.tolist(), B.tolist())

    # Print the resulting matrix
    print("\nResult of Strassen's Matrix Multiplication:")
    for row in C:
        print(row)

    # Library Implementation
    print("\nResult of Numpy Matrix Multiplication:")
    print(np.matmul(A, B))

if __name__ == "__main__":
    main()


Starting Strassen's matrix multiplication...
Strassen called at depth 0 for matrix size 4x4
Calculating M1...
  Strassen called at depth 1 for matrix size 2x2
  Calculating M1...
    Strassen called at depth 2 for matrix size 1x1
    Base case reached: Multiplying 6 and 2
  Calculating M2...
    Strassen called at depth 2 for matrix size 1x1
    Base case reached: Multiplying 8 and 1
  Calculating M3...
    Strassen called at depth 2 for matrix size 1x1
    Base case reached: Multiplying 4 and 1
  Calculating M4...
    Strassen called at depth 2 for matrix size 1x1
    Base case reached: Multiplying 2 and 6
  Calculating M5...
    Strassen called at depth 2 for matrix size 1x1
    Base case reached: Multiplying 4 and 1
  Calculating M6...
    Strassen called at depth 2 for matrix size 1x1
    Base case reached: Multiplying 2 and 3
  Calculating M7...
    Strassen called at depth 2 for matrix size 1x1
    Base case reached: Multiplying -2 and 8
  Returning result from depth 1
Calculatin