In [1]:
input_small = """r, wr, b, g, bwu, rb, gb, br

brwrr
bggr
gbbr
rrbgbr
ubwu
bwurrg
brgr
bbrgwb"""


with open("input.txt") as f:
    input_large = f.read()


def parse_input(input_str):
    input_lines = input_str.strip().splitlines()
    patterns = input_lines[0].strip().split(", ")
    goals = input_lines[2:]
    return patterns, goals


from collections import Counter


def build_goal(goal, patterns):
    patterns = [p for p in patterns if p in goal]
    one_hot_goal = Counter(goal)
    letters = list(one_hot_goal.keys())
    one_hot_goal = [one_hot_goal[l] for l in letters]
    pattern_to_one_hot = {p: Counter(p) for p in patterns}
    for p in patterns:
        pattern_to_one_hot[p] = [pattern_to_one_hot[p].get(l, 0) for l in letters]

    return one_hot_goal, pattern_to_one_hot


In [2]:
from collections import deque


def can_build_goal_backtracking(goal, patterns):
    # filter patterns that cannot be used
    patterns = [p for p in patterns if p in goal]
    patterns.sort(key=len, reverse=True)

    stack = deque([0])
    visited = set()
    n = len(goal)
    while stack:
        i = stack.pop()
        visited.add(i)
        if i == n:
            return True
        for p in patterns:
            next_i = i + len(p)
            if goal[i:].startswith(p) and next_i not in visited:
                stack.append(next_i)
    return False


patterns, goals = parse_input(input_large)
n_possible = 0
for goal in goals:
    is_possible = can_build_goal_backtracking(goal, patterns)

    if is_possible:
        n_possible += 1
        print(f"YES: {goal}")
    else:
        print(f"NO: {goal}")
print(n_possible)

NO: buwugbgrgururgwrgrrugbwgrwurgbubrggruwugwgrwguuurwu
YES: bwbrurbwgurggbbwbrbwubrurrwrwwwruurbrrguuubg
NO: buubbubwwwgugwgwruwbrwbbgrrwwrurrbwgwbbrbugbbubbbwuwubrbg
NO: burugguwubwwbgwrbuwgrwwgwwrbubugwbwbrubbr
YES: wrgrurbwrgbwbgrbbgbrbburrgrbbrbbwgrgrrrwwuubgwgu
YES: wrbrrwwbwgrggwggwrbgwruwwbbgrguugrrubgrrgwgggrgbrguuw
YES: buwbbgbgwbgbbgubwwwbbwbwwruwwwgubgubwbbggwwuu
YES: bgbrgruwgbgbwbwbuwbbbbbgwuburggrgbbwbuwbrbwbuwrgrrrgggwr
YES: wbwgbuwggbgugwrubbwuuuwwuwurwwubrbrurbwguubrwggwrwgwrg
NO: buurrrubgurubggubrrwrgwgbrbbrrbuwurugwbwuwbugbgwww
NO: buubrrbrgrbwrgbwwbugbwrbbrwgrggwwuububbubb
NO: buubbugrbubuguruwrbwubuuurrbwgbwgwuwubuwgrgrwgrgbbruruwbw
NO: buuugwwrwubbuwwbugggwbwggbwwrrrbubwbwgbrgrwbbbwrbw
YES: rrgugrwruuwwuwgrrgwuuwrgbwgrrwbgbbwurggrruuugbbbbrb
NO: buurrbrwrwggrwbugubwguwbbgwbbrrwruugrgrgwwbgurugrrububr
YES: urwgurbgrgwbrbugrubwuububbrwbbgwrwwuwuubgrwbrub
YES: rubububrwrwgurrgurrurwwgwwbguwwbguugugbwrwrrguuurwuwu
NO: buuuubbbbwrburrrrruwurgrgrwuuggrrgrgbrgguuurrrgg

## Part 2

In [None]:
from functools import cache


@cache
def n_connected_to_n(idx: int, links: frozenset[tuple[int, int]]) -> int:
    if idx == 0:
        return 1
    connected_links = {l for l in links if l[1] == idx}
    # links -= connected_links
    # new_links = links - frozenset(connected_links)
    n_ways = sum(n_connected_to_n(l[0], links) for l in connected_links)
    return n_ways


def n_connected_to_n_iterative(
    idx: int, links: frozenset[tuple[int, int]], visited=None
) -> int:
    links = set(links)
    # Initialize a dictionary to store the results (mimicking caching)
    results = {0: 1}  # Base case: idx = 0 has exactly one way

    # Compute results for all indices from 1 to idx
    if visited is None:
        indices = range(1, idx + 1)
    else:
        # we need to sort because we need previous results
        indices = sorted(visited - {0})
    for current_idx in indices:
        connected_links = {l for l in links if l[1] == current_idx}
        results[current_idx] = sum(
            results[l[0]] for l in connected_links if l[0] in results
        )
        links -= connected_links

    return results[idx]


