In [63]:
def search_grid(grid, target):
    rows, cols = len(grid), len(grid[0])
    count = 0

    def check_direction(row, col, dr, dc):
        for i in range(len(target)):
            r, c = row + i * dr, col + i * dc
            if r < 0 or r >= rows or c < 0 or c >= cols or grid[r][c] != target[i]:
                return False

        return True

    # Iterate through all cells in the grid
    for row in range(rows):
        for col in range(cols):
            # Left
            if check_direction(row, col, 0, 1):
                count += 1
            # Right
            if check_direction(row, col, 0, -1):
                count += 1
            # Down
            if check_direction(row, col, 1, 0):
                count += 1
            # Up
            if check_direction(row, col, -1, 0):
                count += 1
            # Diag NW -> SE
            if check_direction(row, col, 1, 1): 
                count += 1
            # Diag SE -> NW
            if check_direction(row, col, -1, -1):
                count += 1
            # Diag NE -> SW
            if check_direction(row, col, 1, -1):
                count += 1
            # Diag SW -> NE
            if check_direction(row, col, -1, 1):
                count += 1

    return count

def check_surrounding_values(grid, row, col, directions, cross_types):
    surrounding_values = {}
    for direction, (dr, dc) in directions.items():
        surrounding_values[direction] = grid[row + dr][col + dc]

    for cross_type in cross_types:
        if all(surrounding_values[dir] == char for dir, char in cross_type.items()):
            return True

def part2():
    total_x = 0
    directions = {
        'top_left': (-1, -1),
        'top_right': (-1, +1),
        'bot_left': (+1, -1),
        'bot_right': (+1, +1)
    }
    
    cross_types = [
        {
            'top_left': 'M',
            'top_right': 'S',
            'bot_left': 'M',
            'bot_right': 'S'
        },
        {
            'top_left': 'S',
            'top_right': 'M',
            'bot_left': 'S',
            'bot_right': 'M'
        },
        {
            'top_left': 'S',
            'top_right': 'S',
            'bot_left': 'M',
            'bot_right': 'M'
        },
        {
            'top_left': 'M',
            'top_right': 'M',
            'bot_left': 'S',
            'bot_right': 'S'
        }
    ]

    rows, cols = len(grid), len(grid[0])

    for row in range(rows):
        for col in range(cols):
            # X can't be made with an A in the first or last row or col
            if row == 0 or col == 0 or row == (rows - 1) or col == (cols - 1):
                continue

            if grid[row][col] == 'A':
                if check_surrounding_values(grid, row, col, directions, cross_types):
                    total_x += 1

    print(total_x)
    
def read_grid_from_file(filename):
    with open(filename, 'r') as file:
        grid = [list(line.strip()) for line in file]
    return grid

In [64]:
target = "XMAS"
grid = read_grid_from_file("day04_input.csv")

count = search_grid(grid, target)
print(count)

part2()

2662
2034
