# Day 10
## Part 1
Represent lights and buttons with sets and BFS. You could do this more efficiently with bitmasks as the lights and buttons are binary numbers.

In [4]:
from advent import read_input

def parse_data(s):
    data = []
    for line in s.strip().splitlines():
        fields = line.strip().split()
        lights = frozenset(
            i 
            for i, c in enumerate(fields[0][1:-1]) 
            if c == "#"
        )
        buttons = [
            frozenset(eval(field.replace(")", ",)")))
            for field in fields[1:-1]
        ]
        joltages = eval(fields[-1].replace("{", "[").replace("}", "]"))
        data.append((lights, buttons, joltages))
    return data

test_data = parse_data("""[.##.] (3) (1,3) (2) (2,3) (0,2) (0,1) {3,5,4,7}
[...#.] (0,2,3,4) (2,3) (0,4) (0,1,2) (1,2,3,4) {7,5,12,7,2}
[.###.#] (0,1,2,3,4) (0,3,4) (0,1,2,4,5) (1,2) {10,11,11,5,10,5}
""")

test_data

[(frozenset({1, 2}),
  [frozenset({3}),
   frozenset({1, 3}),
   frozenset({2}),
   frozenset({2, 3}),
   frozenset({0, 2}),
   frozenset({0, 1})],
  [3, 5, 4, 7]),
 (frozenset({3}),
  [frozenset({0, 2, 3, 4}),
   frozenset({2, 3}),
   frozenset({0, 4}),
   frozenset({0, 1, 2}),
   frozenset({1, 2, 3, 4})],
  [7, 5, 12, 7, 2]),
 (frozenset({1, 2, 3, 5}),
  [frozenset({0, 1, 2, 3, 4}),
   frozenset({0, 3, 4}),
   frozenset({0, 1, 2, 4, 5}),
   frozenset({1, 2})],
  [10, 11, 11, 5, 10, 5])]

In [5]:
from collections import deque

def min_presses(lights, buttons):
    q = deque([(frozenset(), 0)])
    seen = {frozenset()}
    while q:
        l, n = q.popleft()
        if l == lights:
            return n
        for b in buttons:
            new_ls = l ^ b
            if new_ls not in seen:
                q.append((new_ls, n + 1))
                seen.add(new_ls)

[min_presses(ls, bs) for ls, bs, _ in test_data]

[2, 3, 2]

In [6]:
def part_1(data):
    return sum(min_presses(ls, bs) for ls, bs, _ in data)

assert part_1(test_data) == 7

In [7]:
data = parse_data(read_input())

part_1(data)

484

## Part 2

Good old dependable dynamic programming has failed me here, working on the test data but failing on the first instance in the real data.

These are a set of equations, where a given joltage is the sum of the number of buttons affecting that joltage is pressed. We would like to solve for each button's number of presses but from attempting this on paper this is not always possible. So simplify the equations and then search from there.

In [8]:
from collections import Counter, namedtuple
from itertools import chain, combinations, permutations
from pyrsistent import pmap, pvector, pset, PSet
from functools import reduce
from dataclasses import dataclass

@dataclass
class Equation:
    joltage: int 
    buttons: PSet

    def __str__(self):
        return f"{self.joltage} = {' + '.join('n' + str(b) for b in self.buttons)}"

def create_eqs(buttons, joltages):
    eqs = pvector()
    for i, joltage in enumerate(joltages):
        bs = pset({bi for bi, b in enumerate(buttons) if i in b})
        eqs = eqs.append(Equation(joltage, bs))
    ns = pmap()
    return eqs, ns


