In [61]:
# possible permutations:
# horizontal, vertical, diagonal
# also can be in reverse
# and can overlap one another

# convert vertical, horizontal and diagonal planes of the matrix to individual strings
# use regex to find XMAS/SAMX substrings

import re

class XMAS_Search:
    # use lookaheads to account for overlapping strings
    pattern = re.compile(r'(?=XMAS|SAMX)')
    def __init__(self, matrix: list):
        self.matrix = matrix
    
    def getLevels(self):

        max_col = len(self.matrix[0])
        max_row = len(self.matrix)
        cols = [[] for _ in range(max_col)]
        rows = [[] for _ in range(max_row)]
        fdiag = [[] for _ in range(max_row + max_col - 1)]
        bdiag = [[] for _ in range(len(fdiag))]
        min_bdiag = -max_row + 1

        for x in range(max_col):
            for y in range(max_row):
                cols[x].append(self.matrix[y][x])
                rows[y].append(self.matrix[y][x])
                fdiag[x+y].append(self.matrix[y][x])
                bdiag[x-y-min_bdiag].append(self.matrix[y][x])

        cols = [''.join([char for char in col]) for col in cols if len(col) >= 4]
        rows = [''.join([char for char in row]) for row in rows if len(row) >= 4]
        fdiag = [''.join([char for char in diag]) for diag in fdiag if len(diag) >= 4]
        bdiag = [''.join([char for char in diag]) for diag in bdiag if len(diag) >= 4]

        levels = cols + rows + fdiag + bdiag
        assert len(levels) == (len(cols) + len(rows) + len(fdiag) + len(bdiag))
        return levels
    
    def findSubstring(self, level: str):
        return len(re.findall(self.pattern, level))
    
    def getXMASCount(self):
        count = 0
        for level in self.getLevels():
            count += self.findSubstring(level)
        
        return count 



        


In [62]:
with open('data/test/4.txt', 'r', encoding='utf-8') as f:
    lines = f.read().splitlines()
    matrix = [[letter for letter in line] for line in lines]

# print(lines)
part1 = XMAS_Search(matrix)
# part1.getHorizontalLevels()
# part1.getVerticalLevels()
# part1.getDiagonalLevels()
part1.getXMASCount()

18

In [64]:
with open('data/input/4.txt', 'r', encoding='utf-8') as f:
    lines = f.read().splitlines()
    matrix = [[letter for letter in line] for line in lines]

part1 = XMAS_Search(matrix)
part1.getXMASCount()

2336

In [97]:
# part 2: A will always be in the middle
# we can look for the nuclei and then check the diagonals

class Part2:
    def __init__(self, matrix: list):
        self.matrix = matrix 
    
    def findNuclei(self):
        coords = []
        # limits
        vlim = len(self.matrix)-1
        hlim = len(self.matrix[0])-1
        # get coords of all A's in the matrix, except those on borders
        for line_idx, line in enumerate(matrix):
            for char_idx, char in enumerate(line):
                if char == 'A' and 1 <= char_idx < hlim and 1 <= line_idx < vlim:
                    coords.append((line_idx, char_idx))
        return coords
    
    def checkNucleus(self, nucleus: tuple):
        # have to check both diagonals
        mat = self.matrix
        (x, y) = nucleus
        def checkChar(matrix, x, y, char):
            return matrix[x][y] == char
        
        bdiag = ((checkChar(mat, x-1, y-1, "S") and checkChar(mat, x+1, y+1, "M")) or (checkChar(mat, x-1, y-1, "M") and checkChar(mat, x+1, y+1, "S")))
        fdiag = ((checkChar(mat, x+1, y-1, "S") and checkChar(mat, x-1, y+1, "M")) or (checkChar(mat, x+1, y-1, "M") and checkChar(mat, x-1, y+1, "S")))
        return bdiag and fdiag


    def getXMASCount(self):
        count = 0
        for coord in self.findNuclei():
            if self.checkNucleus(coord):
                count += 1

        return count


        


    

            

In [98]:
with open('data/test/4.txt', 'r', encoding='utf-8') as f:
    lines = f.read().splitlines()
    matrix = [[letter for letter in line] for line in lines]

part2 = Part2(matrix)
part2.getXMASCount()

9

In [99]:
with open('data/input/4.txt', 'r', encoding='utf-8') as f:
    lines = f.read().splitlines()
    matrix = [[letter for letter in line] for line in lines]

part2 = Part2(matrix)
part2.getXMASCount()

1831