In [1]:
from dataclasses import dataclass
import itertools

@dataclass(frozen=True)
class Tree:
    pos: tuple[int, int]
    height: int

forest_map = []
for (y, row) in enumerate(open("input.txt")):
    forest_map.append([Tree((x, y), int(h)) for (x, h) in enumerate(row.strip())])
    
def pretty_print(forest_map: list[list[Tree]], res = None, mark_tree = None):
    for row in forest_map:
        print([f"{t.height:^3}" if res is None or t in res else " - " if mark_tree != t else f">{t.height}<" for t in row])

In [4]:
def get_col(forest_map: list[list[Tree]], col: int) -> list[Tree]:
    return [row[col] for row in forest_map]

def visible_from_ground(col):
    curr_max = -1
    for t in col:
        if curr_max < t.height:
            curr_max = t.height
            yield t

def visible_trees(forest_map: list[list[Tree]]) -> set[tuple[int, int]]:
    visible_trees = set()
    for col_id, _ in enumerate(forest_map[0]):
        col = get_col(forest_map, col_id)
        visible_trees.update(visible_from_ground(col))
        visible_trees.update(visible_from_ground(reversed(col)))
    
    for row in forest_map:
        visible_trees.update(visible_from_ground(row))
        visible_trees.update(visible_from_ground(reversed(row)))
        
    return visible_trees

res = visible_trees(forest_map)
len(res)        

21

In [5]:
def visible_from_height(col: list[Tree], height: int) -> list[Tree]:
    for t in col:
        yield t
        if t.height >= height:
            return

def find_view_from(t: Tree, forest_map: list[list[Tree]]) -> list[Tree]:
    x,y = t.pos
    row = forest_map[y]
    col = get_col(forest_map, x)

    left_of = visible_from_height(reversed(row[:x]), t.height)
    right_of = visible_from_height(row[x+1:], t.height)
    above = visible_from_height(reversed(col[:y]), t.height)
    below = visible_from_height(col[y+1:], t.height)

    return [left_of, right_of, above, below]

res = find_view_from(forest_map[3][2], forest_map)


In [None]:
def score(views: list[list[Tree]]) -> int:
    product = 1
    for view in views:
        if (factor := len(list(view))):
            product *= factor
    return product

max([score(find_view_from(t, forest_map)) for t in itertools.chain(*forest_map)])



In [7]:
tree = forest_map[3][2]

pretty_print(forest_map, mark_tree=tree)
print()
pretty_print(forest_map, set(itertools.chain(*find_view_from(tree, forest_map))), mark_tree=tree)

[' 3 ', ' 0 ', ' 3 ', ' 7 ', ' 3 ']
[' 2 ', ' 5 ', ' 5 ', ' 1 ', ' 2 ']
[' 6 ', ' 5 ', ' 3 ', ' 3 ', ' 2 ']
[' 3 ', ' 3 ', ' 5 ', ' 4 ', ' 9 ']
[' 3 ', ' 5 ', ' 3 ', ' 9 ', ' 0 ']

[' - ', ' - ', ' - ', ' - ', ' - ']
[' - ', ' - ', ' 5 ', ' - ', ' - ']
[' - ', ' - ', ' 3 ', ' - ', ' - ']
[' 3 ', ' 3 ', '>5<', ' 4 ', ' 9 ']
[' - ', ' - ', ' 3 ', ' - ', ' - ']
