In [1]:
import re
import itertools
import numpy as np

### Part 1

In [2]:
with open('../data/day4.txt') as f:
    user_input = f.read()
user_input = user_input.split('\n')

In [3]:
test_input = '''MMMSXXMASM
MSAMXMSMSA
AMXSXMAAMM
MSAMASMSMX
XMASAMXAMM
XXAMMXXAMA
SMSMSASXSS
SAXAMASAAA
MAMMMXMMMM
MXMXAXMASX'''.split('\n')

In [4]:
def find_diagonals(array):
    n = len(array)          # Number of rows
    m = len(array[0])       # Number of columns

    # Diagonals from top-left to bottom-right (main diagonals)
    main_diagonals = {}
    for i in range(n):
        for j in range(m):
            k = i - j
            main_diagonals.setdefault(k, []).append(array[i][j])
    main_diagonals = {k: ''.join([str(x) for x in d]) for k, d in main_diagonals.items()}

    return main_diagonals

In [5]:
def find_anti_diagonals(array):
    n = len(array)          # Number of rows
    m = len(array[0])       # Number of columns

    # Diagonals from top-right to bottom-left (anti-diagonals)
    anti_diagonals = {}
    for i in range(n):
        for j in range(m):
            k = i + j
            anti_diagonals.setdefault(k, []).append(array[i][j])
    anti_diagonals ={k: ''.join([str(x) for x in d]) for k, d in anti_diagonals.items()}

    return anti_diagonals

In [6]:
def search_words(array, words):
    total = 0
    indices = set()
    pattern = '|'.join(words)
    m = len(array)
    n = len(array[0])

    # Horizontal search
    for i, line in enumerate(array):
        matches = re.finditer(rf'(?=({pattern}))', line)
        for match in matches:
            total += 1
            new_indices = [(i, match.start() + j) for j in range(4)]
            indices = indices.union(new_indices)

    
    # Vertical search
    for i, line in enumerate([''.join(a) for a in zip(*array)]):
        matches = re.finditer(rf'(?=({pattern}))', line)
        for match in matches:
            total += 1
            new_indices = [(match.start() + j, i) for j in range(4)]
            indices = indices.union(new_indices)

    # Diagonal search
    # Only works for square arrays
    for imj, line in find_diagonals(array).items():
        matches = re.finditer(rf'(?=({pattern}))', line)
        for match in matches:
            total += 1
            ipj = m - len(line)
            s = match.start()
            start = ((ipj + imj)//2, (ipj - imj)//2)
            new_indices = [(start[0] + s + j, start[1] + s + j) for j in range(4)]
            indices = indices.union(new_indices)

    # Anti-diagonal search
    for ipj, line in find_anti_diagonals(array).items():
        matches = re.finditer(rf'(?=({pattern}))', line)
        for match in matches:
            total += 1
            imj = 1-len(line)
            s = match.start()
            start = ((ipj + imj)//2, (ipj - imj)//2)
            new_indices = [(start[0] + s + j, start[1] - s - j) for j in range(4)]
            indices = indices.union(new_indices)

    return total, indices

In [7]:
total, indices = search_words(test_input, ['XMAS', 'SAMX'])
print(total)

18


In [8]:
def mask_array(array, indices):
    m = len(array)
    n = len(array[0])

    masked_array = np.full((m, n), '.')
    for i in range(m):
        for j in range(n):
            if (i, j) in indices:
                masked_array[i, j] = array[i][j]

    return [''.join(row) for row in masked_array]

In [9]:
mask_array(test_input, indices)

['....XXMAS.',
 '.SAMXMS...',
 '...S..A...',
 '..A.A.MS.X',
 'XMASAMX.MM',
 'X.....XA.A',
 'S.S.S.S.SS',
 '.A.A.A.A.A',
 '..M.M.M.MM',
 '.X.X.XMASX']

In [10]:
search_words(test_input, ['XMAS', 'SAMX'])[0]

18

In [11]:
search_words(user_input, ['XMAS', 'SAMX'])[0]

2297

### Part 2

In [12]:
array = [[char for char in line] for line in user_input]

In [13]:
tiles = np.lib.stride_tricks.sliding_window_view(array, (3, 3))

In [14]:
total = 0
indices = set()
for i, row in enumerate(tiles):
    for j, tile in enumerate(row):
        joined_tile = ''.join(itertools.chain.from_iterable(tile))
        if re.search(r'(M.M.A.S.S)|(M.S.A.M.S)|(S.M.A.S.M)|(S.S.A.M.M)', joined_tile):
            total += 1
            indices.add((i, j))
            indices.add((i, j+2))
            indices.add((i+1, j+1))
            indices.add((i+2, j))
            indices.add((i+2, j+2))

In [15]:
m = len(array)
n = len(array[0])

In [16]:
masked_array = np.full((m, n), '.')
for i in range(m):
    for j in range(n):
        if (i, j) in indices:
            masked_array[i, j] = array[i][j]

In [17]:
print('\n'.join([''.join(row) for row in masked_array]))

.........S.M....S.M.........S.S.S.M.......M.M.............S.M.S.M.S.M.S.S................M.S.S.S.S....SMSS.S.M.....M.M....M.M........SMSM...
M.M......SAS....SAS.S...M.M..A...A..M.S.M.SAM..M.S.M....S.MA.A.A.A.A...A.M.S.MM.S...S.M...A...A.AS.M...AA..MAM..S.MSAM.....A.....M.S..AA....
.A.......SAM....SAMA..M.SA.MMSM.S.M..A.A..SAS...A.A......AS.M.S.M.S.M.M.M.A.A..A.....A...M.S.M.M.MA...MMMS.SAM...A.SAS....S.S.....AM.MSMS...
SSSS..S.MMSM....M.M.M.MAS.S.AA......M.S.M.SAM.SMMS.M....S.M......A.A.....M.S.MM.S.S.SMMS.M....M.SS.M...A..AS.S..SMMSAM.S.........M.SA.A.....
.AA....A.A.......A.A.AM.S..MSSSMM.M..A.A..M.MAMAM...............M.S.M............A.A..A.A......A....M.S.SM.M.M....AM.MAM.S.......M.S.S.S....
MMMM..S.M.S.S...S.S.S.S.M.S.AAA..A..M.S.M...M.SAMS.S......MMSS..................M.MMMM.S.MS.S.M.S....A......A....M.S.M.SAS.S......A...A.....
M.S.M......A.M.M.......A.A.MMSMMS.S....A......S.S.A.M.MS.S.AA...........SSMS........A......A......S.M.S....S.S........AM.SA....S.M.S.M.M....
.A.A......MSM