In [14]:
example = """
mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0
""".strip().splitlines()

with open("day14.txt", "r") as f:
    data = f.readlines()

In [15]:
from typing import List, Union, Dict
from collections import defaultdict

class MaskedMemory(object):
    def __init__(self, bits: int = 36):
        self.mask = [
            None
            for i in range(bits)
        ]

        self.memory: Dict[int, int] = defaultdict(lambda: 0)

    @property
    def memory_total(self) -> int:
        return sum(self.memory.values())

    def set_mask(self, mask: Union[str, List[Union[int, None]]]):
        if isinstance(mask, str):
            mask = [
                int(bit)
                if bit != 'X'
                else None
                for bit in mask
            ]

        self.mask = mask

    def write(self, address: int, value: int):
        for i, bit in enumerate(self.mask):
            if bit is None:
                continue

            shifted_bit = 1 << (len(self.mask) - i - 1)

            if bit == 0:
                value = value & (~shifted_bit)
            else:
                value = value | shifted_bit

        self.memory[address] = value

test_memory = MaskedMemory()
test_memory.set_mask("XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X")
test_memory.write(8, 11)
assert test_memory.memory_total == 73
test_memory.write(7, 101)
assert test_memory.memory_total == 174
test_memory.write(8, 0)
assert test_memory.memory_total == 165

In [16]:
class InstructionProcessor(object):
    def __init__(self, memory: MaskedMemory):
        self.memory = memory

    def process(self, instruction: str):
        left, right = map(lambda s: s.strip(), instruction.split('='))

        if left == "mask":
            return self.memory.set_mask(right)

        addr = int(left[left.index("[")+1:left.index("]")])
        return self.memory.write(addr, int(right))

    
test_processor = InstructionProcessor(MaskedMemory())

for line in example:
    test_processor.process(line)

assert test_processor.memory.memory_total == 165
        

In [18]:
part1_processor = InstructionProcessor(MaskedMemory())
for line in data:
    part1_processor.process(line)

print(f"Final Sum (part 1): {part1_processor.memory.memory_total}")

Final Sum (part 1): 13727901897109


In [34]:
from typing import Iterator

class AddressMaskedMemory(MaskedMemory):
    def write(self, address: int, value: int):
        for addr in self.get_addresses(address):
            self.memory[addr] = value

    def get_addresses(self, address: int) -> Iterator[int]:
        floating_bits = [
            len(self.mask) - i - 1
            for i, bit in enumerate(self.mask)
            if bit is None
        ]

        for i, bit in enumerate(self.mask):
            if bit == 1:
                shifted_bit = 1 << (len(self.mask) - i - 1)
                address = address | shifted_bit

        for bit_set in range(0, 2**len(floating_bits)):
            addr = address
            for i, bit in enumerate(floating_bits):
                shifted_bit = 1 << bit
                if (bit_set & 2**i) != 0:
                    addr = addr & (~shifted_bit)
                else:
                    addr = addr | shifted_bit

            yield addr

test_memory = AddressMaskedMemory()
test_memory.set_mask("000000000000000000000000000000X1001X")
test_memory.write(42, 100)
assert test_memory.memory_total == 400
test_memory.set_mask("00000000000000000000000000000000X0XX")
test_memory.write(26, 1)
assert test_memory.memory_total == 208


In [35]:
part2_processor = InstructionProcessor(AddressMaskedMemory())
for line in data:
    part2_processor.process(line)

print(f"Final Sum (part 2): {part2_processor.memory.memory_total}")

Final Sum (part 2): 5579916171823
