## Counting Rubik's Snake shapes, up to reversal symmetry

*Dmytro Fedoriaka, August 2024*

This is an addition to the [main notebook](count-shapes.ipynb). 

Here I am computing one of sequences defined in that notebook. Namely, number of Rubik's Snake shapes *up to reversal*. That is, we count snakes without distinguishing tail and head. Let's denote this sequence $D_n$.

It tuns out that we can use pre-computed sequence $S_n$ (see the main notebook) to compute $D_n$ with much faster asymptotic, $O(2^n \cdot n)$ instead of $O(4^n)$.

Denote $F_n$ - the set of all shapes that are mapped to themselves by reversal. Then $2 D_n = S_n+F_n$. From this, $D_n = (S_n +F_n)/2$
 * This can be proven by counting argument. Consider all strings from a set corresponding to $D_n$, and add to this set all reverses. In resulting multiset we will get all strings corresponding to $S_n$, but some of the strings will appear twice. These are exactly those that are fixed point for reversal.
 * Alternatively, this follows from Burnside's lemma.

How to compute $F_n$? These are exactly those valid shapes whose formula is a palindrome. So we can explicitly enumerate all palindromes of length $n-1$ with characters 0,1,2,3 - there are $4^{[n/2]} = O(2^n)$ of them. Then we can directly check whether each of them is a formula describing a valid shapes. The complexity is $O(2^n \cdot n)$.


In [1]:
import numpy as np
import numba
import time 

