In [1]:
import copy
import itertools as its
import math
import os
import pathlib
import re
import sys
import time
from typing import Dict, List, Optional, Tuple, Union
from collections import Counter, defaultdict, deque

import networkx as nx
import numpy as np
import pandas as pd
from IPython.display import clear_output
from matplotlib import pyplot as plt

from aoc import sim_new as sim, testing, util

twopi = 2 * math.pi

%matplotlib inline

INPUT_PATH = pathlib.Path('..') / 'input' / 'dec16.txt'

In [2]:
def read_data(x: str) -> List[int]:
    return [int(y) for y in x.strip()]

In [3]:
data = read_data(INPUT_PATH.read_text().strip())

Let's start with the $n^2$ slow algorithm so that we can make sure everything is correct

In [4]:
def fft_round(the_input: List[int], pattern: List[int]) -> List[int]:
    output = []
    for i in range(len(the_input)):
        is_first = True
        total = 0
        pattern_idx = 0
        input_idx = 0
        while input_idx < len(the_input):
            num_times = i if is_first else i + 1
            is_first = False
            for _ in range(num_times):
                total += the_input[input_idx] * pattern[pattern_idx]
                input_idx += 1
                if input_idx >= len(the_input):
                    break
            pattern_idx = (pattern_idx + 1) % len(pattern)
        output.append(abs(total) % 10)
    return output

The following algorithm works as well, but it's also $n^2$

