# Problem 34

https://projecteuler.net/problem=34

145 is a curious number, as 1! + 4! + 5! = 1 + 24 + 120 = 145.

Find the sum of all numbers which are equal to the sum of the factorial of their digits.

Note: as 1! = 1 and 2! = 2 are not sums they are not included.

## Notes

We can easily brute force this by iterating numbers 10 and above, breaking them down into their component digits, and summing up their factorials. The challenge here is when to stop. My hunch is that we can stop once the sum of the factorials is greater than the original number.

Hmm... Okay, my hunch is wrong. 9! is 362880. Therefore 1! + 9! > 19. If we had stopped here we would never find 145.

In [23]:
from scipy.special import factorial
factorial(9)

array(362880.0)

In [28]:
def fac(n):
    return int(factorial(n))
fac(3)

6

In [53]:
import numpy as np
N = 100000
for i in range(10, N+1):
    digits = list(''+str(i))
    factorials = list(map(lambda x: fac(int(x)), digits))
    summation = np.sum(factorials)
    #print('{} -> {} -> sum {}'.format(i, factorials, summation))
    if i == summation:
        print(i)

145
40585


Brute force in the manner above is too slow. We can do better. For example we can save the result of the summations.

I think there's a better way...

## Take 2

Firstly we only need factorials for digits 0 to 9. It's much faster to lookup an array than to execute scipy's factorial function every time.

In [39]:
from scipy.special import factorial
def create_fac():
    factorials = [int(factorial(i)) for i in range(10)]
    def f(x):
        return factorials[x]
    return f
fac = create_fac()

In [40]:
time fac(9)

CPU times: user 5 µs, sys: 1 µs, total: 6 µs
Wall time: 11 µs


362880

In [41]:
time factorial(9)

CPU times: user 70 µs, sys: 0 ns, total: 70 µs
Wall time: 77 µs


array(362880.0)

In [74]:
import numpy as np
def dfs(chain, max_length):
    concat_value = 0
    for i in chain:
        concat_value *= 10
        concat_value += i
    fac_sum = sum(map(lambda x: fac(x), chain))
    #print(concat_value, fac_sum)
    if fac_sum == concat_value and len(chain) >= 2:
        yield concat_value
    if len(chain) < max_length:
        if len(chain) == 0:
            lo = 1
        else:
            lo = 0
        for i in range(lo, 10):
            for x in dfs(chain + [i], max_length):
                yield x


In [72]:
print(sum(map(lambda x: fac(x), [1,4,5])))

145


In [75]:
for i in dfs([], 5):
    print(i)

145
40585


In [79]:
list(dfs([],10))

KeyboardInterrupt: 