In [28]:
# Open file and parse contents
file = open('day14_inputs.txt')
content = [line.strip().split(' = ') for line in file]

content[:5]

[['mask', '1010X101010010101X00X00011XX11011111'],
 ['mem[1303]', '728'],
 ['mem[5195]', '213352120'],
 ['mem[34818]', '782'],
 ['mem[43971]', '29724050']]

### Part 1

Our program is a bitmask system, each memory address **n** (denoted by mem[n]) is a 36 bit integer representing a number between 0 and 2^36 - 1.

There are 2 types of input:
1. A memory address update, e.g. mem[3] = 16 would try to update memory address 3 with the binary value for 16 (10000, padded with 0's on the left to make a 36 digit integer)
2. A mask, this take a memory address update and overwrite any places with 1s/0s that are populated in the mask. Any 'X's will not overwrite the memory address update

To give a simple example, if we had a 5 bit mask and a 5 bit address update:
- mask = 0X1X1
- mem[2] = 10000

We would take the attempt at writing 16 to memory address 2, and overwrite any digits where the mask is not an 'X'

0X1X1
**1**0**0**0**0**

The first, third, and fifth digits are overwritten - with the result being 00101 (5) being written to memory address 2.

A mask applies to all lines below it, until another mask is set.

In [29]:
import re 

def convert_to_binary(num):
    
    binary_nums = []
    
    # Loop from largest to smallest number
    for i in reversed(list(range(36))):
        if 2**i <= num:
            binary_nums.append(1)
            num-=2**i
        else:
            binary_nums.append(0)
    
    return ''.join([str(bn) for bn in binary_nums])


def convert_to_int(binary_num):
    
    num = 0
    for i in range(36):
        num+= (2**i * int(binary_num[::-1][i]))
    
    return num


# Now, loop through content and either
# 1. Set new mask
# 2. Turn update value into binary, apply mask, update memory address
dict_mem_addresses = {}

for action, value in content:
    
    # Set new mask if required
    if action == 'mask':
        mask = value
        if len(mask) != 36:
            raise ValueError('Invalid mask length')
    
    else:
        # Get address
        address = re.search('[0-9]+', action).group()
        
        # Convert value to binary
        value_bin = convert_to_binary(int(value))
        
        # Set up blank masked value list
        masked_value_bin_list = []
        
        # Overwrite binary value with mask
        for i in range(36):
            if mask[i] != 'X':
                masked_value_bin_list.append(mask[i])
            else:
                masked_value_bin_list.append(value_bin[i])
                
        # Write to memory
        dict_mem_addresses[address] = ''.join(masked_value_bin_list)
        
        
# Finally, loop through written memory addresses and add values
total = 0
for key in dict_mem_addresses:
    total+=convert_to_int(dict_mem_addresses[key])

total

8566770985168

### Part 2

Bitmask behaviour changed:
- A 0 does nothing
- A 1 overwrites with '1'
- An X is floating, meaning that it takes on all possible values
- It doesn't alter the value being written - instead it specifies which target address the number is being written to

e.g.:
- mem[42] = 100
- address_num = 101010
- mask = X1001X
- masked_address_num = X1101X

The X's take on all possible combinations, meaning that 4 memory addresses are written, all with the value of 100.

- X1101X
- 011010 (26)
- 011011 (27)
- 111010 (58)
- 111011 (59)

In [30]:
import copy

# Now, loop through content and either
# 1. Set new mask
# 2. Turn update value into binary, apply mask, update memory address
dict_mem_addresses = {}

for action, value in content:
    
    # Set new mask if required
    if action == 'mask':
        mask = value
        if len(mask) != 36:
            raise ValueError('Invalid mask length')
    
    else:
        # Get address
        address = re.search('[0-9]+', action).group()
        
        # Convert value to binary
        address_bin = convert_to_binary(int(address))
        
        # Set up blank masked value list
        masked_address_bin_list = []
        
        for i in range(36):
            # Overwrite binary value with mask if 1
            if mask[i] == '1':
                masked_address_bin_list.append('1')
            # Ignore if mask is 0
            elif mask[i] == '0':
                masked_address_bin_list.append(str(address_bin[i]))
            # Otherwise, write an X
            else:
                masked_address_bin_list.append('X')
        
        masked_address_bin = ''.join(masked_address_bin_list)
        
        # Write a different address for each combination of Xs
        if masked_address_bin.count('X') == 0:
            addresses_bin = list(masked_address_bin)
        else:
            addresses_bin = ['']
            for i in range(36):
                if masked_address_bin[i] != 'X':
                    for j in range(len(addresses_bin)):
                        addr = addresses_bin[j]
                        list_addr = list(addr)
                        list_addr.append(masked_address_bin[i])
                        addresses_bin[j] = ''.join(list_addr)
                else:
                    for j in range(len(addresses_bin)):
                        addr = addresses_bin[j]
                        list_addr = list(addr)
                        
                        # Set 0s
                        list_addr.append('0')
                        addresses_bin[j] = ''.join(list_addr)
                        
                        # Set 1s
                        list_addr_new = list(addr)
                        list_addr_new.append('1')
                        addresses_bin.append(''.join(list_addr_new))
            
                    
        # Convert addresses back to int
        addresses_to_write = [convert_to_int(bin_addr) for bin_addr in addresses_bin]
        
        # Write to memory
        for addr in addresses_to_write:
            dict_mem_addresses[addr] = int(value)
        
        
# Finally, loop through written memory addresses and add values
total = 0
for key in dict_mem_addresses:
    total+=dict_mem_addresses[key]

total

4832039794082