In [1]:
%matplotlib inline
import collections
import itertools

In [2]:
testlines = '''............
........0...
.....0......
.......0....
....0.......
......A.....
............
............
........A...
.........A..
............
............'''.splitlines()

In [3]:
with open('day8input.txt') as fp:
    data = fp.read().splitlines()

## Part 1 ##

In [4]:
def get_antenna_map(lines):
    nrows, ncols = len(lines), len(lines[0])
    positions = collections.defaultdict(list)
    for row, line in enumerate(lines):
        for col, c in enumerate(line):
            if c != '.':
                positions[c].append((row, col))
    return positions, nrows, ncols

In [6]:
def get_pair_antinodes(pos1, pos2, nrows, ncols):
    delrow = abs(pos1[0] - pos2[0])
    delcol = abs(pos1[1] - pos2[1])
    vec12 = (pos2[0] - pos1[0], pos2[1] - pos1[1])
    a1 = (pos2[0] + vec12[0], pos2[1] + vec12[1])
    a2 = (pos1[0] - vec12[0], pos1[1] - vec12[1])
    antinodes = []
    if (a1[0] >= 0) and (a1[0] < nrows) and (a1[1] >= 0) and (a1[1] < ncols):
        antinodes.append(a1)
    if (a2[0] >= 0) and (a2[0] < nrows) and (a2[1] >= 0) and (a2[1] < ncols):
        antinodes.append(a2)
    return antinodes    

In [8]:
def get_all_antinodes(antenna_map, nrows, ncols):
    antinodes = collections.defaultdict(list)
    for freq in antenna_map:
        for pos1,pos2 in itertools.combinations(antenna_map[freq], 2):
            antinodes[freq].extend(get_pair_antinodes(pos1, pos2, nrows, ncols))
    return antinodes

In [15]:
def part1(lines):
    antmap, nrows, ncols = get_antenna_map(lines)
    antinodes = get_all_antinodes(antmap, nrows, ncols)
    unique_antinodes = set(itertools.chain(*(antinodes[freq] for freq in antinodes)))
    return len(unique_antinodes)

In [17]:
assert(14 == part1(testlines))

In [18]:
part1(data)

303

## Part 2 ##

In [19]:
def get_pair_antinodes2(pos1, pos2, nrows, ncols):
    delrow = abs(pos1[0] - pos2[0])
    delcol = abs(pos1[1] - pos2[1])
    vec12 = (pos2[0] - pos1[0], pos2[1] - pos1[1])
    antinodes = []
    # start from pos2
    arow, acol = pos2
    while (0 <= arow < nrows) and (0 <= acol < ncols):
        antinodes.append((arow, acol))
        arow, acol = arow + vec12[0], acol + vec12[1]
    # start from pos1
    arow, acol = pos1
    while (0 <= arow < nrows) and (0 <= acol < ncols):
        antinodes.append((arow, acol))
        arow, acol = arow - vec12[0], acol - vec12[1]
    return antinodes    

In [20]:
def get_all_antinodes2(antenna_map, nrows, ncols):
    antinodes = collections.defaultdict(list)
    for freq in antenna_map:
        for pos1,pos2 in itertools.combinations(antenna_map[freq], 2):
            antinodes[freq].extend(get_pair_antinodes2(pos1, pos2, nrows, ncols))
    return antinodes

In [None]:
def part2(lines):
    antmap, nrows, ncols = get_antenna_map(lines)
    antinodes = get_all_antinodes2(antmap, nrows, ncols)
    unique_antinodes = set(itertools.chain(*(antinodes[freq] for freq in antinodes)))
    return len(unique_antinodes)