## Problem 1

In [1]:
def read_input(filename):
    f = open(f'../inputs/{filename}.txt', 'r')
    mtx = []
    while True:
        line = f.readline()
        if line == "":
            break
        mtx.append(list(line.rstrip('\n')))
    f.close()
    return mtx

In [2]:
rotate = {
    '>': 'v',
    '^': '>',
    '<': '^',
    'v': '<'
}

def next_pos(guard_char, x, y):
    if guard_char == '>':
        return (x, y+1)
    if guard_char == '^':
        return (x-1, y)
    if guard_char == '<':
        return (x, y-1)
    if guard_char == 'v':
        return (x+1, y)
    raise Exception('Invalid char for guard')

# update guard pos, return (-1, -1) if guard left the room, position tuple otherwise
def take_step(mtx, x, y):
    n, m = len(mtx), len(mtx[0])
    
    guard_char = mtx[x][y]
    (i, j) = next_pos(guard_char, x, y)
    if i < 0 or i >= n or j < 0 or j >= m:
        mtx[x][y] = 'X'
        return (-1, -1)

    if mtx[i][j] == '#':
        mtx[x][y] = rotate[guard_char]
        return (x, y)

    mtx[x][y] = 'X'
    mtx[i][j] = guard_char
    return (i,j)

In [3]:
def count_visited(mtx):
    n, m = len(mtx), len(mtx[0])
    count = 0
    for i in range(n):
        for j in range(m):
            if mtx[i][j] == 'X':
                count += 1
    return count

In [4]:
def solve1(input_filename):
    mtx = read_input(input_filename)

    n, m = len(mtx), len(mtx[0])
    x, y = -1, -1
    for i in range(n):
        for j in range(m):
            if mtx[i][j] in ['>', '^', '<', 'v']:
                x, y = i, j
                break
        if x != -1:
            break

    # I know this is infinite if the guard ends up in a loop, so I'm gonna leave it here to annoy you :P
    while True:
        (x, y) = take_step(mtx, x, y)
        if x == -1:
            break

    return count_visited(mtx)

## Problem 2

In [19]:
visited = {
    '>': 'R',
    '^': 'U',
    '<': 'L',
    'v': 'D'
}

# save directions of the guard on the visited positions
def take_step2(mtx, x, y, direction):
    n, m = len(mtx), len(mtx[0])
    
    (i, j) = next_pos(direction, x, y)
    if i < 0 or i >= n or j < 0 or j >= m:
        mtx[x][y] += visited[direction]
        return (-1, -1, '', False)

    if mtx[i][j] in ['#', 'O']:
        mtx[x][y] += visited[direction]
        return (x, y, rotate[direction], False)

    in_loop = False
    if visited[direction] in mtx[i][j]:
        in_loop = True
    mtx[x][y] += visited[direction]
    return (i,j, direction, in_loop)

In [20]:
def copy_mtx(mtx):
    mtx_copy = []
    for i in range(len(mtx)):
        mtx_copy.append(mtx[i].copy())
    return mtx_copy

# True if loop, False if guard leaves
def test_obstacle_pos(mtx, start_x, start_y, obs_x, obs_y):
    mtx_copy = copy_mtx(mtx)
    mtx_copy[obs_x][obs_y] = 'O'
    x, y = start_x, start_y
    direction = mtx[x][y]
    while True:
        (x, y, direction, in_loop) = take_step2(mtx_copy, x, y, direction)
        if in_loop:
            return True
        if x == -1:
            return False

In [24]:
def solve2(input_filename):
    mtx = read_input(input_filename)

    n, m = len(mtx), len(mtx[0])
    x, y = -1, -1
    for i in range(n):
        for j in range(m):
            if mtx[i][j] in ['>', '^', '<', 'v']:
                x, y = i, j
                break
        if x != -1:
            break
    # save starting pos
    start_x, start_y = x, y

    mtx_copy = copy_mtx(mtx)
    while True:
        (x, y) = take_step(mtx_copy, x, y)
        if x == -1:
            break

    # this is really not optimal :D
    count = 0
    for i in range(n):
        for j in range(m):
            if (i,j) != (start_x, start_y) and mtx_copy[i][j] == 'X' and test_obstacle_pos(mtx, start_x, start_y, i, j):
                count += 1

    return count