# Prepare the grid.
MAX_N=26
K = 2*(MAX_N//2)
box_size = 2*K+1
dx,dy,dz=1,box_size,box_size**2
CENTER_COORD = K*(dx+dy+dz)

# Pre-calculate geometry.
CUBE = [[1,3,4,2],[0,2,5,3],[0,4,5,1],[0,1,5,4],[0,3,5,2],[1,2,4,3]]
DELTAS = np.array([dy,dz,dx,-dx,-dz,-dy]) # "+y","+z","+x","-x","-z","-y"
WEDGE_ID_TO_FACE_IDS = dict()
FACE_IDS_TO_WEDGE_ID = dict()

def register_wedge(f1, f2, wedge_id):
  WEDGE_ID_TO_FACE_IDS[wedge_id] = (f1, f2)
  FACE_IDS_TO_WEDGE_ID[(f1, f2)] = wedge_id

for i, (f1, f2) in enumerate([(0,1),(0,2),(0,3),(0,4),(1,2),(1,3)]):
  register_wedge(f1,f2, i+1)
  register_wedge(f2,f1, i+1+16)
  register_wedge(5-f1,5-f2, 13-(i+1))
  register_wedge(5-f2,5-f1, 13-(i+1)+16)
  
WEDGE_ID_TO_NEXT_DELTA = np.zeros(36, dtype=np.int64)
ROT_AND_WEDGE_ID_TO_NEXT_WEDGE_ID = np.zeros(36*4, dtype=np.int64)
for f1 in range(6):
  for f2 in CUBE[f1]:
    wedge_id = FACE_IDS_TO_WEDGE_ID[(f1,f2)]
    f1p=5-f2
    f2p=[5-f1,0,f1,0]
    f2p[1]=CUBE[f1][(CUBE[f1].index(f2)+1)%4]
    f2p[3]=5-f2p[1]
    WEDGE_ID_TO_NEXT_DELTA[wedge_id] = DELTAS[f1p]
    for rot in range(4):
      ROT_AND_WEDGE_ID_TO_NEXT_WEDGE_ID[wedge_id+rot*36] = FACE_IDS_TO_WEDGE_ID[(f1p, f2p[rot])]

@numba.jit("i8(i8,i8)", inline="always")
def encode_wedge(coord, wedge_id):
  return (coord<<6) + wedge_id
      
@numba.jit("(i8,i8,i8[:],i8[:])", inline="always")
def push_wedge(wedge_coord, wedge_id, wedges, cubes):
  #print("PUSH", wedge_coord, wedge_id)
  wedges[0] -= 1
  wedges[wedges[0]] = encode_wedge(wedge_coord, wedge_id)
  cubes[wedge_coord] += wedge_id&15
  
@numba.jit("i8(i8,i8)", inline="always")
def get_next_wedge_coord(last_wedge_id, last_wedge_coord):
  return last_wedge_coord + WEDGE_ID_TO_NEXT_DELTA[last_wedge_id]

@numba.jit("i8(i8,i8)", inline="always")
def get_next_wedge_id(last_wedge_id, rot):
  return ROT_AND_WEDGE_ID_TO_NEXT_WEDGE_ID[last_wedge_id+36*rot]

@numba.jit("i8(i8[:],i8[:],i8[:])")
def is_shape_valid_fast(formula, wedges, cubes):
  if len(formula)==0:
    return 1
  last_wedge = wedges[wedges[0]]
  last_wedge_coord,last_wedge_id = last_wedge>>6,last_wedge&63 
  next_wedge_coord = get_next_wedge_coord(last_wedge_id,last_wedge_coord)
  next_wedge_id = get_next_wedge_id(last_wedge_id, formula[0])
  next_wedge_occ_type = next_wedge_id&15
  next_cube_occ_type = cubes[next_wedge_coord]
  can_push = next_cube_occ_type==0 or (next_cube_occ_type+next_wedge_occ_type == 13)
  if can_push:
    push_wedge(next_wedge_coord, next_wedge_id, wedges, cubes)
    ans = is_shape_valid_fast(formula[1:], wedges, cubes)
    cubes[next_wedge_coord] -= next_wedge_occ_type  # pop
    wedges[0] += 1                                  # pop
    return ans
  else:   
    return 0

@numba.jit("i8(i8,i8[:],i8[:])")
def count_palindrome_shapes(n, wedges, cubes):
  ans = 0
  rots = np.zeros(n-1, dtype=np.int64)
  for i in range(4**(n//2)):
    for j in range(n//2):
      rots[j]=(i>>(2*j))&3
      rots[n-2-j]=rots[j]
    if is_shape_valid_fast(rots, wedges, cubes): 
      ans+=1
  return ans  
  
class RubikSnakeCounter:
  def __init__(self):
    self.wedges=np.zeros(MAX_N+1, dtype=np.int64)
    self.wedges[0]=MAX_N+1 # wedges[0] indicates last wedge index
    self.cubes = np.zeros(box_size**3, dtype=np.int64)
    push_wedge(CENTER_COORD, FACE_IDS_TO_WEDGE_ID[(0,3)], self.wedges, self.cubes)  # Initial wedge.

  def count_palindrome_shapes(self, n):
    return count_palindrome_shapes(n, self.wedges, self.cubes)

# Computed in the main notebook.
S = [0, 1, 4, 16, 64, 241, 920, 3384, 12585, 46471, 172226, 633138, 2333757, 8561679, 31462176, 
     115247629, 422677188, 1546186675, 5661378449, 20689242550, 75663420126, 276279455583, 
     1009416896015, 3683274847187, 13446591920995, 49037278586475, 178904588083788]
assert len(S) == MAX_N+1 
D = [0]*(MAX_N+1)

t0=time.time()
ctr = RubikSnakeCounter()  
for n in range(1,27):
  F = ctr.count_palindrome_shapes(n)
  D[n] = (F+S[n])//2
  print(f"S[{n}]={S[n]}, F[{n}]={F}, D[{n}]={D[n]}")
print("Total time: %fs." % (time.time()-t0))

print("Answer: D=", D)


S[1]=1, F[1]=1, D[1]=1
S[2]=4, F[2]=4, D[2]=4
S[3]=16, F[3]=4, D[3]=10
S[4]=64, F[4]=16, D[4]=40
S[5]=241, F[5]=13, D[5]=127
S[6]=920, F[6]=60, D[6]=490
S[7]=3384, F[7]=52, D[7]=1718
S[8]=12585, F[8]=221, D[8]=6403
S[9]=46471, F[9]=185, D[9]=23328
S[10]=172226, F[10]=802, D[10]=86514
S[11]=633138, F[11]=700, D[11]=316919
S[12]=2333757, F[12]=2957, D[12]=1168357
S[13]=8561679, F[13]=2483, D[13]=4282081
S[14]=31462176, F[14]=10820, D[14]=15736498
S[15]=115247629, F[15]=9199, D[15]=57628414
S[16]=422677188, F[16]=39608, D[16]=211358398
S[17]=1546186675, F[17]=33105, D[17]=773109890
S[18]=5661378449, F[18]=144593, D[18]=2830761521
S[19]=20689242550, F[19]=122038, D[19]=10344682294
S[20]=75663420126, F[20]=527782, D[20]=37831973954
S[21]=276279455583, F[21]=439415, D[21]=138139947499
S[22]=1009416896015, F[22]=1922239, D[22]=504709409127
S[23]=3683274847187, F[23]=1613723, D[23]=1841638230455
S[24]=13446591920995, F[24]=7005651, D[24]=6723299463323
S[25]=49037278586475, F[25]=5817729, D[25]

Validation:
* First 14 terms match with first terms computed with [snek](https://github.com/scholtes/snek).
* I also wrote my own program explicitly enumerating all valid states and doing deduplication, see [here](scracth.ipynb). I used it to compute the first 12 terms and results match.
* However, my value for $D_{24}$ is 6723299463323, which differs from value found [here](https://blog.ylett.com/2011/09/rubiks-snake-combinations.html) by 0.02%.