In [None]:
import numpy as np
from scipy.signal import correlate2d
from tqdm.auto import tqdm

In [None]:
def is_valid_index(array, index):
    return (index >= 0).all() and (index < array.shape).all()

In [None]:
def array_is_in(arr, ref):
    return any([np.array_equal(arr, x) for x in ref])

In [None]:
def add_to_region(grid, visited, idx, region):
    if is_valid_index(grid, idx) and not visited[*idx] and grid[*idx]:
        region.add((int(idx[0]), int(idx[1])))
        visited[*idx] = True
        for offset in [[0, 1], [0, -1], [1, 0], [-1, 0]]:
            add_to_region(grid, visited, idx + offset, region)


def get_regions(grid):
    visited = np.zeros_like(grid)
    regions = []
    for i in range(grid.shape[0]):
        for j in range(grid.shape[1]):
            if grid[i][j] and not visited[i][j]:
                region = set()
                add_to_region(grid, visited, np.array([i, j]), region)
                regions.append(region)
    return regions

## Part 1

In [None]:
with open("data/day12/input.txt", "r") as file:
    map_raw = file.read()

In [None]:
map = np.array([list(x) for x in map_raw.split("\n")])

In [None]:
plant_cost = {}
for plant in tqdm(np.unique(map)):
    plant_cost[str(plant)] = 0
    for idx_region in get_regions(map == plant):
        perimeter = 0
        for idx in idx_region:
            idx = np.array(idx)
            for offset in [[0, 1], [0, -1], [1, 0], [-1, 0]]:
                neighbour_idx = idx + offset
                if is_valid_index(map, neighbour_idx):
                    if map[*neighbour_idx] != plant:
                        perimeter += 1
                else:
                    perimeter += 1
        plant_cost[str(plant)] += perimeter * len(idx_region)

print(plant_cost)
print(sum(plant_cost.values()))

## Part 2

In [None]:
ext_kernel = np.array(
    [
        [0, -1, 0],
        [-1, 1, 0],
        [0, 0, 0],
    ]
)
int_kernel = np.array(
    [
        [-1, -1, 0],
        [-1, 1, 0],
        [0, 0, 0],
    ]
)
plant_cost = {}
for plant in tqdm(np.unique(map)):
    plant_cost[str(plant)] = 0
    for idx_region in get_regions(map == plant):
        map_region = np.zeros(map.shape, dtype=int)
        for idx in idx_region:
            map_region[*idx] = 1
        corners_ext = np.array(
            [
                correlate2d(map_region, np.rot90(ext_kernel, k=k), mode="same") > 0
                for k in range(4)
            ]
        )
        corners_int = np.array(
            [
                correlate2d(
                    1 - map_region, np.rot90(int_kernel, k=k), mode="same", fillvalue=1
                )
                > 0
                for k in range(4)
            ]
        )
        edges = int(corners_ext.sum() + corners_int.sum())
        plant_cost[str(plant)] += len(idx_region) * edges

print(plant_cost)
print(sum(plant_cost.values()))