---
# --- Day 12: Hot Springs ---
---

In [10]:
from typing import List, Tuple
import re
import numpy as np
from collections import Counter
from tqdm.notebook import tqdm

In [2]:
V = lambda *x: np.array(x)

## Load data

In [14]:
full_puzzle_data = False

In [15]:
def read_input(full_puzzle_data: bool) -> (List[str], List[List[int]]):
    file_suffix = "" if full_puzzle_data else "_test"
    records = []
    counts = []
    with open(f"data/day12_input{file_suffix}.txt", "r") as f:
        for row in f.read().splitlines():
            record, count_string = row.split(" ")
            records.append(record)
            counts.append([int(c) for c in count_string.split(",")])
    return records, counts

In [16]:
records, counts = read_input(full_puzzle_data)

## --- Part One ---

In [17]:
def solve_record_iteratively(record: str, c: List[int]) -> str:
    
    backward = False
    stuck_count = 0
    
    check = V(0, len(record))    
   
    while stuck_count < 2 and len(c) > 0:
        
        if len(c) == sum(c) == 2 and re.search(r"\?\?\?", record) is not None:
            record = record.replace("???", "#.#")
            c = []
            stuck_count = 0
            continue            
        
        n = c[-1] if backward else c[0]                
        sr = record[check[0]:check[1]]
        if backward:
            sr = sr[::-1]
               
        group = re.search(r"[\?#]+", sr)

        assert group is not None, "Impossible to find another potential group of springs"
        
        a, b = group.span()
        if b - a == n:
            gstring = "".join(["#"]*n) + "".join(["."]*(b-a-n))
            sr = sr[:a] + gstring + sr[b:]
            if backward:
                sr = sr[::-1]
            record = record[:check[0]] + sr + record[check[1]:]            
            check = check - V(0, b) if backward else check + V(b, 0)
            c = c[:-1] if backward else c[1:]
            stuck_count = 0
        else:
            stuck_count += 1
            
        backward = not backward
        
    return record

def assign_hash_group(record: str, counts: List[int]) -> str:
    c = counts.copy()
    while len(c) > 0:
        m = max(c)
        if Counter(c)[m] != 1:
            break
        groups = {}
        for mm in re.finditer(r"(#+)", r):
            groups[len(mm.group())] = groups.get(len(mm.group()), []) + [mm.span()]
        if (not m in groups) or len(groups[m]) != 1:
            break
        rstring = list(record)
        a, b = groups[m][0]
        if a > 0:
            rstring[a-1] = "."
        if b <= len(rstring) - 1:
            rstring[b] = "."
        record = "".join(rstring)
        c.remove(m)
    return record
        
def simplify_record(r: str, c: List[int]) -> str:
    r = assign_hash_group(r, c)
    r = solve_record_iteratively(r, c)
    return r

In [18]:
def is_record_valid(record: str, counts: List[int]) -> bool:
    if "?" in record:
        return False
    springs = V(*[len(mm.group()) for mm in re.finditer(r"(#+)", record)])
    return np.array_equal(springs, V(*counts))

In [62]:
def check_record_start(record: str, first_group_count: int) -> bool:
    print(record)
    r = record[:record.find("?")]       
    n = [len(rr) for rr in re.findall(r"[\#]+", r)]
    print(n)
    print(counts[:len(n)])
    return np.array_equal(V(*n), V(*counts[:len(n)]))

In [99]:
re.findall(r"[#]+", re.match(r"^[.#]+[.|\$]", "..###.##.").group())

['###', '##']

In [68]:
def count_possible_assignments(record: str, counts: List[int]) -> bool:
    i = record.find("?")
    if i == -1:
        if is_record_valid(record, counts):
            return 1
        else:
            return 0        
    else:
        if not check_record_start(record, counts):
            print(record)
            print(counts)
            return 0
        #record = simplify_record(record, counts)        
        r1 = list(record)
        r1[i] = "#"
        r2 = list(record)
        r2[i] = "."        
        return count_possible_assignments("".join(r1), counts) + count_possible_assignments("".join(r2), counts)

In [69]:
sum_counts = 0
for i, r in tqdm(enumerate(records), total=len(records)):
    n = count_possible_assignments(r, counts[i])
    sum_counts += n

  0%|          | 0/6 [00:00<?, ?it/s]

.##..??...###.
[1, 1, 3]


AssertionError: Impossible to find another potential group of springs

In [65]:
print(sum_counts)

9


## --- Part Two ---