In [357]:
import re
import itertools

In [358]:
XMAS_PATTERN = r"XMAS"
SAMX_PATTERN = r"SAMX"


def find_matches(input, pattern="XMAS"):

    forward_pattern = r"" + pattern
    backward_pattern = forward_pattern[::-1]

    count = 0
    locations = []
    for row in input:
        xmas_matches = re.finditer(forward_pattern, row)
        samx_matches = re.finditer(backward_pattern, row)
        
        matches = itertools.chain(samx_matches,xmas_matches)
        count += len(list(matches))
        locations = [match.span()[1] for match in matches]
    return count, locations

def horizontal_unwrap(puzzle):
    return puzzle

def vertical_unwrap(puzzle):
    unpack_transpored = list(map(list, zip(*puzzle)))
    packed_transposed = []
    for row in unpack_transpored:
        packed_transposed.append("".join(row[:]))
    return packed_transposed

def diagonal_unwrap(puzzle):
    unpacked = list(map(list, puzzle))
    row_count = len(unpacked)
    col_count = len(unpacked[0])

    row_index = range(0, row_count)
    col_index = range(0, col_count)

    packed = []
    # lower left - southeast direction
    ll_se_indicies = []
    ll_se_unwrapped = []
    for row, ind in enumerate(range(0, row_count)):
        # packed.append([unpacked[i+row][j] for \
        #             i,j in zip(row_index, col_index) if \
        #                 (j <= col_count-1) & (i+row <= row_count-1)])

        content = []
        indicies = []
        for i,j in zip(row_index, col_index):
            if (j <= col_count-1) & (i+row <= row_count-1):
                content.append(unpacked[i+row][j])
                indicies.append((i+row,j))
        ll_se_unwrapped.append(content)
        ll_se_indicies.append(indicies)
    
    # upper right - southeast direction
    ur_se_indicies = []
    ur_se_unwrapped = []
    for col, ind in enumerate(range(1, col_count)):
        # packed.append([unpacked[i][j+col] for \
        #             i,j in zip(row_index, col_index) if \
        #                 (j+col <= col_count-1) & (i <= row_count-1)])
        content = []
        indicies = []
        for i,j in zip(row_index, col_index):
            if (j+col <= col_count-1) & (i <= row_count-1):
                content.append(unpacked[i][j+col])
                indicies.append((i,j+col))
        ur_se_unwrapped.append(content)
        ur_se_indicies.append(indicies)

    row_index = range(row_count-1, -1, -1)
    col_index = range(0, col_count, 1)

    # lower left - northeast direction
    ll_ne_indicies = []
    ll_ne_unwrapped = []
    for row, ind in enumerate(range(0, row_count)):
        # packed.append([unpacked[i-row][j] for \
        #             i,j in zip(row_index, col_index) if \
        #                 (j <= col_count-1) & (i-row >= 0)])
        content = []
        indicies = []
        for i,j in zip(row_index, col_index):
            if (j <= col_count-1) & (i-row >= 0):
                content.append(unpacked[i-row][j])
                indicies.append((i-row,j))
        ll_ne_unwrapped.append(content)
        ll_ne_indicies.append(indicies)

    # upper right - northeast direction
    ur_ne_indicies = []
    ur_ne_unwrapped = []
    for col, ind in enumerate(range(1, col_count)):
        # packed.append([unpacked[i][j+col] for \
        #             i,j in zip(row_index, col_index) if \
        #                 (j+col <= col_count-1) & (i >= 0)])
        content = []
        indicies = []
        for i,j in zip(row_index, col_index):
            if (j+col <= col_count-1) & (i >= 0):
                content.append(unpacked[i][j+col])
                indicies.append((i,j+col))
        ur_ne_unwrapped.append(content)
        ur_ne_indicies.append(indicies)

    packed = ll_se_unwrapped + ur_se_unwrapped + ll_ne_unwrapped + ur_ne_unwrapped
    indicies = ll_se_indicies + ur_se_indicies + ll_ne_indicies + ur_ne_indicies

    joined = ["".join(row) for row in packed]
        
    return joined, indicies

In [359]:
INPUT_FILE = 'test_input.txt'
with open(INPUT_FILE, 'r') as f:
    input = f.read().splitlines()

In [360]:
def print_puzzle(puzzle):
    for row in puzzle:
        print(row)

In [361]:
def part_1(puzzle):
    horizontal_puzzle = horizontal_unwrap(puzzle)
    vertical_puzzle = vertical_unwrap(puzzle)
    diagonal_puzzle, _ = diagonal_unwrap(puzzle)

    horizontal_count, _ = find_matches(horizontal_puzzle, pattern="XMAS")
    vertical_count, _ = find_matches(vertical_puzzle, pattern="XMAS")
    diagonal_count, _ = find_matches(diagonal_puzzle, pattern="XMAS")
    
    total_count = horizontal_count + vertical_count + diagonal_count

    return total_count

In [362]:
part_1_answer = part_1(input)
print(part_1_answer)

[['M', 'S', 'X', 'M', 'A', 'X', 'S', 'A', 'M', 'X'], ['M', 'M', 'A', 'S', 'M', 'A', 'S', 'M', 'S'], ['A', 'S', 'A', 'M', 'S', 'A', 'M', 'A'], ['M', 'M', 'A', 'M', 'M', 'X', 'M'], ['X', 'X', 'S', 'A', 'M', 'X'], ['X', 'M', 'X', 'M', 'A'], ['S', 'A', 'M', 'X'], ['S', 'A', 'M'], ['M', 'X'], ['M'], ['M', 'S', 'X', 'M', 'A', 'X', 'S', 'A', 'M', 'X'], ['M', 'A', 'S', 'A', 'M', 'X', 'X', 'A', 'M'], ['M', 'M', 'X', 'S', 'X', 'A', 'S', 'A'], ['S', 'X', 'M', 'M', 'A', 'M', 'S'], ['X', 'M', 'A', 'S', 'M', 'A'], ['X', 'S', 'A', 'M', 'M'], ['M', 'M', 'M', 'X'], ['A', 'S', 'M'], ['S', 'A'], ['M', 'A', 'X', 'M', 'M', 'M', 'M', 'A', 'S', 'M'], ['M', 'A', 'S', 'M', 'A', 'S', 'A', 'M', 'S'], ['S', 'M', 'A', 'S', 'A', 'M', 'S', 'A'], ['S', 'X', 'A', 'M', 'X', 'M', 'M'], ['X', 'M', 'A', 'S', 'X', 'X'], ['X', 'S', 'X', 'M', 'X'], ['M', 'M', 'A', 'S'], ['A', 'S', 'M'], ['M', 'M'], ['M'], ['M', 'A', 'X', 'M', 'M', 'M', 'M', 'A', 'S', 'M'], ['X', 'M', 'A', 'S', 'X', 'X', 'S', 'M', 'A'], ['M', 'M', 'M', 'A', '

In [363]:
test = re.findall(XMAS_PATTERN, input[0])

In [364]:
for t in test:
    print(t)

XMAS
