# **Structure and Analysis of Vectorized/Just-In-Time Compiled 3x3 Magic Square Solver**
### _Nathan Keough_

1262743 function calls (1211430 primitive calls) in 1.510 seconds

For advanced benchmark data see 'profile.txt' in the tests dir


In [1]:
import numpy as np
import numba as nb
import itertools as it

#### Import Packages:
* numpy==1.21.5
* numba==0.55.1
* Itertools (built-in)

### Generating all the permutations

In [2]:
a = it.permutations((range(1, 10)))
b = np.fromiter(it.chain(*a), dtype=np.uint8).reshape((-1, 9))
b

array([[1, 2, 3, ..., 7, 8, 9],
       [1, 2, 3, ..., 7, 9, 8],
       [1, 2, 3, ..., 8, 7, 9],
       ...,
       [9, 8, 7, ..., 2, 3, 1],
       [9, 8, 7, ..., 3, 1, 2],
       [9, 8, 7, ..., 3, 2, 1]], dtype=uint8)

For $N x N$ squares, the number of permutations or length of array b is $N^2!$

itertools.chain(...) returns elements from itertools.permutations(...) until iterable is exhausted

### Check the validity of a square

In [3]:
@nb.njit(fastmath=True, locals={'t':nb.b1}, nogil=True, cache=True)
def check(arr: np.ndarray) -> bool:
    t = (np.array([arr[0], arr[0], arr[0], arr[1], arr[2], arr[3]]) + 
         np.array([arr[1], arr[3], arr[4], arr[4], arr[4], arr[4]]) + 
         np.array([arr[2], arr[6], arr[8], arr[7], arr[6], arr[5]]) == 15).all()
    return t

fastmath=True enables the use of fastmath functions when the function is compiled. In order to translate python into native code, types have to be converted to static equivalents. This is usually inferred but we can save the compiler time by specifying the local variable types (nb.b1 is a bool type). nogil=True, allows a less strict global interpretter lock and allows multicore utilization. cache=True reduces compile times.

For a square: $$\begin{bmatrix} a & b & c \\ d & e & f \\ g & h & i \end{bmatrix}$$
Checking the validity involves checking the following conditions:

$$a+b+c=15$$
$$d+e+f=15$$
$$g+h+i=15$$
$$a+d+g=15$$
$$b+e+h=15$$
$$c+f+i=15$$
$$a+e+i=15$$
$$c+e+g=15$$

This can be aranged into 3 vectors. The vectors are added together, adding each subsequent element in parallel on the CPU. If any value in the resulting vector is not 15, then the check for np.all(...) returns False.
$$\begin{bmatrix} a & d & g & a & b & c & a & c \end{bmatrix}$$
$$+$$
$$\begin{bmatrix} b & e & h & d & e & f & e & e \end{bmatrix}$$
$$+$$
$$\begin{bmatrix} c & f & i & g & h & i & i & g \end{bmatrix}$$
$$=$$
$$\begin{bmatrix} 15 & 15 & 15 & 15 & 15 & 15 & 15 & 15 \end{bmatrix}$$

This vectorized structure produces fewer CPU instructions. Notice also in the implementation, that the edges c+f+i and g+h+i can be excluded from the vectors as their computation is somehow redundant. The correct results are still always produced.


### Check All Permutations

In [4]:
@nb.njit(fastmath=True, locals={'d':nb.int32, 'i':nb.int32}, nogil=True)
def main() -> None:
    for i in range(len(b) // 2):
        c = b[i]
        d = 362879 - i
        if check(c):
            print(i)
            print(c.reshape(3,3), '\n')
            print(d)
            print(b[d].reshape(3,3), '\n')

Once again I have compiled this function into a C equivalent. This function is simply looping through the list of permutations. However, we only need to loop through half of the array. As long as we store the index $i$ of the square in its lexicographically ordered set of permutations, we can guarantee that b[$N^2!-i-1$] must also be a valid square. Additionally, the print statements are aranged in a way that makes it easy for their formats and types to be inferred by the compiler. 

In [5]:
main()

69074
[[2 7 6]
 [9 5 1]
 [4 3 8]] 

293805
[[8 3 4]
 [1 5 9]
 [6 7 2]] 

77576
[[2 9 4]
 [7 5 3]
 [6 1 8]] 

285303
[[8 1 6]
 [3 5 7]
 [4 9 2]] 

135289
[[4 3 8]
 [9 5 1]
 [2 7 6]] 

227590
[[6 7 2]
 [1 5 9]
 [8 3 4]] 

157120
[[4 9 2]
 [3 5 7]
 [8 1 6]] 

205759
[[6 1 8]
 [7 5 3]
 [2 9 4]] 

