In [1]:
f = open('input.txt')
raw = f.read()
f.close()

In [2]:
data = raw[:-1].split('\n')

In [3]:
sample_data = '''mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0'''.split('\n')

In [4]:
import re

class Initializer:
    def __init__(self, program_array):
        self.program = program_array
        self.mask = ''
        self.values = dict()
    
    def action(self, line):
        if line[:4] == 'mask':
            self.mask = line.partition(' = ')[-1]
        elif line[:3] == 'mem':
            index = re.search('\[(.*)\]', line)[1]
            bin_number = bin(int(line.partition(' = ')[-1]))
            self.values[index] = int(self.apply_mask(bin_number), 2)
        else:
            print(f'Unknown command {line}')
            
    def apply_mask(self, bin_value):
        bin_value = bin_value[2:]
        ans = ''
        #we assume that mask is always bigger than the value
        for i, symbol in enumerate(reversed(self.mask)):
            if symbol == 'X':
                try:
                    ans = bin_value[-1*(i+1)] + ans
                except IndexError:
                    ans = '0' + ans
            elif symbol in ['0', '1']:
                ans = symbol + ans
            else:
                print(f'Unknown symbol in mask: {symbol}, {self.mask}')
        return ans

    def initialize(self):
        for line in self.program:
            self.action(line)

    def sum_in_memory(self):
        return sum(self.values.values())
        

In [5]:
test = Initializer(sample_data)
test.initialize()
print(test.sum_in_memory() == 165)

True


In [6]:
ini = Initializer(data)
ini.initialize()
ini.sum_in_memory()

13105044880745

In [37]:
sample_data2 = '''mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1'''.split('\n')

- If the bitmask bit is `0`, the corresponding memory address bit is unchanged.
- If the bitmask bit is `1`, the corresponding memory address bit is overwritten with 1.
- If the bitmask bit is `X`, the corresponding memory address bit is floating.

In [67]:
class Initializer2(Initializer):
    def __init__(self, program_array):
        super().__init__(program_array)
    
    def action(self, line):
        if line[:4] == 'mask':
            self.mask = line.partition(' = ')[-1]
        elif line[:3] == 'mem':
            index = bin(int(re.search('\[(.*)\]', line)[1]))
            number = int(line.partition(' = ')[-1])
            self.write(index, number)
        else:
            print(f'Unknown command {line}')
    
    def apply_mask(self, *args):
        pass
    
    def write(self, index, number):
        index = index[2:]
        masked_index = ''
        for i, symbol in enumerate(reversed(self.mask)):
            if symbol == '0':
                try:
                    masked_index = index[-1*(i+1)] + masked_index
                except IndexError:
                    masked_index = '0' + masked_index
            elif symbol in ['1', 'X']:
                masked_index = symbol + masked_index
            else:
                print(f'Unknown symbol in mask: {symbol}, {self.mask}')
        
        for index in Initializer2.unfloat_address(masked_index):
            self.values[index] = number
    
    @staticmethod
    def unfloat_address(index):
        ans = ['']
        for symbol in index:
            if symbol == 'X':
                ans = ans * 2
                ans[:int(len(ans)/2)] = [i + '0' for i in ans[:int(len(ans)/2)]]
                ans[int(len(ans)/2):] = [i + '1' for i in ans[int(len(ans)/2):]]
            else:
                ans = [i + symbol for i in ans]
        return ans

In [73]:
test2 = Initializer2(sample_data2)
test2.initialize()
print(test2.sum_in_memory() == 208)

True


In [76]:
ini2 = Initializer2(data)
ini2.initialize()
ini2.sum_in_memory()

3505392154485