### Puzzle

https://adventofcode.com/2020/day/14

### Imports

In [32]:
from itertools import product

### Load Input

In [33]:
# Store the location of the input directory
data_dir = '../../../data/2020'

# Open the input and store a list of each item as an int
with open(f"{data_dir}/day14_input.txt") as f:
    inputs = f.read().splitlines()

### Helper Functions

In [34]:
def decimal_to_binary(decimal_integer):
    # Initialize empty string
    binary_string = ''
    
    # Loop through all 36 possible digits
    for i in range(36):
        
        # If the decimal number is greater than that digit in binary, add a 1 to that bit and subtract that binary value
        if decimal_integer >= 2**(35-i):
            binary_string += '1'
            decimal_integer -= 2**(35-i)
        
        # Otherwise input a 0 into that bit
        else:
            binary_string += '0'
            
    return binary_string

In [35]:
def binary_to_decimal(binary_string):
    # Convert a binary string to a decimal integer
    decimal_integer = int(binary_string, 2)
        
    return decimal_integer

### Part 1

In [36]:
def masked_bit(mask, bit):
    # Convert the decimal value to binary and initialize an empty string
    binary_bit = decimal_to_binary(bit)
    masked_bit = ''
    
    # Loop through each bit in the mask
    for i in range(len(mask)):
        # If there's an X in the mask, keep the binary bit
        if mask[i] == 'X':
            masked_bit += binary_bit[i]
        
        # Otherwise, input the mask value for that bit
        else:
            masked_bit += mask[i]
    
    # Convert the new binary bit back to a decimal
    decimal_bit = binary_to_decimal(masked_bit)
    
    return decimal_bit

In [50]:
# Initialize a dictionary
mem = dict()

# Go through the input line by line
for line in inputs:
    
    # If the line is a new mask, store the mask
    if line.startswith('mask'):
        mask = line.split(' ')[-1]

    # If the line is a dictionary entry, store the key, value, and run the masked_bitv2 function
    else:
        key = int(line.split('[')[1].split(']')[0])
        value = int(line.split(' ')[-1])
        mem[key] = masked_bit(mask, value) 

In [51]:
sum(mem.values())

9615006043476


### Part 2

In [37]:
def masked_bitv2(mem_dict, mask, key, bit):
    # Initialize an empty list to store keys that will be changed
    keys_to_change = []
    
    # Convert the decimal key to binary and initialize an empty string
    binary_key = decimal_to_binary(key)
    masked_bit = ''
    
    # Initialize an empty list to store indices where there is an 'X'
    X_indices = []
    
    # Loop through all bits in the binary key
    for i in range(len(mask)):
        # If the bit is 0, insert the value of the binary key to the mask
        if mask[i] == '0':
            masked_bit += binary_key[i]
        # If the bit is 1, insert 1 
        elif mask[i] == '1':
            masked_bit += '1'
        # If the bit is X, insert X and add that index to the list of indices
        else:
            masked_bit += 'X'
            X_indices.append(i)
            
    # Count the number of X's
    n_X = len(X_indices)
    
    # If there are no X's, add the masked bit to the list of keys to change
    if n_X == 0:
        keys_to_change.append(binary_to_decimal(masked_bit))
        
    # If there is more than one X, add all permutations (Cartesian products) to the list of keys to change
    else:
        # Separate the string into a list of characters
        masked_bit_list = list(masked_bit)
        
        # Get all Cartesian products of 1 and 0 for the appropriate number of X's
        prod = product(range(2), repeat=n_X)
        
        # For each product, insert the appropriate bit into the masked bit
        for j in list(prod):
            for k in range(len(X_indices)):
                masked_bit_list[X_indices[k]] = f'{j[k]}'
            keys_to_change.append(binary_to_decimal(''.join(masked_bit_list)))
                
    # Change all keys to the correct bit
    for key in keys_to_change:
        mem_dict[key] = bit
    
    return mem_dict

In [39]:
# Initialize a dictionary
mem = dict()

# Go through the input line by line
for line in inputs:
    
    # If the line is a new mask, store the mask
    if line.startswith('mask'):
        mask = line.split(' ')[-1]
        
    # If the line is a dictionary entry, store the key, value, and run the masked_bitv2 function
    else:
        key = int(line.split('[')[1].split(']')[0])
        value = int(line.split(' ')[-1])
        mem = masked_bitv2(mem, mask, key, value)

In [41]:
sum(mem.values())

4275496544925
