## Day 14

https://adventofcode.com/2020/day/14

In [1]:
import aocd

In [2]:
lines = [line for line in aocd.get_data(day=14, year=2020).splitlines()]
len(lines)

521

In [3]:
lines[:5]

['mask = 100110X100000XX0X100X1100110X001X100',
 'mem[21836] = 68949',
 'mem[61020] = 7017251',
 'mask = X00X0011X11000X1010X0X0X110X0X011000',
 'mem[30885] = 231192']

### Solution to Part 1

In [4]:
def maskable(value: int) -> list:
    return list(bin(value)[2:].zfill(36))

In [5]:
def mask_value(value: int, *, mask: str) -> int:
    value = maskable(value)
    assert len(mask) == 36
    assert len(value) == 36
    masked = []
    for (mask_bit, value_bit) in zip(mask, value):
        bit = value_bit if mask_bit == 'X' else mask_bit
        masked.append(bit)
    return int(''.join(masked), 2)

In [6]:
mask = 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X'

In [7]:
mask_value(11, mask=mask)

73

In [8]:
mask_value(101, mask=mask)

101

In [9]:
def parse_line(line: str):
    cmd, value = line.split(' = ')
    if cmd == 'mask':
        return ('mask', value)
    else:
        addr = int(cmd[4:-1])
        value = int(value)
        return ('mem', (addr, value))

In [10]:
def store_mem(lines):
    mask = None
    for line in lines:
        cmd, value = parse_line(line)
        if cmd == 'mask':
            mask = value
        elif cmd == 'mem':
            addr, value = value
            yield addr, mask_value(value, mask=mask)
        else:
            raise RuntimeError('cannot happen')

In [11]:
mem = {addr: value for addr, value in store_mem(lines)}
len(mem)

354

In [12]:
sum(mem.values())

10452688630537

### Solution to Part 2

In [13]:
def overwrite_floating(masked: str, *, bits: str) -> str:
    assert masked.count('X') == len(bits)
    overwritten = []
    bit = iter(bits)
    for c in masked:
        if c == 'X':
            overwritten.append(next(bit))
        else:
            overwritten.append(c)
    return int(''.join(overwritten), 2)

In [14]:
def floating_addrs(masked: str, *, floating: int) -> list:
    floating_bits = [
        bin(i)[2:].zfill(floating)
        for i in range(2 ** floating)
    ]
    return [
        overwrite_floating(masked, bits=bits)
        for bits in floating_bits
    ]

In [15]:
def masked_addrs(addr: int, *, mask: str) -> list:
    addr = maskable(addr)
    assert len(mask) == 36
    assert len(addr) == 36
    masked = []
    for (mask_bit, addr_bit) in zip(mask, addr):
        if mask_bit == '0':
            bit = addr_bit
        elif mask_bit == '1':
            bit = '1'
        elif mask_bit == 'X':
            bit = 'X'
        else:
            raise RuntimeError('cannot happen')
        masked.append(bit)
    floating = masked.count('X')
    if floating == 0:
        return [int(''.join(masked), 2)]
    return floating_addrs(masked, floating=floating)

In [16]:
mask = '000000000000000000000000000000X1001X'

In [17]:
len(masked_addrs(42, mask=mask))

4

In [18]:
def store_mem_part2(lines):
    mask = None
    for line in lines:
        cmd, value = parse_line(line)
        if cmd == 'mask':
            mask = value
        elif cmd == 'mem':
            addr, value = value
            for masked_addr in masked_addrs(addr, mask=mask):
                yield masked_addr, value
        else:
            raise RuntimeError('cannot happen')

In [19]:
mem = {addr: value for addr, value in store_mem_part2(lines)}
len(mem)

80345

In [20]:
sum(mem.values())

2881082759597