In [1]:
clockwise_modifiers = [(-1, -1), (-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1)]
x_modifiers = [(-1, -1), (-1, 1), (1, 1), (1, -1)]

In [2]:
test_input = """MMMSXXMASM
MSAMXMSMSA
AMXSXMAAMM
MSAMASMSMX
XMASAMXAMM
XXAMMXXAMA
SMSMSASXSS
SAXAMASAAA
MAMMMXMMMM
MXMXAXMASX"""
input_fpath = "day4.txt"

In [3]:
def parse_input(input: str) -> dict:
    character_location_dict = {}
    for row_index, row_list in enumerate(input.split()):
        for col_index, character in enumerate(row_list):
            character_location_dict[(row_index, col_index)] = character
    return character_location_dict


def wordsearch(character_location_dict: dict) -> int:
    
    word_count = 0
    target_word = "XMAS"
    
    for (start_row_index, start_col_index), character in character_location_dict.items():
        word = ""

        if not character == "X":
            continue
        
        for (row_mod, col_mod) in clockwise_modifiers:  # Going round each X in a circle
            row_index, col_index = start_row_index, start_col_index
            word = "X"
            while True:
                new_row_index, new_col_index = row_index + row_mod, col_index + col_mod
                new_character = character_location_dict.get((new_row_index, new_col_index))
                if not new_character:  # If next character in this direction doesn't exist then pick new direction
                    break

                word = word + new_character

                if not word in target_word:  # If the word assembled from all new characters is not in target word then pick new direction
                    break
                
                row_index, col_index = new_row_index, new_col_index  # If conditions met then traverse in direction to new character

                if word == target_word:
                    word_count += 1
                    break
    return word_count


def xmas_search(character_location_dict: dict) -> int:
    obj_count = 0
    
    for (row_index, col_index), character in character_location_dict.items():
        valid = True
        if not character == "A":
            continue

        x_chars = []
        for (row_mod, col_mod) in x_modifiers:
            new_row_index, new_col_index = row_index + row_mod, col_index + col_mod
            new_character = character_location_dict.get((new_row_index, new_col_index))
            if not new_character or new_character not in "MS":
                continue
            x_chars.append(new_character)

        if len(x_chars) != 4 or x_chars.count("M") != 2:
            continue

        for char in ["M", "S"]:
            inds = [i for i, e in enumerate(x_chars) if e == char]
            if abs(inds[0] - inds[1]) == 2:
                valid = False
                continue

        if not valid:
            continue
        obj_count += 1

    return obj_count


In [4]:
with open(input_fpath, "r") as f:
    day4_input = parse_input(f.read())
    print(f"Part 1 word count: {wordsearch(day4_input)}")
    print(f"Part 2 x-mas count: {xmas_search(day4_input)}")

Part 1 word count: 2397
Part 2 x-mas count: 1824