def get_n_paths_backtracking(goal, patterns):
    # TODO: we need to save the number of ways to build the goal
    # in a recursive way

    # filter patterns that cannot be used
    patterns = [p for p in patterns if p in goal]
    patterns.sort(key=len, reverse=True)

    visited = set()
    links = set()
    n = len(goal)
    n_achieved = 0

    stack = deque([0])
    while stack:
        i = stack.pop()
        visited.add(i)
        if i == n:
            n_achieved += 1
            continue
        for p in patterns:
            next_i = i + len(p)
            if goal[i:].startswith(p):
                links.add((i, next_i))
                if next_i not in visited:
                    stack.append(next_i)

    if n_achieved == 0:
        return 0

    # links = frozenset(links) # for recursive with cache
    n_ways = n_connected_to_n_iterative(n, links, visited)

    return n_ways


patterns, goals = parse_input(input_large)
total_paths = 0
for goal in goals:
    n_paths = get_n_paths_backtracking(goal, patterns)
    total_paths += n_paths
    # print(f"{n_paths}: {goal}")

print(total_paths)

691316989225259


In [15]:
links = {
    (16, 20),
    (33, 36),
    (50, 52),
    (43, 46),
    (44, 45),
    (29, 32),
    (46, 48),
    (8, 9),
    (39, 42),
    (23, 25),
    (40, 41),
    (34, 37),
    (11, 14),
    (7, 10),
    (24, 26),
    (41, 42),
    (44, 47),
    (3, 6),
    (20, 22),
    (14, 15),
    (8, 11),
    (9, 10),
    (40, 43),
    (2, 4),
    (15, 16),
    (41, 44),
    (25, 27),
    (42, 43),
    (18, 21),
    (35, 37),
    (38, 39),
    (14, 17),
    (31, 33),
    (48, 49),
    (9, 12),
    (27, 29),
    (15, 18),
    (16, 17),
    (47, 50),
    (42, 45),
    (12, 13),
    (5, 7),
    (22, 23),
    (38, 41),
    (0, 2),
    (48, 51),
    (49, 50),
    (1, 3),
    (13, 14),
    (45, 46),
    (33, 35),
    (50, 51),
    (51, 52),
    (12, 15),
    (29, 31),
    (46, 47),
    (22, 25),
    (23, 24),
    (49, 52),
    (19, 20),
    (13, 16),
    (45, 48),
    (7, 9),
    (24, 25),
    (3, 5),
    (20, 21),
    (52, 53),
    (46, 49),
    (23, 26),
    (40, 42),
    (19, 22),
    (36, 38),
    (30, 31),
    (7, 11),
    (32, 34),
    (20, 23),
    (4, 6),
    (21, 22),
    (14, 16),
    (31, 32),
    (40, 44),
    (17, 18),
    (27, 28),
    (37, 39),
    (30, 33),
    (47, 49),
    (25, 28),
    (6, 8),
    (26, 29),
    (43, 45),
    (5, 6),
    (21, 24),
    (38, 40),
    (1, 2),
    (17, 20),
    (34, 36),
    (27, 30),
    (11, 13),
    (28, 29),
    (37, 41),
    (47, 51),
    (50, 53),
    (12, 14),
    (29, 30),
    (44, 46),
    (38, 42),
    (8, 10),
    (39, 43),
    (1, 4),
    (28, 31),
    (24, 27),
    (41, 43),
    (3, 4),
    (18, 20),
    (35, 36),
    (12, 16),
    (51, 53),
    (9, 11),
    (2, 5),
    (34, 40),
    (19, 21),
    (15, 17),
    (32, 33),
    (42, 44),
    (4, 5),
    (35, 38),
    (0, 1),
    (31, 34),
    (10, 11),
    (36, 39),
    (37, 38),
    (30, 32),
    (6, 7),
    (32, 35),
    (16, 18),
    (33, 34),
    (26, 28),
    (43, 44),
    (42, 46),
    (4, 7),
    (21, 23),
    (5, 8),
    (22, 24),
    (39, 40),
    (17, 19),
    (34, 35),
    (10, 13),
    (49, 51),
    (11, 12),
    (37, 40),
    (13, 15),
    (6, 9),
}

n = max(l[1] for l in links)
links = frozenset(links)

In [24]:
%%timeit


@cache
def n_connected_to_n(idx: int, links: frozenset[tuple[int, int]]) -> int:
    if idx == 0:
        return 1
    connected_links = {l for l in links if l[1] == idx}
    # links -= connected_links
    # new_links = links - frozenset(connected_links)
    n_ways = sum(n_connected_to_n(l[0], links) for l in connected_links)
    return n_ways


n_connected_to_n(n, links)

119 μs ± 983 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [36]:
%%timeit
n_connected_to_n_iterative(n, links)

75 μs ± 125 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## DP