https://www.janestreet.com/puzzles/candy-collectors/ october 2020 puzzle

We will try to solve an abstraction of the problem: 
Fill a 5x5 grid with integers in range [0, 5]

Constraints:
1. Column sums are equal to 5 (implies total sum is 25)
2. Row sums are equal to 5 (implies total sum is 25)
3. There is a strict maximum in each column

Each column represents a type of candy

Each row represents a child

If we count the number of ways we can fill such a grid, and calculate the total amount of ways we can fill a grid without constraint 2 and 3, we can compute the probability by dividing these two numbers.

Approach:
Recursively fill each element of the grid (backtracking), stop expanding if we recognize that a constraint has been violated

Optimizations:
- if column not filled but sum > 5: constraints cannot be satisfied => reject
- if row not filled but sum > 5: constraints cannot be satisfied => reject
- Each maximum has a value >= 2. We can assign initial maxima to the grid and then assign 3 pieces of candy instead of 5 pieces of candy. This reduces the search space a lot. We compensate by multiplying by 5! at the end: we assigned maxima, but in reality maxima could be in any possible ordering.


In [None]:
# N: total amount of grid fills without constraint #2 and #3
# Multinomial: 25! / 5! 5! 5! 5! 5!

def fact(n):
  acc = 1
  for i in range(2, n+1):
    acc *= i
  return acc

N = fact(25)//((fact(5))**5)
print("N:", N)

N: 623360743125120


In [None]:
na = 0

grid = []

for i in range(5):
    grid.append([0, 0, 0, 0, 0]) 

for i in range(5):
    grid[i][i] = 2

def printGrid():
    for r in range(5):
        for c in range(5):
            print(f"{grid[r][c]}\t", end="")
        print("")
    print("")

def passesConstraints(r, c):
    for rn in range(r):
        # filled rows
        acc = 0
        for cn in range(5):
            acc += grid[rn][cn]
        # print(f"acc: {acc}")
        if acc != 5: return False

    if r != 5:
        # unfilled row
        acc = 0
        for cn in range(5):
            acc += grid[r][cn]
        if acc > 5: return False
    else: # edge case
        c = 5
        r = 4

    if r == 4:
        for cn in range(c):
            # filled column
            acc = 0
            maximum = -1
            occurrences = -1
            for rn in range(5):
                acc += grid[rn][cn]
                if grid[rn][cn] > maximum:
                    maximum = grid[rn][cn]
                    occurrences = 1
                elif grid[rn][cn] == maximum:
                    occurrences += 1
            if acc != 5:
                return False
            if occurrences > 1:
                return False

    for cn in range(5): # possibly unfilled column
        acc = 0
        for rn in range(5):
            acc += grid[rn][cn]
        if acc > 5: return False

    return True

def dfs(r, c):
    global na
    if not passesConstraints(r, c): 
        return
    if r == 5:
        # grid is filled and passes constraints: increment na
        na += 1
        if na % 10000 == 0:
            print(f"na >= {na} * 5!")

        return 

    for v in range(0, 4): # add 0, 1, 2 or 3 pcs of candy
        grid[r][c] += v
        dfs(r + ((c+1) // 5), (c+1) % 5)
        grid[r][c] -= v
    return

dfs(0, 0)
na *= fact(5)
print(f"na = {na}")

na >= 10000 * 5!
na >= 20000 * 5!
na = 2633760


In [None]:
from math import gcd
print(f"N = {N}")
print(f"n_a = {na}")
print(f"p = n_a/N = {na//gcd(na, N)}/{N//gcd(na, N)} = {na/N}")


N = 623360743125120
n_a = 2633760
p = n_a/N = 1829/432889404948 = 4.225097632545904e-09
