# [Advent of Code 2020 Day 14](https://adventofcode.com/2020/day/14)

This looks cracked. More low-level computer stuff.

## Initial setup

In [299]:
import ipytest
import sys
sys.path.append("..")
from ansi import *
from comp import *
ipytest.autoconfig()

## Input Parsing

In [300]:
def parse_input(filename: str) -> list[tuple[str, str, str] | tuple[str, str]]:

    gen = yield_line(filename)
    instructions = []

    for line in gen:
        if line.startswith("mem"):
            address, value = parse(r"mem\[(\d+)\] = (\d+)", line)
            instructions.append(("mem", address, value))
        elif line.startswith("mask"):
            mask = parse(r"mask = (.*)", line)[0]
            instructions.append(("mask", mask))

    return instructions

## Part 1
Let's go the straightforward approach, because how else would you do it? First I'm going to make a mask class.

In [301]:
class Mask:
    def __init__(self, mask: str):
        assert len(mask) == 36
        self.mask: list[str] = list(mask)
    def apply(self, other):
        assert len(other) == 36
        for idx, char in enumerate(other):
            if char == "X":
                continue
            self.mask[idx] = char
    def __repr__(self):
        return "".join(self.mask)

In [302]:
%%ipytest
def test_mask_creation():
    assert Mask("000000000000000000000000000000001011").mask == list("000000000000000000000000000000001011")

def test_mask_apply_11_73():
    mask = Mask("000000000000000000000000000000001011")
    mask.apply("XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X")
    assert str(mask) == "000000000000000000000000000001001001"

def test_mask_apply_101_101():
    mask = Mask("000000000000000000000000000001100101")
    mask.apply("XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X")
    assert str(mask) == "000000000000000000000000000001100101"

def test_mask_apply_0_64():
    mask = Mask("000000000000000000000000000000000000")
    mask.apply("XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X")
    assert str(mask) == "000000000000000000000000000001000000"

[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m                                                                                         [100%][0m
[32m[32m[1m4 passed[0m[32m in 0.01s[0m[0m


OK, now a helper method to convert int to 36-bit bitmask.

In [303]:
def int_to_mask(num: int) -> str:
    tmp = bin(num)[2:]
    return ("0" * (36 - len(tmp))) + tmp

In [304]:
%%ipytest
def test_int_to_mask():
    assert int_to_mask(11) == "000000000000000000000000000000001011"
    assert int_to_mask(73) == "000000000000000000000000000001001001"
    assert int_to_mask(101) == "000000000000000000000000000001100101"
    assert int_to_mask(0) == "000000000000000000000000000000000000"
    assert int_to_mask(64) == "000000000000000000000000000001000000"

[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.01s[0m[0m


In [305]:
def part_one(data: list[tuple[str, str, str] | tuple[str, str]]) -> int:
    memory = {}
    mask = None

    for item in data:
        if item[0] == "mem":
            assert mask is not None
            address, value = item[1], item[2]
            memory[address] = Mask(int_to_mask(int(value)))
            memory[address].apply(mask)
        elif item[0] == "mask":
            mask = item[1]
        else:
            raise Exception(f"Invalid command {item[0]}")

    return sum([int(str(item), 2) for item in memory.values()])

In [306]:
%%ipytest
def test_part_one():
    assert part_one(parse_input("example1")) == 165
    assert part_one(parse_input("input")) == 11926135976176

[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.01s[0m[0m


## Part 2
Oh boy, binary maths. I don't think I should've made a class in the first part, so I'll just go back to procedural.

In [307]:
def apply_mask(original: str, mask: str) -> str:
    assert len(original) == 36
    assert len(mask) == 36
    exploded = list(original)
    for idx, char in enumerate(mask):
        if char == "1":
            exploded[idx] = "1"
        if char == "X":
            exploded[idx] = "X"
    return "".join(exploded)

In [308]:
%%ipytest
def test_apply_mask():
    assert apply_mask("000000000000000000000000000000101010", "000000000000000000000000000000X1001X") == "000000000000000000000000000000X1101X"

[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.01s[0m[0m


In [309]:
def get_addresses(mask: str) -> list[int]:
    floating_points = []
    for idx, char in enumerate(mask):
        if char == "X":
            floating_points.append(idx)

    curr_mask = list(mask)

    result = []

    def backtrack(pos: int):
        nonlocal result

        if pos >= len(floating_points):
            result.append(int("".join(curr_mask), 2))
            return

        curr_mask[floating_points[pos]] = "0"
        backtrack(pos + 1)
        curr_mask[floating_points[pos]] = "1"
        backtrack(pos + 1)

    backtrack(0)

    return result

In [310]:
%%ipytest
def test_sum_float():
    assert get_addresses("000000000000000000000000000000X1101X") == [26, 27, 58, 59]
    assert get_addresses("00000000000000000000000000000001X0XX") == [16, 17, 18, 19, 24, 25, 26, 27]

[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.01s[0m[0m


In [311]:
def part_two(data: list[tuple[str, str, str] | tuple[str, str]]) -> int:
    memory = {}
    mask = None

    for item in data:
        if item[0] == "mem":
            assert mask is not None
            address, value = item[1], item[2]
            for address_to_write in get_addresses(apply_mask(int_to_mask(int(address)), mask)):
                memory[address_to_write] = value
        elif item[0] == "mask":
            mask = item[1]
        else:
            raise Exception(f"Invalid command {item[0]}")

    return sum(map(int, memory.values()))

In [312]:
%%ipytest
def test_part_two():
    assert part_two(parse_input("example2")) == 208
    assert part_two(parse_input("input")) == 4330547254348

[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.07s[0m[0m
