In [1]:
class BitMask:
    def __init__(self):
        self.masked = {}
        
    def update(self, rep):
        self.masked = {}
        for i,ch in enumerate(reversed(rep)):
            if ch != 'X':
                self.masked[i] = ch
                
    @staticmethod
    def zeropad(value, places):
        zeros = "0" * places
        return (zeros + (bin(value)[2:]))[-places:]
        
                
    def apply(self, value):
        bin_string = self.zeropad(value, 36)
        
        result = []
        for i,ch in enumerate(reversed(bin_string)):
            if i in self.masked.keys():
                result.append(self.masked[i])
            else:
                result.append(ch)
                
        bin_data = '0b' + "".join(reversed(result))
        return int(bin_data, 2)
    
    def apply_v2(self, value):
        """
        If the bitmask bit is 0, the corresponding memory address bit is unchanged.
        If the bitmask bit is 1, the corresponding memory address bit is overwritten with 1.
        If the bitmask bit is X, the corresponding memory address bit is floating.
        """
        bin_string = self.zeropad(value, 36)
        result = []
        X_count = 0
        for i, ch in enumerate(reversed(bin_string)):
            if i in self.masked.keys():
                if self.masked[i] == '1':
                    result.append('1')
                else:
                    result.append(ch)
            else: # didn't save the X's in p1
                result.append('X')
                X_count += 1
                
                
        num_values = 2**X_count 
        values = []
        for i in range(num_values):
            value = []
            replace_index = 0
            replacements = self.zeropad(i, X_count)
            
            for ch in result:
                if ch != 'X':
                    value.append(ch)
                else:
                    value.append(replacements[replace_index])
                    replace_index += 1
            value_data = "0b" + "".join(reversed(value))
            values.append(int(value_data, 2))
            
        return values

In [2]:
with open('masks.txt') as fp:
    data = fp.read().splitlines()
    

# Problem 1

In [3]:
test_data = """mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0""".splitlines()

cur_mask = BitMask()
memory = dict()

for line in data:
    parts = line.split()
    if parts[0] == 'mask':
        cur_mask.update(parts[2])
    else:
        address = int(parts[0][4:-1])
        value = int(parts[2])
        memory[address] = cur_mask.apply(value)

sum(memory.values())

15018100062885

# Problem 2

In [4]:
test_data = """mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1""".splitlines()   

cur_mask = BitMask()
memory = dict()

for line in data:
    parts = line.split()
    if parts[0] == 'mask':
        cur_mask.update(parts[2])
    else:
        value = int(parts[2])
        address = int(parts[0][4:-1])
        masked_addresses = cur_mask.apply_v2(address)
        for ad in masked_addresses:
            memory[ad] = value

sum(memory.values())

5724245857696