In [297]:
import numpy as np


def read_file(path: str):
    with open(path, "r") as f:
        res = f.readlines()
    res = [list(x.strip()) for x in res]
    return np.matrix(res)


input = read_file("input.txt")

In [298]:
input

matrix([['S', 'S', 'S', ..., 'S', 'S', 'S'],
        ['M', 'A', 'A', ..., 'A', 'A', 'A'],
        ['M', 'S', 'M', ..., 'M', 'M', 'M'],
        ...,
        ['M', 'A', 'A', ..., 'A', 'M', 'A'],
        ['M', 'S', 'S', ..., 'S', 'X', 'S'],
        ['X', 'X', 'A', ..., 'M', 'A', 'X']], dtype='<U1')

In [299]:
from enum import Enum


class Direction(Enum):
    LEFT = "l"
    RIGHT = "r"
    UP = "u"
    DOWN = "d"


def get_line(start, end):
    start = start[:]
    res = []
    res.append(start[:])

    while start != end:
        if start[0] < end[0]:
            start[0] += 1
        elif start[0] > end[0]:
            start[0] += -1

        if start[1] < end[1]:
            start[1] += 1
        elif start[1] > end[1]:
            start[1] += -1
        res.append(start[:])

    return res


def construct_start_end(
    start: list[int, int], direction: Direction | list[Direction], len: int
):
    start = start[:]
    end = start[:]
    direction = [direction] if not isinstance(direction, list) else direction

    for partial_direction in direction:
        if partial_direction == Direction.RIGHT:
            end = [end[0], end[1] + len]
        elif partial_direction == Direction.LEFT:
            end = [end[0], end[1] - len]
        elif partial_direction == Direction.UP:
            end = [end[0] - len, end[1]]
        elif partial_direction == Direction.DOWN:
            end = [end[0] + len, end[1]]

    return end


def get_start_coords(input_matrix, char: str):
    return np.array(np.where(input_matrix == char)).T.tolist()


def line_to_str(input, line):
    res = [input[*x] for x in line]
    return "".join(res)

In [300]:
x_coords = get_start_coords(input, "X")
directions = [
    Direction.RIGHT,
    Direction.LEFT,
    Direction.UP,
    Direction.DOWN,
    [Direction.UP, Direction.LEFT],
    [Direction.UP, Direction.RIGHT],
    [Direction.DOWN, Direction.LEFT],
    [Direction.DOWN, Direction.RIGHT],
]

rows = input.shape[0] - 1
cols = input.shape[1] - 1

res = []
for x in x_coords:
    start = x[:]
    for d in directions:
        end = construct_start_end(start, d, 3)

        if end[0] > rows or end[0] < 0 or end[1] > cols or end[1] < 0:
            pass
        else:
            line = get_line(start, end)
            t = line_to_str(input, line)
            if t == "XMAS":
                res.append(
                    {
                        "start": start,
                        "end": end,
                        "direction": d,
                        "line": line,
                        "text": t,
                    }
                )

In [301]:
len(res)

2644

In [302]:
from itertools import product


def get_rotations_around_points(start):
    x = [start[0] - 1, start[0], start[0] + 1]
    y = [start[1] - 1, start[1], start[1] + 1]
    rotations = [[x, y] for x, y in product(x, y) if x != start[0] and y != start[1]]

    return rotations


def draw_line(points: list[list[int, int]], line_length: int, max_x: int, max_y: int):
    """Assuming given two adjecent points, this will draw a line crossing thesee points of length provided"""
    p1 = points[0][:]
    p2 = points[1][:]
    x = p2[0] - p1[0]
    y = p2[1] - p1[1]
    if abs(x) > 1 or abs(y) > 1:
        raise ValueError(f"Points provided need to be adjecent to each other.")

    line = [p1, p2]
    while len(line) < line_length:
        last_p = line[-1][:]
        last_p[0] += x
        last_p[1] += y
        line.append(last_p)

    return [
        x for x in line if x[0] >= 0 and x[1] >= 0 and x[0] <= max_x and x[1] <= max_y
    ]


def get_strings_for_centre(input, centre):
    max_x = input.shape[0] - 1
    max_y = input.shape[1] - 1

    starts = get_rotations_around_points(centre)
    lines = [
        draw_line([start_point, centre], 3, max_x, max_y) for start_point in starts
    ]
    return [line_to_str(input, x) for x in lines]


def is_cross(input, centre):
    strings = get_strings_for_centre(input, centre)
    if strings.count("MAS") >= 2:
        return True
    else:
        return False

In [303]:
a_coords = get_start_coords(input, "A")
sum([is_cross(input, x) for x in a_coords])

1952