# Problem 34

https://projecteuler.net/problem=34

145 is a curious number, as 1! + 4! + 5! = 1 + 24 + 120 = 145.

Find the sum of all numbers which are equal to the sum of the factorial of their digits.

Note: as 1! = 1 and 2! = 2 are not sums they are not included.

## WIP: Delete...

## Notes

We can easily brute force this by iterating numbers 10 and above, breaking them down into their component digits, and summing up their factorials. The challenge here is when to stop. My hunch is that we can stop once the sum of the factorials is greater than the original number.

Hmm... Okay, my hunch is wrong. 9! is 362880. Therefore 1! + 9! > 19. If we had stopped here we would never find 145.

In [None]:
from scipy.special import factorial
factorial(9)

In [None]:
def fac(n):
    return int(factorial(n))
fac(3)

In [None]:
import numpy as np
N = 100000
for i in range(10, N+1):
    digits = list(''+str(i))
    factorials = list(map(lambda x: fac(int(x)), digits))
    summation = np.sum(factorials)
    #print('{} -> {} -> sum {}'.format(i, factorials, summation))
    if i == summation:
        print(i)

Brute force in the manner above is too slow. We can do better. For example we can save the result of the summations.

I think there's a better way...

## Take 2

Firstly we only need factorials for digits 0 to 9. It's much faster to lookup an array than to execute scipy's factorial function every time.

In [None]:
from scipy.special import factorial
def create_fac():
    factorials = [int(factorial(i)) for i in range(10)]
    def f(x):
        return factorials[x]
    return f
fac = create_fac()

In [None]:
time fac(9)

In [None]:
time factorial(9)

In [None]:
import numpy as np
def dfs(chain, max_length):
    concat_value = 0
    for i in chain:
        concat_value *= 10
        concat_value += i
    fac_sum = sum(map(lambda x: fac(x), chain))
    #print(concat_value, fac_sum)
    if fac_sum == concat_value and len(chain) >= 2:
        yield concat_value
    if len(chain) < max_length:
        if len(chain) == 0:
            lo = 1
        else:
            lo = 0
        for i in range(lo, 10):
            for x in dfs(chain + [i], max_length):
                yield x


In [None]:
print(sum(map(lambda x: fac(x), [1,4,5])))

In [None]:
list(dfs([],5))

## Thinking in terms of multi sets

Observe that the sum of factorials for 145 is the same as the sum of factorials for any permutation of those digits. 541, 154, etc.

Let the notation `{1145}` denote a multi set of digits 1, 1, 4, and 5. For clarity, we will always list the smaller digits first. Also, since it is a multi set, repeated digits are allowed. For example, the repeated 1s in `{1145}`.

Let the function F(x) map a multi set x into a sum of factorials, as per the definition in problem 34.

Let the function M(y) map a number y into a multi set corresponding to its digits. For example `M(123) == M(321)`.

Using this notation we can say the following about 145.

    F(M(145)) == 145

Generalizing, we can rephrase problem 34 as finding all values y such that `F(M(y)) == y`. 

Observe that `M(123)` is the same as `M(321)`. That is, a single multi set maps from multiple numbers. A many to one mapping. Another way to put this is to say that a multi set maps to one or more numbers. Given this, and that we are searching for factorial sums over multi sets, it makes more sense to search the space of multi sets than to search the space of numbers. 

Observe that given a multi set, there is exactly one factorial sum. Then we can evaluate if **that sum** is a curious number by computing its multi set. Therefore we can rephrase the problem in terms of multi sets as follows. This form is better suited for searching the space of multi sets. 

    M(F(x)) == x

The next piece of the puzzle is how to navigate the space of multi sets in an efficient manner. Also, how would we know when to stop?

## Vector representation for multi sets

Since the only set members in our multi sets are digits, we can represent our multi sets as vectors of digit counts. For example the multi set `{1145}` would be a vector `[0,2,0,0,1,1,0,0,0,0]`.

**Note:** The count at the 0-th position of the vector is also important because `0! == 1`.

Multi set equality using this notation is when two vectors have the same count at each digit position. In Python you can directly use the `==` operator.

## Code

In [19]:
# Define `fac(n)` function that returns the factorial of `n`.
# For values of n in 0..9. We precalculate the factorials
# since those are the only factorials we will need to solve
# this problem.
from scipy.special import factorial
def create_fac():
    factorials = [int(factorial(i)) for i in range(10)]
    def f(x):
        return factorials[x]
    return f
fac = create_fac()

In [20]:
# Define F(v) where v is a multiset in vector form.
def F(v):
    assert len(v) == 10
    s = 0
    for i in range(10):
        s += fac(i) * v[i]
    return s

In [27]:
# The following assertion should pass.
assert F([0,1,0,0,1,1,0,0,0,0]) == 145
assert F([1,0,0,0,1,2,0,0,1,0]) == 40585

In [22]:
# Vector equality works in python
assert [0,1,0,0,1,1,0,0,0,0] == [0,1,0,0,1,1,0,0,0,0]
assert [1,1,0,0,0,0,0,0,0,0] != [0,0,0,0,0,0,0,0,1,1]

In [23]:
# Define M(k) that converts a number k into a multiset vector.
def M(k):
    v = [0] * 10
    while k > 0:
        r = k % 10
        k = k // 10
        v[r] += 1
    return v

In [28]:
# The following assertions about M(k) should pass.
assert M(145) == [0,1,0,0,1,1,0,0,0,0]
assert M(541) == [0,1,0,0,1,1,0,0,0,0]
assert M(5141) == [0,2,0,0,1,1,0,0,0,0]
assert M(40585) == [1,0,0,0,1,2,0,0,1,0]

In [29]:
# Assertions to test relationships mentioned in the design.
assert F(M(145)) == 145
assert M(F([0,1,0,0,1,1,0,0,0,0])) == [0,1,0,0,1,1,0,0,0,0]

In [38]:
# try this....

def explore(v):
    # Let s be the factorial sum of multi set v
    s = F(v)
    # Let u be the multi set vector for s.
    u = M(s)

    # print(v, s, u)

    if u == v:
        if sum(u) > 1:
            yield s
    else:
        if sum(u) <= sum(v):
            for i in range(10):
                v[i] += 1
                for x in explore(v):
                    yield x
                v[i] -= 1


In [39]:
for i in explore(M(10)):
    print(i)

RecursionError: maximum recursion depth exceeded in comparison