In [1]:
import numpy as np

s = 300
gridserial = 4455
X, Y = np.meshgrid(range(s), range(s), indexing='ij')

def powerlevel(x, y):
    return (((x+10)*y + gridserial)*(x+10) // 100 % 10) - 5

%timeit powerlevel(X, Y)  # 3ms
grid = powerlevel(X, Y)
# summed-area table 
%timeit np.pad(grid.cumsum(axis=0).cumsum(axis=1), pad_width=((1, 0,),)*2, mode='constant')  # 2ms
sag = np.pad(grid.cumsum(axis=0).cumsum(axis=1), pad_width=((1, 0,),)*2, mode='constant')
def totpower(x, y, dial=3):
    return sag[x, y] - sag[x+dial, y] - sag[x, y+dial] + sag[x+dial, y+dial]


def max_total_power(dial):
    lim = s-dial+1
    pg = totpower(X[:lim, :lim], Y[:lim, :lim], dial)
    mpg = np.argmax(pg); mpg = np.unravel_index(mpg, pg.shape)
    return mpg, pg[mpg]
    
%timeit max_total_power(3)  # 3ms
print('Part 1:', max_total_power(3))

2.77 ms ± 21.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.23 ms ± 29.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.15 ms ± 97.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Part 1: ((21, 54), 33)


In [2]:
# Part 2: brute force of each dial
def find_best_dial():
    dials = {dial: max_total_power(dial) for dial in range(1, 300+1)}
    best = max(dials.items(), key=lambda x: x[1][1])
    return best[0], best[1][0], best[1][1]

%timeit find_best_dial()  # 320ms
print(find_best_dial())

321 ms ± 6.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
(11, (236, 268), 74)


However, we can improve on this brute force approach. If the maximum total power at dial size 10 is less than zero, we know that the maximum total power size at integer multiples of 10 will be even lower. Therefore:

In [3]:
def find_best_dial2():
    no_check = set()
    dials = {}
    for dial in range(1, 300+1):
        if dial in no_check:
            #print("don't check", dial)
            continue
        rv = max_total_power(dial)
        dials[dial] = rv
        if rv[1] < 0:
            #print(dial)
            for mult in range(1, s//dial):
                no_check.add(mult*dial)
            #print('nocheck', sorted(no_check))
            
    best = max(dials.items(), key=lambda x: x[1][1])
    return best[0], best[1][0], best[1][1]

%timeit find_best_dial2()  # 190ms
print(find_best_dial2())

188 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
(11, (236, 268), 74)


However, we can optimize this further. We can find an upper bound on the maximum power level for a dial size using the following universal inequality:

$$ \mathrm{max}(x + y) \leq \mathrm{max}(x) + \mathrm{max}(y) $$

This leads us to

$$\mathrm{max\_tp}(d) \leq \sum_{i \in T_d} \mathrm{max\_tp}(d_i)$$

where $\mathrm{max\_tp}(d)$ is the maximum power level for a given dial size $d$, and $T_d$ is a square tiling of an integer square with side length $d$, $i$ is a square contained in that tiling, and $d_i$ is its side length.

For example, if $d=4$, we could tile this $4 \times 4$ square into 4 squares of $d_i=2$, i.e. 

$$\mathrm{max\_tp}(4) \leq 4 \cdot \mathrm{max\_tp}(2)$$.

but it could also be tiled into 16 $1 \times 1$ squares. However, this is less useful since only larger dial sizes will contain lower numbers. As such, this becomes a task of tiling a $d \times d$ square with the largest possible smallest squares, or perhaps with the least number of squares [[1](https://www2.stetson.edu/~efriedma/mathmagic/1298.html)].

In general finding this tiling is not straightforward (e.g. prime number dial sizes) and requires integer linear programming [[2](http://www.wm-archive.uni-bayreuth.de/fileadmin/Sascha/Publikationen2/square.pdf)]. However, since it can be pre-computed unrelated to the actual input, it can speed up our algorithm significantly. 

However, a much simpler approach that works just as well in practice is to simply tile odd-numbered squares into two pairs of squares with sidelengths differing by 1, and an overlap of one square in the middle. Since $0 \leq \mathrm{max}(x) - \mathrm{min}(x)$, we can account for this overlap and have

$$\mathrm{max\_tp}(2k+1) \leq 2 \cdot \mathrm{max\_tp}(k+1) + 2 \cdot \mathrm{max\_tp}(k) - \mathrm{min\_tp}(1)$$

In [4]:
def upper_bound(d, max_tp, min_p):
    if d % 2 == 0:
        return 4*max_tp[d//2]
    elif d == 1:
        return max_tp[1]
    else:
        return 2*max_tp[d//2 + 1] + 2*max_tp[d//2] - min_p

def find_best_dial3():
    no_check = set()
    best = float("-inf")
    max_tp = {1: float("inf")}
    min_p = np.min(grid)
    skipped = 0
    for dial in range(1, 300+1):
        ub = upper_bound(dial, max_tp, min_p)
        if ub < best:
            #print('skip', dial, 'ub', ub)
            max_tp[dial] = ub
            skipped += 1
            continue
        #print('noskip',dial, 'ub', ub)
        loc, _max_tp = max_total_power(dial)
        #print(dial, "best:", _max_tp)
        max_tp[dial] = _max_tp
        best = max(_max_tp, best)
            
    best = max(max_tp.items(), key=lambda x: x[1])
    return best[0], best[1], "{:.1%} skipped".format(skipped/300)

%timeit find_best_dial3()  # 120ms
find_best_dial3()

118 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


(11, 74, '83.3% skipped')