In [2]:
from functools import reduce
from itertools import permutations

from aoc import submit, raw
from collections import defaultdict

DAY = 8
INPUT = [[patterns.strip().split() for patterns in line.split('|')] for line in raw(day=DAY).splitlines()]

In [3]:
@submit(day=DAY)
def part_one():
    return sum([sum(len(o) in [2, 3, 4, 7] for o in outputs) for [_, outputs] in INPUT])


part_one: 440 (0.16 ms)
✅ That's the right answer!


In [51]:
def match_digits(segments):
    segments_by_length = group_by([set(s) for s in segments], lambda s: len(s))
    one, four, seven, eight = (next(segments_by_length[x]) for x in [2, 4, 3, 7])

    bd = set.difference(four, one)
    eg = set.difference(eight, set.union(four, seven))

    zero = next(s for s in segments_by_length[6] if not s.issuperset(bd))
    two = next(s for s in segments_by_length[5] if s.issuperset(eg))
    three = next(s for s in segments_by_length[5] if s.issuperset(one))
    five = next(s for s in segments_by_length[5] if s.issuperset(bd))
    six = next(s for s in segments_by_length[6] if s.issuperset(set.union(bd, eg)))
    nine = next(s for s in segments_by_length[6] if not s.issuperset(eg))

    return [sorted(list(s)) for s in [zero, one, two, three, four, five, six, seven, eight, nine]]

@submit(day=DAY)
def part_two():
    result = 0
    for [segments, outputs] in INPUT:
        digits = match_digits(segments)
        for i, output in enumerate(outputs):
            digit = digits.index(sorted(output))
            result += digit * (10 ** (len(outputs) - 1 - i))
    return result

part_two: 1046281 (4.76 ms)
✅ That's the right answer!


In [40]:
## First version which creates a substitution table by counting segments

DIGITS = {s: i for i, s in
          enumerate(['abcefg', 'cf', 'acdeg', 'acdfg', 'bcdf', 'abdfg', 'abdefg', 'acf', 'abcdefg', 'abcdfg'])}


def group_by(seq, key):
    return reduce(lambda grp, val: grp[key(val)].append(val) or grp, seq, defaultdict(list))
def occurrences(seq):
    return [(c, sum(c in s for s in seq)) for c in set.union(*seq)]
def get_subs(segments):
    segments_by_length = group_by([set(s) for s in segments], lambda s: len(s))
    one, four, seven = (segments_by_length[x][0] for x in [2, 4, 3])
    subs = {set.difference(seven, one).pop(): 'a'}

    for x, count in occurrences(segments_by_length[6]):
        match (count, x in one, x in four):
            case (2, True, _): subs[x] = 'c'
            case (3, True, _): subs[x] = 'f'
            case (2, _, True): subs[x] = 'd'
            case (3, _, True): subs[x] = 'b'
    for x, count in occurrences([s.difference(subs.keys()) for s in segments_by_length[5]]):
        if count == 1: subs[x] = 'e'
        if count == 3: subs[x] = 'g'

    return subs


@submit(day=DAY)
def part_two():
    result = 0
    for [segments, outputs] in INPUT:
        subs = get_subs(segments)
        for i, output in enumerate(outputs):
            digit = DIGITS[''.join(sorted([subs[o] for o in output]))]
            result += digit * (10 ** (len(outputs) - 1 - i))
    return result


part_two: 1046281 (4.25 ms)
✅ That's the right answer!
