## part 1 ##

In [1]:
bitsize = 36

In [2]:
ex = '''mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0'''.strip().split('\n')

In [3]:
bin(11), int('1011', base=2)

('0b1011', 11)

In [4]:
def get_masks(maskstr):
    '''Break the mask string into two maks '''
    ones = []
    zeros = []
    for c in maskstr.strip():
        if c == 'X':
            ones.append('0')
            zeros.append('1')
        elif c == '1':
            ones.append('1')
            zeros.append('1')
        elif c == '0':
            ones.append('0')
            zeros.append('0')
        else:
            raise ValueError(f'Value not allowed: {c}')
    ones = ''.join(ones)
    zeros = ''.join(zeros)
    return(int(ones, base=2), int(zeros, base=2))

In [5]:
get_masks('XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X')

(64, 68719476733)

In [6]:
def apply_masks(ones, zeros, num):
    return zeros & (ones | num)

In [7]:
assert(73 == apply_masks(64, 68719476733, 11))
assert(101 == apply_masks(64, 68719476733, 101))
assert(64 == apply_masks(64, 68719476733, 0))

In [8]:
def process(lines):
    mem = {}
    setones, setzeros = get_masks('XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX')
    for line in lines:
        if 'mask' in line:
            maskstr = line.split('=')[1].strip()
            setones, setzeros = get_masks(maskstr)
        elif 'mem' in line:
            lhs, rhs = line.strip().split('=')
            loc = int(lhs[lhs.rindex('[')+1:lhs.rindex(']')])
            val = int(rhs.strip())
            mem[loc] = apply_masks(setones, setzeros, val)
        else:
            raise ValueError(f'Bad line: {line}')
    return sum(mem[loc] for loc in mem)

In [9]:
process(ex)

165

In [10]:
with open('inputs/day14.input') as fp:
    data = fp.readlines()

In [11]:
process(data)

14954914379452

## part 2 ##

In [12]:
ex2 = '''mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1'''.strip().split('\n')

In [13]:
import itertools

In [14]:
def masked_locations(maskstr, loc):
    loc = format(loc, '036b') # binary 36-digit string
    s = []
    for i, c in enumerate(maskstr):
        if c == '0':
            s.append(loc[i])
        elif c == '1':
            s.append('1')
        elif c == 'X':
            s.append(['0', '1'])
        else:
            raise ValueError(f'Invalid value: {c}')
    newlocs = [int(''.join(p), base=2) for p in itertools.product(*s)]
    return newlocs

In [15]:
masked_locations('000000000000000000000000000000X1001X', 42)

[26, 27, 58, 59]

In [16]:
def process2(lines):
    mem = {}
    maskstr = format(0, '036b')
    for line in lines:
        if 'mask' in line:
            maskstr = line.split('=')[1].strip()
        elif 'mem' in line:
            lhs, rhs = line.strip().split('=')
            loc = int(lhs[lhs.rindex('[')+1:lhs.rindex(']')])
            val = int(rhs.strip())
            newlocs = masked_locations(maskstr, loc)
            for x in newlocs:
                mem[x] = val
        else:
            raise ValueError(f'Bad line: {line}')
    return sum(mem[loc] for loc in mem)    

In [17]:
process2(ex2)

208

In [18]:
process2(data)

3415488160714