In [1]:
def get_input(fname='test.txt'):
    input = []
    with open(fname) as f:
        for l in f.readlines():
            line = l.strip()
            input.append(list(line))
    return input

In [2]:
test_input = get_input('test.txt')
my_input = get_input('input.txt')

In [3]:
def print_grid(g):
    print('\n'.join([''.join(l) for l in g]))

In [4]:
print_grid(test_input)

...#......
.......#..
#.........
..........
......#...
.#........
.........#
..........
.......#..
#...#.....


In [5]:
from collections import deque

def get_distances(g):
    galaxies = []
    lines_costs = []
    for i, l in enumerate(g):
        lines_costs.append(1)
        v = set()
        for j, c in enumerate(l):
            if c == '#':
                galaxies.append((i, j))
            v.add(c)
        if v == {'.'}:
            lines_costs[-1] = 2
    columns_costs = []
    for j in range(len(g[0])):
        columns_costs.append(1)
        v = set(g[i][j] for i in range(len(g)))
        if v == {'.'}:
            columns_costs[-1] = 2
    distances = {}
    q = deque(galaxies)
    while len(q) > 1:
        g1 = q.popleft()
        distances[g1] = {}
        for g2 in q:
            i1, i2 = g1[0], g2[0]
            j1, j2 = g1[1], g2[1]
            if i1 > i2:
                i1, i2 = i2, i1
            if j1 > j2:
                j1, j2 = j2, j1
            d = sum(lines_costs[i] for i in range(i1 + 1, 1 + i2)) + sum(columns_costs[j] for j in range(j1 + 1, 1 + j2))
            distances[g1][g2] = d
    return distances

In [6]:
get_distances(test_input)

{(0, 3): {(1, 7): 6,
  (2, 0): 6,
  (4, 6): 9,
  (5, 1): 9,
  (6, 9): 15,
  (8, 7): 15,
  (9, 0): 15,
  (9, 4): 12},
 (1, 7): {(2, 0): 10,
  (4, 6): 5,
  (5, 1): 13,
  (6, 9): 9,
  (8, 7): 9,
  (9, 0): 19,
  (9, 4): 14},
 (2, 0): {(4, 6): 11,
  (5, 1): 5,
  (6, 9): 17,
  (8, 7): 17,
  (9, 0): 9,
  (9, 4): 14},
 (4, 6): {(5, 1): 8, (6, 9): 6, (8, 7): 6, (9, 0): 14, (9, 4): 9},
 (5, 1): {(6, 9): 12, (8, 7): 12, (9, 0): 6, (9, 4): 9},
 (6, 9): {(8, 7): 6, (9, 0): 16, (9, 4): 11},
 (8, 7): {(9, 0): 10, (9, 4): 5},
 (9, 0): {(9, 4): 5}}

In [7]:
sum(v for g in get_distances(test_input).values() for v in g.values())

374

In [8]:
sum(v for g in get_distances(my_input).values() for v in g.values())

9609130

In [9]:
def get_distances2(g, expand=1_000_000):
    galaxies = []
    lines_costs = []
    for i, l in enumerate(g):
        lines_costs.append(1)
        v = set()
        for j, c in enumerate(l):
            if c == '#':
                galaxies.append((i, j))
            v.add(c)
        if v == {'.'}:
            lines_costs[-1] = expand
    columns_costs = []
    for j in range(len(g[0])):
        columns_costs.append(1)
        v = set(g[i][j] for i in range(len(g)))
        if v == {'.'}:
            columns_costs[-1] = expand
    distances = {}
    q = deque(galaxies)
    while len(q) > 1:
        g1 = q.popleft()
        distances[g1] = {}
        for g2 in q:
            i1, i2 = g1[0], g2[0]
            j1, j2 = g1[1], g2[1]
            if i1 > i2:
                i1, i2 = i2, i1
            if j1 > j2:
                j1, j2 = j2, j1
            d = sum(lines_costs[i] for i in range(i1 + 1, 1 + i2)) + sum(columns_costs[j] for j in range(j1 + 1, 1 + j2))
            distances[g1][g2] = d
    return distances

In [10]:
sum(v for g in get_distances2(test_input, 10).values() for v in g.values())

1030

In [11]:
sum(v for g in get_distances2(my_input, 1_000_000).values() for v in g.values())

702152204842