In [16]:
from aocd.models import Puzzle
import itertools
import functools
import re
puzzle = Puzzle(year=2023, day=12)
data = puzzle.input_data
%run helper.ipynb
# data = """???.### 1,1,3
# .??..??...?##. 1,1,3
# ?#?#?#?#?#?#?#? 1,3,1,6
# ????.#...#... 4,1,1
# ????.######..#####. 1,6,5
# ?###???????? 3,2,1"""

In [17]:
lookup = [".","#"]

class Row(object):
    def __init__(self, s):
        springs, groups = s.split(" ")
        self.springs = springs
        self.groups = list(map(int, groups.split(",")))

    def is_possibility(self, s):
        matches = re.findall(r"#+", s)
        if len(matches) != len(self.groups):
            return False
        if self.groups != [len(x) for x in matches]:
            return False
        return True

    def count_possibilities_simple(self):
        num_broken = len([x for x in self.springs if x == "#"])
        num_missing = len([x for x in self.springs if x == "?"])
        total_broken = sum(self.groups)
        num_broken_needed = total_broken - num_broken
        possibilities = list(itertools.product([0, 1], repeat=num_missing))
        possibilities = [p for p in possibilities if sum(p) == num_broken_needed]
        num_poss = 0
        for p in possibilities:
            new_str = ""
            p_idx = 0
            for c in self.springs:
                if c == "?":
                    new_str += lookup[p[p_idx]]
                    p_idx += 1
                else:
                    new_str += c
            if self.is_possibility(new_str):
                num_poss += 1
        return num_poss

In [18]:
rs = [Row(s).count_possibilities_simple() for s in data.split("\n")]

In [19]:
puzzle.answer_a = sum(rs)

In [112]:
lookup = [".","#"]

def find(s, ch):
    return [i for i, ltr in enumerate(s) if ltr == ch]

def sums(length, total_sum):
    if length == 1:
        yield (total_sum,)
    else:
        for value in range(total_sum + 1):
            for permutation in sums(length - 1, total_sum - value):
                yield (value,) + permutation

def min_len(groups):
    return sum(groups) + len(groups) - 1

@functools.cache
def solve_single_orig(spring_line, groups):
    if len(groups) == 0:
        return 1
    
    diff = len(spring_line) - min_len(groups)
    if diff == 0:
        return 1
    
    indexes = find(spring_line, "#")
    possibilities = sums(len(groups) + 1, diff)
    count = 0
    for p in possibilities:
        s = "."*p[0]
        for i in range(len(groups)):
            s += ("#" * groups[i]) + ("." * (p[i+1] + 1))
        if all([s[i] == "#" for i in indexes]):
            count += 1
    return count
       
def solve_single(spring_line, groups):
    if len(groups) == 0:
        if "#" not in spring_line:
            return 1
        else:
            return 0
    
    diff = len(spring_line) - min_len(groups)
    return solve_single_helper(spring_line, groups, diff, "")
    
def solve_single_helper(spring_line, groups, diff, current):
    if len(groups) == 0:
        if len(spring_line) == len(current):
            print(current)
            return 1
        return 0
    total = 0
    for i in range(diff+1):
        s = current + "."*i + "#"*groups[0]
        if len(groups) == 1:
            s = s + "." * (diff - i)
        else:
            s = s + "."
        indexes = find(spring_line[:len(s)], "#")
        if all([s[i] == "#" for i in indexes]):
            c = solve_single_helper(spring_line, groups[1:], diff-i, s)
            total += c
    return total
    

def count_poss(bad_springs, groups):
#     spacing = "    " * len(extra)
#     print(spacing, "Args: ", bad_springs, groups, len(extra))
    if (len(bad_springs) == 0) != (len(groups) == 0): # xor
#         print(spacing, "Zero failure")
        return 0
    if len(groups) == 0 and len(groups) == 0:
        return 1
    if min_len(groups) > sum([len(b)+1 for b in bad_springs]):
        print("Not enough")
        return 0
    total = 0
    for i in range(len(groups)+1):
        if min_len(groups[:i]) > len(bad_springs[0]):
