In [1]:
from pathlib import Path

data_file = Path("../Data/day4.txt").read_text()


EXAMPLE = """
MMMSXXMASM
MSAMXMSMSA
AMXSXMAAMM
MSAMASMSMX
XMASAMXAMM
XXAMMXXAMA
SMSMSASXSS
SAXAMASAAA
MAMMMXMMMM
MXMXAXMASX
"""

EXAMPLE4 = """
....X.....
.....M....
......A...
.......S..
"""


def prepare(input: str):
    return list(
        map(
            lambda line: list(line),
            filter(lambda line: len(line) > 0, input.splitlines()),
        )
    )


example_data = prepare(EXAMPLE)
example_data2 = prepare("MMSXXMASM")
example_data3 = prepare("MSAMXMSMSA")
example_data4 = prepare(EXAMPLE4)
data = prepare(data_file)

In [2]:
SEARCH_WORD = "XMAS"


def check_coordinates(lines: list[list[str]], coordinates: list[tuple[int, int]]):
    for row, col in coordinates:
        if row < 0 or col < 0 or row >= len(lines) or col >= len(lines[0]):
            return None

    try:
        word = "".join(
            map(lambda coordinate: lines[coordinate[0]][coordinate[1]], coordinates)
        )
    except IndexError:
        return

    if word != SEARCH_WORD:
        return

    return coordinates


def check_right(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row, column),
        (row, column + 1),
        (row, column + 2),
        (row, column + 3),
    ]

    return check_coordinates(lines=lines, coordinates=coordinates)


def check_left(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row, column),
        (row, column - 1),
        (row, column - 2),
        (row, column - 3),
    ]

    return check_coordinates(lines=lines, coordinates=coordinates)


def check_up(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row, column),
        (row - 1, column),
        (row - 2, column),
        (row - 3, column),
    ]

    return check_coordinates(lines=lines, coordinates=coordinates)


def check_down(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row, column),
        (row + 1, column),
        (row + 2, column),
        (row + 3, column),
    ]

    return check_coordinates(lines=lines, coordinates=coordinates)


def check_right_up(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row, column),
        (row - 1, column + 1),
        (row - 2, column + 2),
        (row - 3, column + 3),
    ]

    return check_coordinates(lines=lines, coordinates=coordinates)


def check_right_down(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row, column),
        (row + 1, column + 1),
        (row + 2, column + 2),
        (row + 3, column + 3),
    ]

    return check_coordinates(lines=lines, coordinates=coordinates)


def check_left_up(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row, column),
        (row - 1, column - 1),
        (row - 2, column - 2),
        (row - 3, column - 3),
    ]

    return check_coordinates(lines=lines, coordinates=coordinates)


def check_left_down(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row, column),
        (row + 1, column - 1),
        (row + 2, column - 2),
        (row + 3, column - 3),
    ]

    return check_coordinates(lines=lines, coordinates=coordinates)


def make_coordinate_entry(entry: list[tuple[int, int]]):
    return "".join(sorted(map(str, entry)))


def get_coordinates(lines: list[list[str]]):
    coordinates: set[str] = set()
    for row, line in enumerate(lines):
        for column, character in enumerate(line):
            if character != "X":
                continue

            if right := check_right(lines=lines, row=row, column=column):
                coordinates.add(make_coordinate_entry(right))
            if left := check_left(lines=lines, row=row, column=column):
                coordinates.add(make_coordinate_entry(left))
            if up := check_up(lines=lines, row=row, column=column):
                coordinates.add(make_coordinate_entry(up))
            if down := check_down(lines=lines, row=row, column=column):
                coordinates.add(make_coordinate_entry(down))
            if right_up := check_right_up(lines=lines, row=row, column=column):
                coordinates.add(make_coordinate_entry(right_up))
            if right_down := check_right_down(lines=lines, row=row, column=column):
                coordinates.add(make_coordinate_entry(right_down))
            if left_up := check_left_up(lines=lines, row=row, column=column):
                coordinates.add(make_coordinate_entry(left_up))
            if left_down := check_left_down(lines=lines, row=row, column=column):
                coordinates.add(make_coordinate_entry(left_down))

    return coordinates


def part1(lines: list[list[str]]):
    coordinates = get_coordinates(lines)

    return len(coordinates)


assert part1(example_data2) == 1
assert "(0, 4)(0, 5)(0, 6)(0, 7)" in get_coordinates(example_data2)
assert part1(example_data4) == 1
assert "(0, 4)(1, 5)(2, 6)(3, 7)" in get_coordinates(example_data4)
assert part1(example_data) == 18

result = part1(data)

assert result != 2576
assert result == 2551

print("result is", result)

result is 2551


In [3]:
def find_xmas(lines: list[list[str]], coordinates: list[tuple[int, int]]):
    for row, col in coordinates:
        if row < 0 or col < 0 or row >= len(lines) or col >= len(lines[0]):
            return None

    try:
        word = "".join(
            map(lambda coordinate: lines[coordinate[0]][coordinate[1]], coordinates)
        )
    except IndexError:
        return

    if word != "MAS" and word != "SAM":
        return

    return coordinates


def check_left_up_right_down(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row - 1, column - 1),
        (row, column),
        (row + 1, column + 1),
    ]

    return find_xmas(lines=lines, coordinates=coordinates)


def check_left_down_to_right_up(lines: list[list[str]], row: int, column: int):
    coordinates = [
        (row + 1, column - 1),
        (row, column),
        (row - 1, column + 1),
    ]

    return find_xmas(lines=lines, coordinates=coordinates)


def get_coordinates(lines: list[list[str]]):
    coordinates: set[str] = set()
    for row, line in enumerate(lines):
        for column, character in enumerate(line):
            if character != "A":
                continue

            if (
                left_down_to_right_up := check_left_down_to_right_up(
                    lines=lines, row=row, column=column
                )
            ) and (
                left_up_right_down := check_left_up_right_down(
                    lines=lines, row=row, column=column
                )
            ):
                coordinates.add(
                    "".join(
                        sorted(
                            set(
                                map(
                                    lambda item: str(item),
                                    left_down_to_right_up + left_up_right_down,
                                )
                            )
                        )
                    )
                )

    return coordinates


def part2(lines: list[list[str]]):
    return len(get_coordinates(lines))


assert part2(example_data) == 9

result = part2(data)

assert result == 1985


print("result is", result)

result is 1985
