## Part 1

In [75]:
from dataclasses import dataclass
from queue import PriorityQueue

In [342]:
TEST_INFILE = "inputs/day_10_test_1.txt"
INFILE = "inputs/day_10.txt"

#with open(TEST_INFILE) as infile:
with open(INFILE) as infile:
    lines = infile.read().splitlines()

In [343]:
@dataclass
class Problem:
    mask: tuple[str]
    buttons: list[tuple[int]]
    joltages: str

In [344]:
problems = []
for line in lines:
    split = line.split()
    mask = tuple(split[0][1:-1])
    
    buttons = split[1:-1]
    buttons = [eval(b) for b in buttons]
    buttons = [b if isinstance(b, tuple) else tuple([b]) for b in buttons]    

    joltages = [int(j) for j in split[-1].replace("{", "").replace("}", "").split(",")]
    problems.append(Problem(mask, buttons, joltages))

In [345]:
def get_adjacent(mask, buttons):
    neighbors = []
    for button in buttons:
        new_mask = list(mask)
        for i in button:
            if mask[i] == ".":
                new_mask[i] = "#"
            else:
                new_mask[i] = "."
        neighbors.append(tuple(new_mask))
    return neighbors

In [346]:
def dijkstra(start, goal, buttons):
    # distances and predecessors
    dists = {}
    preds = {}
    # initialize distances and predecessors
    # for the start node
    dists[start] = 0
    preds[start] = None
        
    open_q = PriorityQueue()
    open_q.put((0, start))
    while not open_q.empty():
        node = open_q.get()[1]
            
        if node == goal:
            break
            
        for neighbor in get_adjacent(node, buttons):
            new_dist = dists[node] + 1

            if neighbor not in dists or new_dist < dists[neighbor]:
                dists[neighbor] = new_dist
                priority = new_dist
                open_q.put((priority, neighbor))
                preds[neighbor] = node
                    
    return preds, dists
    

In [347]:
def solve_problem(problem):
    goal  = problem.mask
    start = tuple("." for _ in problem.mask)
    preds, dists = dijkstra(start, goal, problem.buttons)
    return dists[goal]

In [348]:
sum(solve_problem(p) for p in problems)

547

## Part 2

In [367]:
import numpy as np
from scipy.optimize import milp, LinearConstraint, Bounds

In [387]:
def solve_integer_min_sum(A, b):
    """
    Minimize sum_i x_i
    s.t. Ax = b, x>=0 and x_i int's
    """
    A = np.asarray(A, dtype=float)
    b = np.asarray(b, dtype=float).ravel()   # <-- ensure shape (m,)
    m, n = A.shape

    # Objective: minimize sum(x_i)
    c = np.ones(n, dtype=float)

    # Equality constraint: A x = b  =>  lb = ub = b
    lc_eq = LinearConstraint(A, lb=b, ub=b)

    # Bounds: x_i >= 0
    bounds = Bounds(lb=np.zeros(n), ub=np.full(n, np.inf))

    # All variables integer
    integrality = np.ones(n, dtype=int)

    res = milp(
        c=c,
        constraints=[lc_eq],   # must be a sequence or single LinearConstraint
        bounds=bounds,
        integrality=integrality,
    )

    if not res.success:
        raise ValueError(f"MILP infeasible or failed: {res.message}")

    x = np.rint(res.x).astype(int)  # round small numerical noise
    return x, x.sum(), res

In [388]:
def solve(problem):
    n_rows = len(problem.joltages)
    n_cols = len(problem.buttons)
    #print(f"We have {len(problem.buttons)} buttons and {len(problem.joltages)} indicators.")

    rows = []
    for row_n in range(n_rows):
        this_row = []
        for button in problem.buttons:
            if row_n in button:
                this_row.append(1)
            else:
                this_row.append(0)
        rows.append(this_row)

    A = np.array(rows)
    y = np.array([[j] for j in problem.joltages])

    return solve_integer_min_sum(A, y)[1]

In [389]:
sum(solve(p) for p in problems)

np.int64(21111)