#             print(spacing, f"Stopping on {i} due to length")
            break
#         print(spacing, "  Trying" + str((bad_springs[0], groups[:i])))
        remaining_count = count_poss(bad_springs[1:], groups[i:])
        if remaining_count != 0:
            print("solving", bad_springs[0], groups[:i])
            single = solve_single(bad_springs[0], groups[:i])
            total += (single * remaining_count)
    return total
    
class Row2(object):
    def __init__(self, s, expand=False):
        self.original = s
        springs, groups = s.split(" ")
        self.springs = springs
        self.groups = list(map(int, groups.split(",")))
        if expand:
            self.springs = "?".join([self.springs] * 5)
            self.groups = self.groups * 5
        self.bad_springs = re.findall(r"[#?]+", self.springs)

In [113]:
rows = [Row2(s, False) for s in data.split("\n")]
a = rows[5]

In [108]:
count_poss(tuple(a.bad_springs), tuple(a.groups))

solving ???#???#?#????? (2, 3, 3, 1, 2)
##.###.###.#.##


1

In [96]:
a.bad_springs[0], tuple(a.groups)

('???#???#?#?????', (2, 3, 3, 1, 2))

In [100]:
total = 0
for i in range(len(rows)):
    c = count_poss(tuple(rows[i].bad_springs), tuple(rows[i].groups))
    if c != rs[i]:
        print(rows[i].original, "====", c, rs[i])
    total += c
print(total)

solving ??#???##??#?? (4, 2, 2)
####..##.##..
####..##..##.
.####.##.##..
.####.##..##.
solving ? (1,)
#
solving #????????? (9,)
#########.
solving ??#? (2,)
.##.
..##
solving ????# (4,)
.####
solving #???#? (2, 1, 1)
solving #???#? (1, 1)
#...#.
solving ??# (2,)
.##
solving ???#?#????? (1, 1, 2, 3)
#..#.##.###
.#.#.##.###
solving ?#?? (4,)
####
solving ???#???#?#????? (2, 3, 3, 1, 2)
##.###.###.#.##
solving #??###? (6,)
######.
solving ???? (3,)
###.
.###
solving ??? (2,)
##.
.##
solving ??# (1,)
..#
solving ????#?#?#? (1, 6)
#..######.
#...######
.#.######.
.#..######
..#.######
solving ?## (3,)
###
solving ?????? (1,)
#.....
.#....
..#...
...#..
....#.
.....#
solving #????#? (2, 4)
##.####
solving ?#?? (3,)
###.
.###
solving ????? (1, 1, 1)
#.#.#
solving ?????##?## (1, 4, 2)
#..####.##
.#.####.##
solving # (1,)
#
solving ????? (1, 2)
#.##.
#..##
.#.##
solving ?#?? (2,)
##..
.##.
solving # (1,)
#
solving ????? (2,)
##...
.##..
..##.
...##
solving ?#?? (2, 1)
##.#
solving ??##????????

In [114]:
count_poss(("??????####???????", "?"), (1, 6, 3))

Not enough
Not enough
Not enough


0

In [39]:
list(sums(3,5))

[(0, 0, 5),
 (0, 1, 4),
 (0, 2, 3),
 (0, 3, 2),
 (0, 4, 1),
 (0, 5, 0),
 (1, 0, 4),
 (1, 1, 3),
 (1, 2, 2),
 (1, 3, 1),
 (1, 4, 0),
 (2, 0, 3),
 (2, 1, 2),
 (2, 2, 1),
 (2, 3, 0),
 (3, 0, 2),
 (3, 1, 1),
 (3, 2, 0),
 (4, 0, 1),
 (4, 1, 0),
 (5, 0, 0)]

In [328]:
find('#???????#?', '#')

[0, 8]