In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
from itertools import product

In [None]:
data = load_data(2024, 4)

In [None]:
# data, part_1, part_2
tests = [
    (
        """MMMSXXMASM
MSAMXMSMSA
AMXSXMAAMM
MSAMASMSMX
XMASAMXAMM
XXAMMXXAMA
SMSMSASXSS
SAXAMASAAA
MAMMMXMMMM
MXMXAXMASX
""",
        18,
        9,
    ),
]

# Part 1

In [None]:
def parse_letters(data):
    letters = {}
    for y, line in enumerate(data.splitlines()):
        for x, c in enumerate(line):
            letters[(x, y)] = c
    return letters, x + 1, y + 1

In [None]:
def is_word_here(letters, x, y, dx, dy):
    return all(
        letters.get((x + i * dx, y + i * dy), None) == c
        for i, c in enumerate("XMAS")
    )

In [None]:
def word_count(data):
    letters, width, height = parse_letters(data)
    wc = 0
    for dx, dy in product([-1, 0, 1], [-1, 0, 1]):
        if dx == 0 and dy == 0:
            continue
        # for every starting position
        for x in range(width):
            for y in range(height):
                wc += is_word_here(letters, x, y, dx, dy)
    return wc

In [None]:
check(word_count, tests)
word_count(data)

# Part 2

In [None]:
def is_pattern_here(letters, x, y, pattern):
    return all(
        letters.get((x + i, y + j), None) == c
        for (i, j), c in pattern.items()
    )

In [None]:
def rotate(pattern):
    rotated = {}
    width = max(x for x, _ in pattern)
    for (x, y), c in pattern.items():
        rotated[(y, width - x)] = c
    return rotated

In [None]:
def word_count(data):
    letters, width, height = parse_letters(data)
    pattern = {
        (0, 0): "M",
        (2, 0): "S",
        (1, 1): "A",
        (0, 2): "M",
        (2, 2): "S",
    }
    wc = 0
    for _ in range(4):
        for x in range(width):
            for y in range(height):
                wc += is_pattern_here(letters, x, y, pattern)
        pattern = rotate(pattern)
    return wc

In [None]:
check(word_count, tests, 2)
word_count(data)