In [142]:
import numpy as np
import heapq

s = 300
gridserial = 4455
X, Y = np.meshgrid(range(s), range(s), indexing='ij')   

def getdigit(number, n):
    try:
        return int(str(number)[n])
    except IndexError:
        return 0

@np.vectorize
def powerlevel(x, y):
    return getdigit(((x+10)*y + gridserial)*(x+10),-3) - 5

# summed-area table 
%timeit np.pad(powerlevel(X, Y).cumsum(axis=0).cumsum(axis=1), pad_width=((1, 0,),)*2, mode='constant')  # 75ms
sag = np.pad(powerlevel(X, Y).cumsum(axis=0).cumsum(axis=1), pad_width=((1, 0,),)*2, mode='constant')
def totpower(x, y, dial=3):
    return sag[x, y] - sag[x+dial, y] - sag[x, y+dial] + sag[x+dial, y+dial]


def max_total_power(dial):
    lim = s-dial+1
    pg = totpower(X[:lim, :lim], Y[:lim, :lim], dial)
    mpg = np.argmax(pg); mpg = np.unravel_index(mpg, pg.shape)
    return mpg, pg[mpg]
    
%timeit max_total_power(3)  # 3ms
print('Part 1:', max_total_power(3))

76 ms ± 396 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.08 ms ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Part 1: ((21, 54), 33)


In [143]:
# Part 2: brute force of each dial
def find_best_dial():
    dials = {dial: max_total_power(dial) for dial in range(1, 300+1)}
    best = max(dials.items(), key=lambda x: x[1][1])
    return best[0], best[1][0], best[1][1]

%timeit find_best_dial()  # 320ms
print(find_best_dial())

319 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
(11, (236, 268), 74)


However, we can improve on this brute force approach. If the maximum total power at dial size 10 is less than zero, we know that the maximum total power size at integer multiples of 10 will be even lower. Therefore:

In [140]:
def find_best_dial2():
    no_check = set()
    dials = {}
    for dial in range(1, 300+1):
        if dial in no_check:
            #print("don't check", dial)
            continue
        rv = pl2(dial)
        dials[dial] = rv
        if rv[1] < 0:
            #print(dial)
            for mult in range(1, s//dial):
                no_check.add(mult*dial)
            #print('nocheck', sorted(no_check))
            
            
    best = max(dials.items(), key=lambda x: x[1][1])
    return best[0], best[1][0], best[1][1]

%timeit find_best_dial()  # 190ms
print(find_best_dial())

3.09 ms ± 29.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
((21, 54), 33)
187 ms ± 785 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
(11, (236, 268), 74)


However, we can optimize this further. We can find an upper bound on the maximum power level for a dial size using the following universal inequality:

$$ \mathrm{max}(x + y) \leq \mathrm{max}(x) + \mathrm{max}(y) $$

This leads us to

$$\mathrm{max\_tp}(d) \leq \sum_{i \in T_d} \mathrm{max\_tp}(d_i)$$

where $\mathrm{max\_tp}(d)$ is the maximum power level for a given dial size $d$, and $T_d$ is a square tiling of an integer square with side length $d$, $i$ is a square contained in that tiling, and $d_i$ is its side length.

For example, if $d=4$, we could tile this $4 \times 4$ square into 4 squares of $d_i=2$, i.e. 

$$\mathrm{max\_tp}(4) \leq 4 \cdot \mathrm{max\_tp}(2)$$.

but it could also be tiled into 16 $1 \times 1$ squares. However, this is less useful since only larger dial sizes will contain lower numbers. As such, this becomes a task of tiling a $d \times d$ square with the biggest possible squares, i.e. with the least number of squares.

For prime number dial sizes, finding this tiling is not straightforward and requires dynamic programming. However, since it can be pre-computed unrelated to the actual input, it can speed up our algorithm significantly.

In [145]:
tiling_d = {1: Counter({1: 1}), 
            2: Counter({1: 4}), 
            3: Counter({1: 5, 2: 1}),
}

from collections import Counter

def memoize(f):
    def helper(x):
        if x not in tiling_d:            
            tiling_d[x] = f(x)
        return tiling_d[x]
    return helper

@memoize
def tiling(d):
    try:
        return tiling_d[d]
    except KeyError:
        for d_i in reversed(range(2, d//2 + 1)):
            if d % d_i == 0:
                if d//d_i == 2:
                    return Counter({d_i: 4})
                return Counter({d-d_i: 1, d_i: 2*d//d_i - 1})
                #print('possible tiling', d-d_i, d_i, 'weight', 2 + 2*(d/d_i-1))
                #return n, (n//2,) * 4

        # prime tiling (not optimal, but good enough)
        return Counter({1: 2*(d//2), d//2: 3, d//2+1: 1})

[tiling(i) for i in range(1, 300)]
for d in [1,2,3,4,5,11,31]:
    print(tiling_d[d])

Counter({1: 1})
Counter({1: 4})
Counter({1: 5, 2: 1})
Counter({2: 4})
Counter({1: 4, 2: 3, 3: 1})
Counter({1: 10, 5: 3, 6: 1})
Counter({1: 30, 15: 3, 16: 1})


In [144]:
def upper_bound(d, max_tp):
    t = tiling(d)
    #print('         ', d, t)
    return sum(n_i*max_tp[d_i] for d_i, n_i in t.items())

def find_best_dial2():
    no_check = set()
    best = float("-inf")
    max_tp = {1: float("inf")}
    for dial in range(1, 300+1):
        ub = upper_bound(dial, max_tp)
        if ub < best:
            #print('skip', dial, 'ub', ub)
            max_tp[dial] = ub
            continue
        #print('noskip',dial, 'ub', ub)
        loc, _max_tp = pl2(dial)
        #print(dial, "best:", _max_tp)
        max_tp[dial] = _max_tp
        best = max(_max_tp, best)
            
    best = max(max_tp.items(), key=lambda x: x[1])
    return best[0], best[1]

%timeit find_best_dial2() # 130 ms, could be improved further with better prime tilings
find_best_dial2()

128 ms ± 4.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


(11, 74)