In [118]:
import numpy as np
import os
from pathlib import Path
import re
from collections import Counter, defaultdict

FOLDER = Path(os.path.dirname(os.path.realpath("__file__"))) / 'data'
in_file = 'day8.txt'

with open(FOLDER / in_file) as f:
    data = [s.strip().split(' | ') for s in f]


In [119]:
from collections import Counter

'''each display digit expressed as an 8 bit int'''
bits = {
    0b1110111: '0',
    0b0010010: '1',
    0b1011101: '2',
    0b1011011: '3',
    0b0111010: '4',
    0b1101011: '5',
    0b1101111: '6',
    0b1010010: '7',
    0b1111111: '8',
    0b1111011: '9'
}

def rev_dict(l):
    '''map counts to lists of characters'''
    counts = Counter(''.join(l))
    res = {}
    for k, v in counts.items():
        res.setdefault(v, []).append(k)
    return res

def map_wires(l):
    '''
    Determine which input character lights up which segment.
    Return a dict mapping character to bit

    The info from two sets of letter counts is enough to decode
    all letter counts
    '''
    l = l.split()
    
    # letter counts for all tokens
    rev = rev_dict(l)
    # letters counts for tokens except '4'
    rev2 = rev_dict([s for s in l if len(s) != 4])

    bits = {
        64: None,       #a
        32: rev[6][0],  #b
        16: None,       #c
        8:  rev2[6][0], #d
        4:  rev[4][0],  #e
        2:  rev[9][0],  #f
        1:  None        #g
    }
    bits[64] = next(c for c in rev2[8] if c != bits[2])  #a
    bits[16] = next(c for c in rev[8] if c != bits[64])  #c
    bits[1]  = next(c for c in rev2[7] if c != bits[16]) #g 

    return {v:k for k, v in bits.items()}

def word_to_digit(word, bit_map):
    return sum(bit_map[c] for c in word)

def decode_input(s, wires):
    tokens = s.split()
    s = ''.join([bits[word_to_digit(t, wires)] for t in tokens])
    return int(s)
    


In [127]:
def solution_one(data):
    count = 0
    for line in data:
        key, cypher = line
        count += len([s for s in cypher.split() if len(s) in [2, 3, 4, 7]])
    
    return count

solution_one(data)

548

In [120]:
def solution_two(data):
    total = 0
    for line in data:
        key, cypher = line
        wires = map_wires(key)
        total += decode_input(cypher, wires)
    return total

print(solution_two(data))


1074888