In [5]:
def fft_round_fast(the_input: List[int]):
    return [
        abs(sum(inp * (-((i // idx) % 4) + 2) * ((i // idx) % 2) for i, inp in enumerate(the_input, 1))) % 10
        for idx in range(1, len(the_input) + 1)
    ]


On the other hand, if you stare hard enough from the right, you see that each round you're shifting a bunch of masks to the left and shortening them. The number of masks at round $i$ (counting down from the right) is proportional $N / i$, and so while in total this is an $n^2$ algorithm, the bulk of the pain is on the left.

This algorithm is implemented in the following functions. The first sets up the masks we'll need to keep track of. The second, `dec_round` performs one round of summing. However, this is *still* $n^2$ because it sums across all the runs. Finally, we implement and actual $N \log(N)$ algorithm in `dec_round_fast`.

Note that in practice, the hard part is memory management and Python's weird looping/access behavior. This problem is perfectly suited to make Python sad. :-/ It turns out that in this case, to report the answer to part 2, the answer wants us to report at an offset extremely far to the right (nearly 90% of the way toward the end). Since everything before the offset doesn't affect the value to the right of the offset, we add a parameter to just stop computing to the left.

In [6]:
def decrement(runs, N):
    new_runs = []
    for i, (start, end, mask) in enumerate(runs):
        new_runs.append((start - 1 - i - i, end - 2 - i - i, mask))
    
    start, end, mask = new_runs[-1]
    while end < N:
        #from IPython.core.debugger import Tracer; Tracer()()
        length = end - start
        start = end + length
        end = start + length
        mask = -mask
        if start < N:
            new_runs.append((start, end, mask))
        else:
            break
    return new_runs

In [7]:
def dec_round(the_input: List[int]) -> List[int]:
    N = len(the_input)
    runs = [(N - 1, N + N - 1, 1)]
    output = [the_input[-1]]
    while runs[0][0] > 0:
        runs = decrement(runs, N)
        output.append(sum(mask * sum(the_input[start:end]) for start, end, mask in runs))
    return [abs(x) % 10 for x in reversed(output)]

In [8]:
def dec_round_fast(the_input: List[int], offset: int = 0) -> List[int]:
    N = len(the_input)
    runs = [(N - 1, N + N - 1, 1)]
    output = [the_input[-1]]
    while runs[0][0] > offset:
        new_runs = decrement(runs, N)
        total = output[-1]
        for (old_start, old_end, mask), (new_start, new_end, _) in zip(runs, new_runs):
            total += mask * (sum(the_input[new_start: old_start]) - sum(the_input[new_end:old_end]))
        for new_start, new_end, mask in new_runs[len(runs):]:
            total += mask * sum(the_input[new_start:new_end])
        output.append(total)
        runs = new_runs
    output = [abs(x) % 10 for x in reversed(output)]
    if len(output) < len(the_input):
        output = ([0] * (len(the_input) - len(output))) + output
    return output

In [9]:
def fft_rounds(the_input: List[int], pattern: List[int], num_rounds: int) -> List[int]:
    output = copy.copy(the_input)
    for num_round in range(num_rounds):
        output = fft_round(output, pattern)
    return output

In [10]:
def fft_rounds_fast(the_input: List[int], num_rounds: int) -> List[int]:
    output = copy.copy(the_input)
    for num_round in range(num_rounds):
        output = fft_round_fast(output)
    return output

In [11]:
def dec_rounds(the_input: List[int], num_rounds: int) -> List[int]:
    output = copy.copy(the_input)
    for num_round in range(num_rounds):
        output = dec_round(output)
    return output

In [12]:
def dec_rounds_fast(the_input: List[int], num_rounds: int, offset: int = 0) -> List[int]:
    output = copy.copy(the_input)
    for num_round in range(num_rounds):
        output = dec_round_fast(output, offset=offset)
    return output

In [13]:
# Some simple tests

testing.assert_all_equal(fft_round(read_data('12345678'), [0, 1, 0, -1]), [4,8,2,2,6,1,5,8])
testing.assert_all_equal(fft_rounds(read_data('12345678'), [0, 1, 0, -1], 2), [3,4,0,4,0,4,3,8])

testing.assert_all_equal(fft_round_fast(read_data('12345678')), [4,8,2,2,6,1,5,8])
testing.assert_all_equal(fft_rounds_fast(read_data('12345678'), 2), [3,4,0,4,0,4,3,8])

testing.assert_all_equal(dec_round(read_data('12345678')), [4,8,2,2,6,1,5,8])
testing.assert_all_equal(dec_rounds(read_data('12345678'), 2), [3,4,0,4,0,4,3,8])

testing.assert_all_equal(dec_round_fast(read_data('12345678')), [4,8,2,2,6,1,5,8])
testing.assert_all_equal(dec_rounds_fast(read_data('12345678'), 2), [3,4,0,4,0,4,3,8])


testing.assert_all_equal(fft_round(read_data('80871224585914546619083218645595'), [0, 1, 0, -1]), dec_round(read_data('80871224585914546619083218645595')))

In [14]:
print(f'The answer to part 1 is {"".join(map(str, dec_rounds_fast(data, 100)[:8]))}')

The answer to part 1 is 23135243


In [15]:
offset = int(''.join(str(x) for x in data[:7]))
foo = data * 10000
for num_round in range(100):
    print(f'On round {num_round}')
    foo = dec_round_fast(foo, offset=offset)

On round 0
On round 1
On round 2
On round 3
On round 4
On round 5
On round 6
On round 7
On round 8
On round 9
On round 10
On round 11
On round 12
On round 13
On round 14
On round 15
On round 16
On round 17
On round 18
On round 19
On round 20
On round 21
On round 22
On round 23
On round 24
On round 25
On round 26
On round 27
On round 28
On round 29
On round 30
On round 31
On round 32
On round 33
On round 34
On round 35
On round 36
On round 37
On round 38
On round 39
On round 40
On round 41
On round 42
On round 43
On round 44
On round 45
On round 46
On round 47
On round 48
On round 49
On round 50
On round 51
On round 52
On round 53
On round 54
On round 55
On round 56
On round 57
On round 58
On round 59
On round 60
On round 61
On round 62
On round 63
On round 64
On round 65
On round 66
On round 67
On round 68
On round 69
On round 70
On round 71
On round 72
On round 73
On round 74
On round 75
On round 76
On round 77
On round 78
On round 79
On round 80
On round 81
On round 82
On round 83
On

In [16]:
print(f'The answer to part 2 is {"".join(map(str, foo[offset:offset+8]))}')

The answer to part 2 is 21130597


Addendum: Note that the offset is > 50% of the way through the vector, so the answer can actually be done thus:

In [27]:
foo = (data * 10000)[offset:]
for num_round in range(100):
    for i in range(len(foo) - 1, 0, -1):
        foo[i - 1] = (foo[i - 1] + foo[i]) % 10
print(f'The answer to part 2 is {"".join(map(str, foo[:8]))}')

The answer to part 2 is 21130597
