# Sum of Squares

Find and output the given number's shortest square sum.

For example: 

    12 = 2^2 + 2^2 + 2^2 
    
not: 

    12 = 3^2 + 1^2 + 1^2 + 1^2

Output should be the list of the squares: `{2 2 2}`

In [1]:
def square_sum_memoized(n):
    import functools

    def memoize(fn):
        cache = fn.cache = {}
        @functools.wraps(fn)
        def memoizer(arg):
            if arg not in cache:
                cache[arg] = fn(arg)
            return cache[arg]
        return memoizer

    @memoize
    def _ss(n):
        sqrt = int(n**0.5)
        if sqrt**2 == n:
            return (sqrt,)
        mn = (1,) * min(n, 5)
        for i in range(1, sqrt+1):
            s = (i,) + _ss(n-i**2)
            if len(s) < len(mn):
                mn = s
            if len(mn) == 2:
                break
        return mn
    
    for i in range(1, n+1):
        _ss(i)
        
    return _ss(n)

In [2]:
%timeit square_sum_memoized(10**5)

1 loops, best of 3: 21.5 s per loop


In [3]:
for n in range(1, 26):
    print(n, square_sum_memoized(n))

1 (1,)
2 (1, 1)
3 (1, 1, 1)
4 (2,)
5 (1, 2)
6 (1, 1, 2)
7 (1, 1, 1, 2)
8 (2, 2)
9 (3,)
10 (1, 3)
11 (1, 1, 3)
12 (2, 2, 2)
13 (2, 3)
14 (1, 2, 3)
15 (1, 1, 2, 3)
16 (4,)
17 (1, 4)
18 (3, 3)
19 (1, 3, 3)
20 (2, 4)
21 (1, 2, 4)
22 (2, 3, 3)
23 (1, 2, 3, 3)
24 (2, 2, 4)
25 (5,)


In [4]:
def square_sum_arrays(n):
    
    def _ss_length(x):
        if not lengths[x]:
            sqrt = int(x**0.5)
            if sqrt**2 == x:
                lengths[x] = 1
                roots[x] = (sqrt,)
            elif lengths[x-1] == 1:
                lengths[x] = 2
                roots[x] = (1,) + roots[x-1]
            else:
                mn = x
                mnii = 1
                for i in range(1, sqrt+1):
                    ii = i**2
                    length = 1 + _ss_length(x-ii)
                    if length < mn:
                        mn = length
                        mnii = ii
                    if mn == 2:
                        break
                lengths[x] = mn
                roots[x] = roots[mnii] + roots[x-mnii]
        return lengths[x]
    
    roots = [None] * (n+1)
    lengths = [None] * (n+1)
    
    for i in range(1, n+1):
        _ss_length(i)

    return roots[n]

In [5]:
%timeit square_sum_arrays(10**5)

1 loops, best of 3: 16.4 s per loop


In [6]:
for n in range(1, 26):
    print(n, square_sum_arrays(n))

1 (1,)
2 (1, 1)
3 (1, 1, 1)
4 (2,)
5 (1, 2)
6 (1, 1, 2)
7 (1, 1, 1, 2)
8 (2, 2)
9 (3,)
10 (1, 3)
11 (1, 1, 3)
12 (2, 2, 2)
13 (2, 3)
14 (1, 2, 3)
15 (1, 1, 2, 3)
16 (4,)
17 (1, 4)
18 (3, 3)
19 (1, 3, 3)
20 (2, 4)
21 (1, 2, 4)
22 (2, 3, 3)
23 (1, 2, 3, 3)
24 (2, 2, 4)
25 (5,)
