In [76]:
import re
import itertools

In [2]:
test_input = '''mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0'''

In [117]:
test_input2 = '''mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1'''

In [3]:
def parse_program(s):
    blocks = s.replace('\n', ' ').split('mask = ')[1:]
    
    # blocks:
    # ['0X0X1110X1010X1X10010X0011010X100110 mem[40190] = 23031023 mem[13516] = 384739600 ']
    
    all_instructions = []
    
    for block_s in blocks:
        m = re.match('([0|1|X]+)\s(.*)', block_s)
        mask = m.group(1)
        rest = m.group(2)
        
        instructions = [(mask, int(m[0]), int(m[1])) for m in re.findall("mem\[(\d+)\]\s=\s(\d+)", rest)]
        
        all_instructions += instructions

    return all_instructions

In [4]:
test_program = parse_program(test_input)

In [118]:
test_program2 = parse_program(test_input2)

In [5]:
program = parse_program(open('./inputs/14').read())

In [6]:
def overwrite(mask_str, number):
    # XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
    # first delete existing bits at the locations by
    # ANDing with 0s at the specified locations and 1s elsewhere.
    # then OR with the mask with 0s for Xs
        
    and_mask = int(mask_str.replace('1', '0').replace('X', '1'), 2)
    or_mask = int(mask_str.replace('X', '0'), 2)
    
    return (number & and_mask) | or_mask

In [7]:
def execute(program):
    memory = {}
    
    for (mask, location, value) in program:
        masked_value = overwrite(mask, value)
        memory[location] = masked_value
        
    return sum(memory.values())

In [8]:
execute(test_program)

165

In [9]:
execute(program)

6631883285184

In [99]:
# ORs properly, but replaces X with 0 and records their position
# will use the recorded positions to sum with the reduced value
def get_floating_bin(mask, value):
    bin_str = format(value, '036b')
    new_s = list(bin_str)
    mask_positions = []

    for i, (v, m) in enumerate(zip(bin_str, mask)):
        if m == 'X':
            new_s[i] = '0'
            mask_positions.append(len(mask) - 1 - i )
        else:
            new_s[i] = str(int(v) | int(m))

    return mask_positions, int(''.join(new_s), 2)

In [128]:
def execute2(program):
    memory = {}
    
    for (mask, location, value) in program:
        mask_positions, reduced_value = get_floating_bin(mask, location)
        
        for new_location in [reduced_value + sum(ns) for ns in
                             itertools.product(*((0, 2**i) for i in mask_positions))]:
            memory[new_location] = value
    return sum(memory.values())

In [134]:
execute2(program)

3161838538691