In [None]:
import numpy as np
from itertools import product
import matplotlib.pyplot as plt

with open('input.txt') as f:
    plant_map = [[c for c in line.strip()] for line in f.readlines()]
    
plant_map = np.array(plant_map)
assigned_region = np.zeros(plant_map.shape, dtype=int)

for y, x in product(range(plant_map.shape[0]), range(plant_map.shape[1])):
    possible_region = None
    tile_perimeter = 0
    # check top
    if y > 0 and plant_map[y-1, x] == plant_map[y, x]:
        possible_region = assigned_region[y-1, x]
        
    # check left
    if x > 0 and plant_map[y, x-1] == plant_map[y, x]:
        if possible_region is not None:
            # join two regions
            left_region_id = assigned_region[y, x-1]
            assigned_region[np.where(assigned_region == left_region_id)] = possible_region
        else:
            possible_region = assigned_region[y, x-1]
    
    # assign region
    if possible_region is None:
        assigned_region[y, x] = assigned_region.max() + 1
    else:
        assigned_region[y, x] = possible_region
        
regions, areas = np.unique(assigned_region, return_counts=True)
perimeters = []
for region in regions:
    region_perimeter = 0
    region_points = list(zip(*np.where(assigned_region == region)))
    for y, x in region_points:
        for dy, dx in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
            if (y + dy, x + dx) not in region_points:
                region_perimeter += 1
    perimeters.append(region_perimeter)

products = areas * perimeters
print(products.sum())

def show_map(plant_map, assigned_region):
    cmap = plt.get_cmap('Paired')
    colors = [cmap(i) for i in range(assigned_region.max() + 1)]
    
    for y, x in product(range(plant_map.shape[0]), range(plant_map.shape[1])):
        color = colors[assigned_region[y, x]]
        plt.text(x / plant_map.shape[1], y / plant_map.shape[1], plant_map[y, x], ha='center', va='center', color=color)
        
# show_map(plant_map, assigned_region)

In [None]:
numbers_of_edges = []

for region in regions:
    region_perimeter = 0
    region_points = set(zip(*np.where(assigned_region == region)))
    
    new_region_points = set()
    
    for y, x in region_points:
        for dy, dx in product([0.5, -0.5], repeat=2):
            new_region_points.add((y+dy, x+dx))
            pass
        
    only_vertices = set()
    
    offset_list_vertical_1 = [(0.5,0.5), (-0.5,-0.5)]
    offset_list_vertical_2 = [(0.5,-0.5), (-0.5,0.5)]

    extra_multiplier = 0 # If a point is a vertical point, it will be counted twice
    
    for y, x in new_region_points:
        number_of_adjacent_1 = sum([1 for dy, dx in offset_list_vertical_1 if (y + dy, x + dx) in region_points])
        number_of_adjacent_2 = sum([1 for dy, dx in offset_list_vertical_2 if (y + dy, x + dx) in region_points])
        number_of_adjacent = number_of_adjacent_1 + number_of_adjacent_2
        if (
            number_of_adjacent == 1 # outer edge
            or number_of_adjacent == 3 # inner edge
        ):
            only_vertices.add((y, x))
        elif (
            number_of_adjacent_1 == 2 and number_of_adjacent_2 != 2 # vertical point 1
            or number_of_adjacent_2 == 2 and number_of_adjacent_1 != 2 # vertical point 2
        ):
            only_vertices.add((y, x))
            extra_multiplier += 1
            
    number_of_edges = len(only_vertices) + extra_multiplier
    numbers_of_edges.append(number_of_edges)
    
    # plt.plot(*zip(*new_region_points), 'o')
    # plt.plot(*zip(*region_points), 'x')
    # plt.plot(*zip(*only_vertices), 's', alpha=0.5)
    # plt.show()
    
print(sum(areas * numbers_of_edges))