In [1]:
from collections import defaultdict


def load_data(path):
    with open(path) as f:
        data = f.read().splitlines()
    return data


def preprocess_data(data):
    return [int(i) for i in data]


def get_next_secret_number(secret_number):
    secret_number ^= (64 * secret_number) % 16_777_216
    secret_number ^= (secret_number // 32) % 16_777_216
    secret_number ^= (2_048 * secret_number) % 16_777_216
    return secret_number


def solve_part_1(initial_secret_numbers, iterations):
    secret_numbers = initial_secret_numbers.copy()
    for i, secret_number in enumerate(secret_numbers):
        for iteration in range(iterations):
            secret_number = get_next_secret_number(secret_number)
        secret_numbers[i] = secret_number
    return sum(secret_numbers)


def solve_part_2(initial_secret_numbers, iterations):
    secret_numbers = initial_secret_numbers.copy()
    total_prices = defaultdict(int)
    last_digits = []
    diffs = []
    for i, secret_number in enumerate(secret_numbers):
        diffs = diffs[-4:]
        last_digits = last_digits[-4:]
        prices = defaultdict(int)
        for iteration in range(iterations):
            last_digit = secret_number % 10
            last_digits.append(last_digit)
            if iteration != 0:
                diff = last_digits[-1] - last_digits[-2]
                diffs.append(diff)
            if len(diffs) >= 4:
                sequence = tuple(diffs[-4:])
                if sequence not in prices:
                    prices[sequence] = last_digit
            secret_number = get_next_secret_number(secret_number)
        secret_numbers[i] = secret_number
        for k, v in prices.items():
            total_prices[k] += v
    return sorted((v,k) for k, v in total_prices.items())[-1][0]


data = load_data('input.txt')
initial_secret_numbers = preprocess_data(data)
print(solve_part_1(initial_secret_numbers, 2_000))
print(solve_part_2(initial_secret_numbers, 2_000))

20215960478
2221