def simplify_step(eqs, ns):
    if any(
        eq.joltage < 0 or (eq.joltage >= 0 and len(eq.buttons) == 0)
        for eq in eqs
    ):
        raise ValueError
    # Remove empty equation
    for i, eq in enumerate(eqs):
        if len(eq.buttons) == 0:
            eqs = eqs.delete[i]
            return True, eqs, ns
    # Remove duplicate equations
    for i, j in combinations(range(len(eqs)), 2):
        if eqs[i] == eqs[j]:
            # print(f"Removing duplicate: {eqs[j]}")
            eqs = eqs.delete(j)
            return True, eqs, ns
    # Resolve equations with single variable
    for i, eq in enumerate(eqs):
        if len(eq.buttons) == 1:
            # print(f"Resolved: {eqs[i]}")
            x = next(iter(eq.buttons))
            ns = ns.set(x, eq.joltage)
            eqs = eqs.delete(i)
            for i in range(len(eqs)):
                if x in eqs[i].buttons:
                    eqs[i].buttons = eqs[i].buttons.remove(x)
                    eqs[i].joltage = eqs[i].joltage - ns[x]
            return True, eqs, ns
    # If a joltage is zero, set all to zero
    for i, eq in enumerate(eqs):
        if eq.joltage == 0:
            # print(f"All zero: {eqs[i]}")
            for b in eq.buttons:
                eqs.append(Equation(0, pset({b})))
            eqs = eqs.delete(i)
            return True, eqs, ns
    # Resolve pairs of equations where one has an extra variable
    for i, j in permutations(range(len(eqs)), 2):
        if eqs[j].buttons - eqs[i].buttons == pset():
            diff = eqs[i].buttons - eqs[j].buttons
            if len(diff) == 1:
                x = next(iter(diff))
                xjolt = eqs[i].joltage - eqs[j].joltage
                # print(f"Worked out: {eqs[i]} - {eqs[j]}")
                eqs = eqs.append(Equation(xjolt, pset({x})))
                return True, eqs, ns
            else: 
                # print(f"Diffing: {eqs[i]} - {eqs[j]}")
                eqs = eqs.append(Equation(eqs[i].joltage - eqs[j].joltage, diff))
                eqs = eqs.delete(i)
                return True, eqs, ns
    # If a variable is in all equations set it to the minimum joltage
    if len(eqs) > 1:
        in_all = list(reduce(
            lambda x, y: x.intersection(y), 
            (eq.buttons for eq in eqs)
        ))
        if in_all:
            i = in_all[0]
            min_joltage = min(eq.joltage for eq in eqs)
            # print(f"All equations contain {i}, setting to {min_joltage}")
            eqs = eqs.append(Equation(min_joltage, pset({i})))
            return True, eqs, ns
    return (False, eqs, ns)
        
    while simplify():
        pass

def simplify_eqs(eqs, ns):
    simplifying = True
    while simplifying:
        simplifying, eqs, ns = simplify_step(eqs, ns)
    return eqs, ns
        

_, bs, js = test_data[1]
eqs, ns = create_eqs(bs, js)
simplify_eqs(eqs, ns)

(pvector([]), pmap({0: 2, 1: 5, 2: 0, 3: 5, 4: 0}))

In [9]:
completely_resolved = 0
partially_resolved = 0
unresolved = 0
for i, (_, bs, js) in enumerate(data):
    eqs, ns = simplify_eqs(*create_eqs(bs, js))
    if not eqs:
        completely_resolved += 1
    elif ns:
        partially_resolved += 1
    else:
        unresolved += 1
print(len(data))
print(completely_resolved)
print(partially_resolved)
print(unresolved)

186
75
71
40


In [10]:
import math
from copy import deepcopy, copy

def min_presses(eqs, ns):
    # print("====")
    # for eq in eqs:
    #     print(eq)
    # print(ns)
    try:
        eqs, ns = simplify_eqs(eqs, ns)
    except ValueError:
        return math.inf
    if any(x < 0 for x in ns.values()):
        return math.inf
    if any(eq.joltage < 0 for eq in eqs):
        return math.inf
    if len(eqs) == 0:
        return sum(ns.values())
    # print("----")
    # for eq in eqs:
    #     print(eq)
    # print(ns)
    x = Counter(chain.from_iterable(eq.buttons for eq in eqs)).most_common(1)[0][0]

    presses = []
    for j in range(min(eq.joltage for eq in eqs if x in eq.buttons) + 1):
        presses.append(min_presses(deepcopy(eqs).append(Equation(j, pset({x}))), deepcopy(ns)))
    return min(presses)

In [11]:
_, bs, js = data[1]
eqs, ns = create_eqs(bs, js)
min_presses(eqs, ns)

157

In [12]:
import tqdm

def part_2(data):
    result = 0
    for _, bs, js in tqdm.tqdm(data):
        result += min_presses(*create_eqs(bs, js))
    return result

part_2(test_data)

100%|██████████████████████████████| 3/3 [00:00<00:00, 713.60it/s]


33

In [13]:
part_2(data)

100%|█████████████████████████| 186/186 [2:14:25<00:00, 43.36s/it]


19170

Over two hours running time and wrong.

It was stuck on this one for a long time, I'll take a look when I have more time.

In [18]:
eqs, ns = create_eqs(data[132][1], data[132][2])
for eq in eqs:
    print(eq)

75 = n0 + n1 + n9 + n2 + n10 + n4
61 = n0 + n1 + n9 + n2 + n10
74 = n0 + n3 + n4 + n5 + n6
98 = n1 + n9 + n2 + n10 + n4 + n6 + n7
68 = n8 + n1 + n9 + n2 + n10 + n4
34 = n10 + n7
123 = n0 + n1 + n10 + n3 + n4 + n5 + n6 + n7
60 = n0 + n1 + n10 + n6
50 = n8 + n2 + n10 + n4 + n5


In [21]:
eqs, ns = simplify_eqs(eqs, ns)
for eq in eqs:
    print(eq)
print(ns)

46 = n0 + n9 + n2 + n10
60 = n0 + n3 + n5 + n6
39 = n8 + n9 + n2 + n10
34 = n10 + n7
45 = n0 + n10 + n6
36 = n8 + n2 + n10 + n5
35 = n9 + n2 + n6
pmap({1: 15, 4: 14})
