In [1]:
with open("input.txt") as f:
    lines = [
        line.strip() for line in f.readlines()
    ]

In [2]:
from collections import namedtuple

Mask = namedtuple("Mask", ["mask"])
MemStore = namedtuple("Store", ["addr", "val"])

In [3]:
import re

mask_regex = "mask = (?P<mask>[X01]+)"
mem_regex = "mem\[(?P<addr>\d+)\] = (?P<val>\d+)"

instructions = []

for line in lines:
    if match := re.fullmatch(mask_regex, line):
        mask = Mask(match.group("mask"))
        instructions.append(mask)
    elif match := re.fullmatch(mem_regex, line):
        addr = int(match.group("addr"))
        val = int(match.group("val"))
        mem_store = MemStore(addr, val)
        instructions.append(mem_store)

## part1

In [4]:
def apply_mask(val: int, mask: str) -> int:
    """Applies mask rules from part1

    Where the mask has a 0-bit, the val's bit is
    set to 0. Where the mask has a 1-bit, the val's
    bit is set to 1.
    """
    # Create a mask for applying the 0-bits where
    # the 0-bits from the original mask are left
    # in place and all other bits (both 1s and Xs)
    # are set to 1. Then do a bitwise AND with the
    # val to set the 0-bits and leave other bits
    # unchanged.
    zero_mask = ["0" if bit == "0" else "1" for bit in mask]
    zero_mask = "".join(zero_mask)
    val &= int(zero_mask, base=2)

    # Apply the 1-bits from the mask by creating a
    # new mask where 1-bits are left in place and
    # all other bits are set to 0. Then do a bitwise
    # OR with the val to set the 1-bits and leave others
    # unchanged.
    one_mask = ["1" if bit == "1" else "0" for bit in mask]
    one_mask = "".join(one_mask)
    val |= int(one_mask, base=2)

    return val

In [5]:
cur_mask = None
memory = {}

for instruction in instructions:
    if isinstance(instruction, Mask):
        cur_mask = instruction.mask
    elif isinstance(instruction, MemStore):
        addr = instruction.addr
        val = instruction.val
        memory[addr] = apply_mask(val, cur_mask)

sum(val for val in memory.values())

14553106347726

## part2

In [6]:
from typing import List

def address_decoder(addr: str, mask: str) -> List[int]:
    """Given a 36-bit binary address and a mask, apply
    the rules from part 2 and return a list of all
    possible addresses with each floating bit set to
    a 0 and a 1.
    """
    # Use a recursive process for getting all addresses.
    # The recursive call gets a list of all addresses
    # not including the left-most bit.
    
    if not addr:
        return [""]

    sub_addresses = address_decoder(addr[1:], mask[1:])

    if mask[0] == "1":
        # If left-most bit of the mask is a 1, then prepend
        # a 1 to all addresses from the recursive call.
        return ["1" + subaddr for subaddr in sub_addresses]
    elif mask[0] == "0":
        # If left-most bit of the mask is a 0, then prepend
        # the left-most bit of address to all addresses from
        # the recursive call.
        return [
            addr[0] + subaddr for subaddr in sub_addresses
        ]
    else:
        # If left-most of the mask is a 1, then prepend
        # both 0s and 1s to all addresses returned from
        # the recursive call.
        leading_0 = ["0" + subaddr for subaddr in sub_addresses]
        leading_1 = ["1" + subaddr for subaddr in sub_addresses]
        return leading_0 + leading_1

In [7]:
def to_36_bit_string(val: int) -> str:
    """Given an unsigned int, return 36-bit binary string."""
    return bin(val)[2:].zfill(36)

In [8]:
cur_mask = None
memory = {}

for instruction in instructions:
    if isinstance(instruction, Mask):
        cur_mask = instruction.mask
    elif isinstance(instruction, MemStore):
        addr = instruction.addr
        binary_addr = to_36_bit_string(addr)
        val = instruction.val
        for mem_addr in address_decoder(binary_addr, cur_mask):
            memory[mem_addr] = val

sum(val for val in memory.values())

2737766154126