# Day 14 - Bit masking operations

* https://adventofcode.com/2020/day/14

All we need to do is generate *two* bitmasks per input mask:

- A mask to _set_ bits, using bit-wise OR (`number | bitmask`); any `1` in the bitmask sets that bit in the output.
- A mask to _clear_ bits, using bit-wise AND (`number & bitmask`); any `0` in the bitmask clears that bit.

The masks can be generated trivially; just replace the `"X"`s with `"0"` or `"1"` respectively, then convert the string as a binary value to an integer; the example `XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X` input makes a set mask when transformed to `int("000000000000000000000000000001000000", 2)`, and a clear mask when transformed to `int("111111111111111111111111111111111101", 2)`.

Memory is just a `defaultdict(int)` object.

In [1]:
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Any

@dataclass
class Mask:
    set_mask: int = 0
    clear_mask: int = (2 ** 36) - 1

    def __rand__(self, other: Any) -> int:
        if not isinstance(other, int):
            return NotImplemented
        return other & self.clear_mask | self.set_mask
    
    @classmethod
    def from_mask(cls, mask: str) -> "Mask":
        return cls(
            int(mask.translate({88: 48}), 2),  # X -> 0
            int(mask.translate({88: 49}), 2),  # X -> 1
        )

_instr = re.compile(r"""
  ^(?:
    mem\[(?P<addr>\d+)\]\s*=\s*(?P<val>\d+)
  | mask\s*=\s*(?P<mask>[01X]+)
  )$
""", flags=re.VERBOSE).search

def initialize_program(lines: list[str], _parse=_instr) -> int:
    mem = defaultdict(int)
    mask = Mask()
    for match in map(_parse, lines):
        if (mval := match["mask"]):
            mask = Mask.from_mask(mval)
        else:
            mem[int(match["addr"])] = int(match["val"]) & mask
    return sum(mem.values())

assert initialize_program("""\
mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0
""".splitlines()) == 165

In [2]:
import aocd
lines = aocd.get_data(day=14, year=2020).splitlines()

In [3]:
print("Part 1:", initialize_program(lines))

Part 1: 11501064782628


## Part 2 - generating addresses

At first I wondered if there might be a method of eliminating addresses from the input, as I was worried about the size of the address space. But looking at my input, the masks never have more than 9 Xs, so up to 512 potential addresses to generate per mask. That's easy enough to brute-force.

To generate the floating masks, we need to _clear_ the bits at the `X` positions, then provide alternative values for these for all possible bit combinations. We can take the `X` positions to produce a series of `(0, 1 << bitpos)` tuples as input to `product()`, we can then use their combination as an additional OR mask.

E.g. `000000000000000000000000000000X1001X` becomes a set mask (`000000000000000000000000000000010010`), and the Xs are mapped to `1 << 5` and `1 << 1` and each paired with `0`, so we can generate `(0, 0)`, `(0, 1 << 1)`, `(1 << 5, 0)` and `(1 << 5, 1 << 1)`, generating the 4 masks needed to produce the floating addresses.

I opted to generate `Mask` instances; you can generate all possible addresses by applying the `Mask` instances that iteration over a `FloatingMask` produces:

In [4]:
from collections.abc import Iterable, Iterator
from functools import reduce
from itertools import product
from operator import or_


@dataclass
class FloatingMask:
    set_mask: int = 0
    clear_mask: int = (2 ** 36) - 1
    address_bits: Iterable[int] = ()

    def __rand__(self, other: Any) -> int:
        if not isinstance(other, int):
            return NotImplemented
        return other & self.clear_mask | self.set_mask

    def __iter__(self) -> Iterator[Mask]:
        for combo in product(*([0, a] for a in self.address_bits)):
            yield Mask(reduce(or_, combo, self.set_mask), self.clear_mask)
    
    @classmethod
    def from_mask(cls, mask: str) -> "Mask":
        return cls(
            int(mask.translate({88: 48}), 2),  # X -> 0
            int(mask.translate({48: 49, 88: 48}), 2),  # 0 -> 1, X -> 0
            [1 << (35 - i) for i, c in enumerate(mask) if c == "X"],
        )


def apply_memory_address_decoder_program(lines: list[str], _parse=_instr) -> int:
    mem = defaultdict(int)
    fmask = FloatingMask()
    for match in map(_parse, lines):
        if (mval := match["mask"]):
            fmask = FloatingMask.from_mask(mval)
        else:
            addr, val = map(int, (match["addr"], match["val"]))
            mem |= {addr & mask: val for mask in fmask}
    return sum(mem.values())


assert apply_memory_address_decoder_program("""\
mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1
""".splitlines()) == 208

In [5]:
print("Part 2:", apply_memory_address_decoder_program(lines))

Part 2: 5142195937660
