In [1]:
# Read file
import time

def read_file(filename):
    with open(filename) as infile:
        lines = [line.strip() for line in infile.readlines()]
    return lines

In [2]:
# Part 1
def split_mask(mask):
    val = 1
    pos = 0
    neg = 0
    for c in reversed(mask):
        if c == 'X':
            pos |= val
        elif int(c):
            pos |= val
            neg |= val
        val *= 2
    return pos, neg

def apply_mask(value, pos, neg):
    return (value & pos) | neg

def process(lines):
    mem = {}
    mask = (0, 0)
    for line in lines:
        k, v = line.split(' = ')
        if k == 'mask':
            mask = split_mask(v)
        else:
            mem[int(k[4:-1])] = apply_mask(int(v), *mask)
    return mem

In [3]:
# Test part 1
start = time.time()
print(sum(process(read_file("test01.txt")).values()) == 165)
time.time() - start

True


0.0

In [4]:
# Solve part 1
start = time.time()
print(sum(process(read_file("input.txt")).values()))
time.time() - start

11327140210986


0.0

In [5]:
# Part 2
def decode_address(mask, address):
    decoded = list(mask)
    val = 1
    for i in range(len(mask) - 1, -1, -1):
        if mask[i] == '0' and (address & val):
            decoded[i] = '1'
        val *= 2
    return decoded

def collision(a, b):
    for i in range(len(a)):
        if a[i] != 'X' and b[i] != 'X' and a[i] != b[i]:
            return False
    return True

def remainder(old, new):
    remains = []
    alt = old[:]
    for i in range(len(alt)):
        if alt[i] == 'X' and new[i] != 'X':
            alt[i] = '1' if new[i] == '0' else '0'
            remains.append(alt[:])
            alt[i] = '1' if alt[i] == '0' else '0'
    return remains

def process2(lines):
    mem = []
    mask = ""
    for line in lines:
        k, v = line.split(' = ')
        if k == 'mask':
            mask = v
        else:
            newaddr = decode_address(mask, int(k[4:-1]))
            newmem = []
            for (addr, val) in mem:
                if collision(addr, newaddr):
                    rem = remainder(addr, newaddr)
                    newmem += list(zip(rem, [val] * len(rem)))
                else:
                    newmem.append((addr, val))
            newmem.append((newaddr, int(v)))
            mem = newmem
    return mem

def sum_mem(mem):
    tot = 0
    for (k, v) in mem:
        tot += v * 2 ** k.count('X')
    return tot

In [6]:
# Test part 2
start = time.time()
print(sum_mem(process2(read_file("test02.txt"))) == 208)
time.time() - start

True


0.0

In [7]:
# Solve part 2
start = time.time()
print(sum_mem(process2(read_file("input.txt"))))
time.time() - start

2308180581795


0.20003938674926758