# Requirements

In [50]:
import itertools
import math
import operator
import time

# Problem statement

A magic square is an $n \times n$ square matrix with elements $\{1, \ldots, n^2\}$ such that the sum of each row, each column, the main diagonal and the skew diagonal are all equal.

For example, the following $3 \times 3$ square is magic:
$$
\begin{array}{ccc}
  8  & 1 & 6 \\
  3  & 5 & 7 \\
  4  & 9 & 2 \\
\end{array}
$$

The magic constant for a $3 \times 3$ square is 15.  In general for an $n \times n$ square, the sum of all the elements is $\sum_{i=1}^{n^2} i = \frac{1}{2}n^2(n^2 + 1)$.  Since each of the row sums is equal to the magic constant $M$, and there are $n$ rows, $M = \frac{1}{2}n(n^2 + 1)$.

The magic $1 \times 1$ square is of course trivial, and there is no $2 \times 2$ squire.  For larger values of $n$, we might naively construct magic squares by considering the permutations of the fundamental $n \times n$ square (i.e., a square filled row by row with the number $1$ to $n^2$ is an option that soon turns out to be impossible since the number of permutation is $n^2!$.

In [7]:
for n in range(6):
    print(f'{n}: {math.factorial(n**2)} ({math.factorial(n**2):.2e})')

0: 1 (1.00e+00)
1: 1 (1.00e+00)
2: 24 (2.40e+01)
3: 362880 (3.63e+05)
4: 20922789888000 (2.09e+13)
5: 15511210043330985984000000 (1.55e+25)


It is clear that even for $4 \times 4$ squares, the number of squares to test is prohibitively large.  However, there are algorithms to construct magic squares.  We will implement one for odd values of $n$ attributed to Simon de la Loubère.

# Helper functions

For performance reasons, we will represent an $n \times n$ square as a list of integers, so its length is $n^2$.  The square's elements are stored rowwise, so the element of the square at row $i$ and column $j$ is the list element at index $i \cdot n + j$ where $i, j \in \{0, \ldots, n - 1\}$.

We implement a simple function that converts a list representation to a list of lists where each list represents a row of the square.

In [48]:
def convert_to_square(elements: list[int]) -> list[list[int]]:
    n = math.isqrt(len(elements))
    if n**2 != len(elements):
        raise ValueError('list does not represent a square')
    # The groupby key will be the row index, computed as the list index, integer
    # divided by n, this yields a sequence of tuples, the first element of which is
    # the key, the second a group.  The group contains tuples, the first element the
    # list index, the second the list tuple.  We select the second by mapping the
    # itemgetter(1), and than convert the group to a list.
    return list(map(lambda t: list(map(operator.itemgetter(1), t[1])),
                    itertools.groupby(enumerate(elements), lambda t: t[0]//n)))

In [20]:
convert_to_square(list(range(1, 10)))

[[1, 2, 3], [4, 5, 6], [7, 8, 9]]

For convenience, we also write a function to display a square.

In [32]:
def print_square(elements: list[int]) -> None:
    n = math.isqrt(len(elements))
    if n**2 != len(elements):
        raise ValueError('list does not represent a square')
    for i in range(n):
        print(''.join(f'{value:5d}' for value in elements[i*n:i*n+n]))

We also define a function to check whether a square is magic.

In [33]:
def is_magic_squuare(elements: list[int]) -> bool:
    n = math.isqrt(len(elements))
    if n**2 != len(elements):
        raise ValueError('list does not represent a square')
    diag_sum = sum(elements[i + i*n] for i in range(n))
    for i in range(n):
        row_sum = sum(elements[j + i*n] for j in range(n))
        if row_sum != diag_sum:
            return False
    for j in range(n):
        col_sum = sum(elements[j + i*n] for i in range(n))
        if col_sum != diag_sum:
            return False
    return sum(elements[i*n + (n - i - 1)] for i in range(n)) == diag_sum

In [34]:
is_magic_squuare(list(range(1, 10)))

False

In [35]:
is_magic_squuare([2, 9, 4, 7, 5, 3, 6, 1, 8])

True

We implement a function that creates the $n \times n$ fundamental square.

In [29]:
def create_fundamental_square(n: int) -> list[int]:
    return list(range(1, n**2 + 1))

In [36]:
create_fundamental_square(3)

[1, 2, 3, 4, 5, 6, 7, 8, 9]

# Finding magic squares

## Enumeration

The naive algorithm to find a magic squarre is straightforward by simply enumerating all permutations of the fundamental square, and returning the first one that is magic.  We build in a maximum number of iterations to avoid unrealistic runtimes of the function.

In [58]:
def find_magic_square(n: int, max_iterations: int=10_000_000) -> list[list[int]] | None:
    for iteration, elements in enumerate(itertools.permutations(create_fundamental_square(n))):
        if is_magic_squuare(elements):
            return convert_to_square(elements)
        if iteration >= max_iterations:
            break
    return None

In [59]:
find_magic_square(3)

[[2, 7, 6], [9, 5, 1], [4, 3, 8]]

This worked well, and fast for $n = 3$.  However, lets time this for $n = 4$.

In [60]:
start = time.perf_counter()
find_magic_square(4)
time.perf_counter() - start

12.033219652000298

It takes more than 10 seconds to iterate over $10^7$ squares.  We know there are $10^{13}$ possible $4 \times 4$ squares, so to enumerate them all would take $10^7 \approx 116$ days.  Of course, we might get lucky and hit upon a magic square "early".

For $n = 3$, we can enumerate all magic squares>

In [66]:
def create_all_magic_square(n: int, is_verbose: bool=False) -> list[list[int]]:
    nr_squares = 0
    for elements in itertools.permutations(create_fundamental_square(n)):
        if is_magic_squuare(elements):
            yield convert_to_square(elements)
        nr_squares += 1
        if is_verbose and nr_squares % 10_000_000 == 0:
            print(f'{nr_squares} tested')

In [67]:
%%time
for square in create_all_magic_square(3):
    print(square)

[[2, 7, 6], [9, 5, 1], [4, 3, 8]]
[[2, 9, 4], [7, 5, 3], [6, 1, 8]]
[[4, 3, 8], [9, 5, 1], [2, 7, 6]]
[[4, 9, 2], [3, 5, 7], [8, 1, 6]]
[[6, 1, 8], [7, 5, 3], [2, 9, 4]]
[[6, 7, 2], [1, 5, 9], [8, 3, 4]]
[[8, 1, 6], [3, 5, 7], [4, 9, 2]]
[[8, 3, 4], [1, 5, 9], [6, 7, 2]]
CPU times: user 478 ms, sys: 0 ns, total: 478 ms
Wall time: 476 ms


It takes only half a second to enumerate all magic $3 \times 3$ squares, but there are only 362,880 candiates.

## Simon de la Loubère's method

An algorithm attributed to Simon de la Loubère works for odd values of $n$ only.  It works as follows:
* start in the middle of the first row;
* fill the corresponding broken diagonal by starting with the values 1 to $n$, going up
* at the position of the last value entered, move one row down, and repeat with the next $n$ values
* repeat until all $n^2$ values are entered into the square.

For $n = 3$, this would look like for the first broken diagonal:
$$
\begin{array}{ccc}
   \cdot &   1   & \cdot \\
      3  & \cdot & \cdot \\
   \cdot & \cdot &    2  \\
\end{array}
$$
For the second (not) broken diagonal:
$$
\begin{array}{ccc}
   \cdot &   1   &    6  \\
      3  &   5   & \cdot \\
      4  & \cdot &    2  \\
\end{array}
$$
And finally for the last broken diagonal:
$$
\begin{array}{ccc}
      8  &   1   &    6  \\
      3  &   5   &    7  \\
      4  &   9   &    2  \\
\end{array}
$$


The implementation is straightforward.

In [72]:
def create_magic_square(n: int) -> list[int]:
    if n < 1:
        raise ValueError(f'{n} is not a valid dimension')
    if n % 2 == 0:
        raise NotImplementedError('algorithm not implement for even sizes')
    elements = [-1]*n**2
    i, j = 0, n//2
    for value in create_fundamental_square(n):
        elements[i*n + j] = value
        if value % n == 0:
            i = (i + 1) % n
        else:
            i = (i - 1 + n) % n
            j = (j + 1) % n
    return elements

For $n = 3$, and also check whether the square is indeed magic:

In [74]:
print_square(create_magic_square(3))

    8    1    6
    3    5    7
    4    9    2


In [76]:
is_magic_squuare(create_magic_square(3))

True

For $n = 5$, and check:

In [75]:
print_square(create_magic_square(5))

   17   24    1    8   15
   23    5    7   14   16
    4    6   13   20   22
   10   12   19   21    3
   11   18   25    2    9


In [78]:
is_magic_squuare(create_magic_square(5))

True

This method is of course very fast compared to enumerating squares, its complexity is quadratic in $n$.