# Day 14
## Part 1

In [2]:
import parse
from collections import defaultdict

Use bitwise operators to apply the mask. Split the mask into two, one with the specified 1s and zeroes everywhere else - bitwise `or` with the value to apply this; and one with the specifies zeroes and 1s everywhere else, bitwise `and` with the value to apply this.

In [23]:
def part_1(data):
    memory = defaultdict(int)
    for line in data:
        line = line.strip()
        if (p := parse.parse('mask = {mask}', line)):
            mask = p['mask']
            mask_0 = sum(
                2 ** (len(mask) - 1 - i) 
                for i, n in enumerate(mask)
                if n != "0"
            )            
            mask_1 = sum(
                2 ** (len(mask) - 1 - i) 
                for i, n in enumerate(mask)
                if n == '1'
            )
        elif (p := parse.parse('mem[{mem:d}] = {value:d}', line)):
            memory[p['mem']] = (p['value'] | mask_1) & mask_0
            
    return sum(memory.values())

In [24]:
test_data = '''mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0'''.splitlines()

assert part_1(test_data) == 165

In [25]:
data = open('input').read().strip().splitlines()
part_1(data)

15514035145260

## Part 2

What are the maximum number of combinations from a mask?

In [30]:
2**max(line.count('X') for line in data)

512

That looks tractable to brute force.

In [91]:
import itertools

def part_2(data):
    memory = defaultdict(int)
    for line in data:
        line = line.strip()
        if (p := parse.parse('mask = {mask}', line)):
            mask = p['mask']
            mask_1 = sum(
                2 ** (len(mask) - 1 - i) 
                for i, n in enumerate(mask)
                if n == '1'
            )
            mask_x = [
                len(mask) - 1 - i 
                for i, n in enumerate(mask)
                if n == 'X'
            ]
        elif (p := parse.parse('mem[{mem:d}] = {value:d}', line)):
            memory_address = (p['mem'] | mask_1)
            for floating_mask in itertools.product([0, 1], repeat=len(mask_x)):
                floating_zeroes = 2 ** 36 - 1
                floating_ones = 0
                for i, b in zip(mask_x, floating_mask):
                    if b == 0:
                        floating_zeroes -= 2 ** i
                    else:
                        floating_ones += 2 ** i
                memory[(memory_address | floating_ones) & floating_zeroes] = p['value']
    return sum(memory.values())
            

In [92]:
test_data_2 = '''mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1'''.splitlines()

part_2(test_data_2)

208

In [94]:
part_2(data)

3926790061594

In [95]:
%%timeit 
part_2(data)

488 ms ± 27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
