In [1]:
from pathlib import Path

import re

from functools import reduce

from itertools import combinations_with_replacement, product, zip_longest

from collections import Counter

In [2]:
data_path = Path.home() / 'workstation' / 'dev' / 'Advent-of-Code-2020' / 'data' / 'day14_input.txt'

In [3]:
data_path.exists()

True

In [4]:
with open(data_path, 'r') as reader:
    init_program = reader.read().strip()

In [5]:
init_list = init_program.split('\n')

#### Part 1

In [6]:
class InitializationProgram:
    def __init__(self):
        self.regex_pattern = re.compile(r'([0-9]+)')
        self.mem = {}
        self.bitmask_dict = {}
        
        
    @staticmethod
    def set_bit(v, index, x):
        """
        # https://stackoverflow.com/questions/12173774/how-to-modify-bits-in-an-integer
        """
        mask = 1 << index
        v &= ~mask 
        if x:
            v |= mask
        return v
    
    
    def process_instructions(self, init_list):
        for step in init_list:
            if step.startswith('mask'):
                self.bitmask_dict = {}
                self.process_bitmask(step)
            elif step.startswith('mem'):
                self.process_memory_value(step)
    

    def get_sum_memory_values(self):
        return sum(self.mem.values())

        
    def process_bitmask(self, step):
        bitmask_string = step.split('mask = ')[1]
        l = len(bitmask_string) 
        for string_index, bitmask in enumerate(bitmask_string):
            dict_key = l-1-string_index
            if bitmask == '0' or bitmask == '1':
                self.bitmask_dict[dict_key] = int(bitmask)

                
    def process_memory_value(self, step):
        mem_index, value = map(int, re.findall(self.regex_pattern, step))
        for key in self.bitmask_dict.keys():
            value = InitializationProgram.set_bit(value, key, self.bitmask_dict[key])
        self.mem[mem_index] = value

In [7]:
program = InitializationProgram()

In [8]:
program.process_instructions(init_list)

In [9]:
program.get_sum_memory_values()

15172047086292

#### Part 2

In [10]:
class InitializationProgram2(InitializationProgram):
    def __init__(self):
        super().__init__()
        self.bitmask_string = ''

    
    def process_instructions(self, init_list):
        for step in init_list:
            if step.startswith('mask'):
                self.bitmask_dict = {}
                self.process_bitmask(step)
            elif step.startswith('mem'):
                self.process_memory_value(step)

        
    def process_bitmask(self, step):
        self.bitmask_string = step.split('mask = ')[1]

                
    @staticmethod
    def replace_char_w_all_combinations(char, string, replacement_list):
        split_string_list = string.split(char)
        res = []
    
        char_counter = Counter(string)
        num_replacements = char_counter[char]
    
        possible_combinations = product(replacement_list, repeat=num_replacements)
    
        for replacement_tuple in possible_combinations:
            replacement_string = ''.join(replacement_tuple)
            temp = ''.join(i + j for i, j in zip_longest(split_string_list, replacement_tuple, fillvalue=""))
            res.append(temp)
    
        return res

    
    def process_memory_value(self, step):
        mem_index, value = map(int, re.findall(self.regex_pattern, step))
        mem_index_binary_string = bin(mem_index)[2:]
        l = len(mem_index_binary_string)
        mem_index_36_bits = mem_index_binary_string.zfill(36)
        new_binary_string = ''
        
        for mask_char, mem_char in zip(self.bitmask_string, mem_index_36_bits):
            if mask_char == '0':
                new_binary_string += mem_char
            elif mask_char == '1':
                new_binary_string += '1'
            elif mask_char == 'X':
                new_binary_string += 'X'
                
        mem_index_list = InitializationProgram2.replace_char_w_all_combinations('X', new_binary_string, ['0', '1'])
        
        for mem_index in mem_index_list:
            int_mem_index = int(mem_index, 2)
            self.mem[int_mem_index] = value

In [11]:
program_2 = InitializationProgram2()

In [12]:
program_2.process_instructions(init_list)

In [13]:
program_2.get_sum_memory_values()

4197941339968