In [82]:
import numpy as np
import itertools

class GeneralLatinSquare:
    def __init__(self, N: int, S: int):
        self.N = N
        self.S = S
        assert S >= N
        self.cube = np.zeros((self.N, self.N, self.S))
        # for i in range(N):
        #     for j in range(N):
        #         self.cube[i,j,(i+j)%self.N] = 1
        for i in range(N):
                for j in range(N):
                    self.cube[i,j,(i+j)%self.N] = 1

    def __str__(self):
        return self.cube.__str__()

    def __repr__(self):
        return self.__str__()

    def copy(self):
        CP = GeneralLatinSquare(self.N, self.S)
        CP.cube = self.cube.copy()
        return CP

    def improper(self):
        improper_pts = []
        for i in range(self.N):
            for j in range(self.N):
                for k in range(self.S):
                    if self.cube[i,j,k] == -1:
                        improper_pts.append((i,j,k))
        assert len(improper_pts) <= 1
        if len(improper_pts) == 0:
            return None
        else:
            return improper_pts[0]

    def ones(self, x: int, y: int, z: int):
        assert x == -1 or y == -1 or z == -1
        indices = []
        if x == -1:
            for i in range(self.N):
                if self.cube[i,y,z] == 1:
                    indices.append(i)
        elif y == -1:
            for i in range(self.N):
                if self.cube[x,i,z] == 1:
                    indices.append(i)
        elif z == -1:
            for i in range(self.N):
                if self.cube[x,y,i] == 1:
                    indices.append(i)
        return indices

    def apply_move(self, pt1, pt2):
        x1, y1, z1 = pt1
        x2, y2, z2 = pt2
        self.cube[x1, y1, z1] += 1
        self.cube[x1, y1, z2] -= 1
        self.cube[x1, y2, z1] -= 1
        self.cube[x2, y1, z1] -= 1
        self.cube[x1, y2, z2] += 1
        self.cube[x2, y1, z2] += 1
        self.cube[x2, y2, z1] += 1
        self.cube[x2, y2, z2] -= 1
        self.improper()

    def adj(self):
        imp = self.improper()
        moves = []
        if imp is None:
            for i in range(self.N):
                for j in range(self.N):
                    for k in range(self.S):
                        if self.cube[i,j,k] == 0:
                            pt1 = (i,j,k)
                            X = self.ones(-1,j,k)
                            Y = self.ones(i,-1,k)
                            Z = self.ones(i,j,-1)
                            for pt2 in itertools.product(X, Y, Z):
                                moves.append((pt1, pt2))
        else:
            i,j,k = imp
            X = self.ones(-1,j,k)
            Y = self.ones(i,-1,k)
            Z = self.ones(i,j,-1) 
            for pt2 in itertools.product(X, Y, Z):
                moves.append((imp, pt2))
        adj_list = []
        for move in moves:
            LS2 = self.copy()
            LS2.apply_move(move[0], move[1])
            adj_list.append(LS2)
        return adj_list
            
            



In [83]:
N = 4
S = N+1
LS = GeneralLatinSquare(N, S)
LS

[[[1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]]

 [[0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]
  [1. 0. 0. 0. 0.]]

 [[0. 0. 1. 0. 0.]
  [0. 0. 0. 1. 0.]
  [1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]]

 [[0. 0. 0. 1. 0.]
  [1. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 1. 0. 0.]]]

In [84]:
LS.adj()

[[[[ 0.  1.  0.  0.  0.]
   [ 1.  0.  0.  0.  0.]
   [ 0.  0.  1.  0.  0.]
   [ 0.  0.  0.  1.  0.]]
 
  [[ 1.  0.  0.  0.  0.]
   [-1.  1.  1.  0.  0.]
   [ 0.  0.  0.  1.  0.]
   [ 1.  0.  0.  0.  0.]]
 
  [[ 0.  0.  1.  0.  0.]
   [ 0.  0.  0.  1.  0.]
   [ 1.  0.  0.  0.  0.]
   [ 0.  1.  0.  0.  0.]]
 
  [[ 0.  0.  0.  1.  0.]
   [ 1.  0.  0.  0.  0.]
   [ 0.  1.  0.  0.  0.]
   [ 0.  0.  1.  0.  0.]]],
 [[[0. 0. 1. 0. 0.]
   [0. 1. 0. 0. 0.]
   [1. 0. 0. 0. 0.]
   [0. 0. 0. 1. 0.]]
 
  [[0. 1. 0. 0. 0.]
   [0. 0. 1. 0. 0.]
   [0. 0. 0. 1. 0.]
   [1. 0. 0. 0. 0.]]
 
  [[1. 0. 0. 0. 0.]
   [0. 0. 0. 1. 0.]
   [0. 0. 1. 0. 0.]
   [0. 1. 0. 0. 0.]]
 
  [[0. 0. 0. 1. 0.]
   [1. 0. 0. 0. 0.]
   [0. 1. 0. 0. 0.]
   [0. 0. 1. 0. 0.]]],
 [[[ 0.  0.  0.  1.  0.]
   [ 0.  1.  0.  0.  0.]
   [ 0.  0.  1.  0.  0.]
   [ 1.  0.  0.  0.  0.]]
 
  [[ 0.  1.  0.  0.  0.]
   [ 0.  0.  1.  0.  0.]
   [ 0.  0.  0.  1.  0.]
   [ 1.  0.  0.  0.  0.]]
 
  [[ 0.  0.  1.  0.  0.]
   [ 0.  0.  0.  1.  0.]


In [85]:
import sys
sys.setrecursionlimit(10**9)

In [86]:
def dfs(visited, LS):
    visited[str(LS)] = LS
    for V in LS.adj():
        if not str(V) in visited:
            dfs(visited, V)

In [87]:
visited = dict()
dfs(visited, LS)

In [88]:
len(visited)

7488

In [89]:
proper_cnt = 0
for LS in visited.values():
    if LS.improper() == None:
        proper_cnt += 1
proper_cnt

576