# Imports

In [None]:
from typing import List, Dict, Set

# Input

In [None]:
with open("input.txt", "r") as fb:
    data = fb.read()
    displays = [display.split("|") for display in data.splitlines()]
    displays = [[item.strip().split(" ") for item in display] for display in displays]

# Variables

In [None]:
segments_to_digit_mapping = {
    "abcefg": 0,
    "cf": 1,
    "acdeg": 2,
    "acdfg": 3,
    "bcdf": 4,
    "abdfg": 5,
    "abdefg": 6,
    "acf": 7,
    "abcdefg": 8,
    "abcdfg": 9
}

# Functions

In [None]:
def collect_wiring_counts(signal_patterns: List[str]) -> Dict[int, Dict[str, int]]:
    wiring_counts = {}
    for signal_pattern in signal_patterns:
        n = len(signal_pattern)
        for s in signal_pattern:
            wiring_counts.setdefault(n, {})
            wiring_counts[n][s] = wiring_counts[n].get(s, 0) + 1

    return wiring_counts

In [None]:
def query_wiring_counts(wiring_counts: Dict[int, Dict[str, int]], signal_length: int, signal_count: int) -> Set[str]:
    return set(signal for signal, count in wiring_counts[signal_length].items() if count == signal_count)

# Solution 1

In [None]:
count = 0
for display in displays:
    signals, output = display
    for o in output:
        if len(o) in [2, 3, 4, 7]:
            count += 1

In [None]:
count

# Solution 2

## Exploration

In [None]:
wiring_counts = collect_wiring_counts(segments_to_digit_mapping.keys())
for n, info in sorted(wiring_counts.items()):
    print(n)
    print(sorted(info.items(), key=lambda x: (x[1], x[0])))
    print("------------")

In [None]:
# Segment a: len 3 - len 2
# Segment b: len 5 : count 1 & len 6 : count 3
# Segment c: len 5 : count 2 & len 6 : count 2
# Segment d: len 5 : count 3 & len 6 : count 2
# Segment e: len 5 : count 1 & len 6 : count 2
# Segment f: len 5 : count 2 & len 6 : count 3
# Segment g: remainder

## Solution

In [None]:
def get_wiring(signal_patterns):
    wiring_counts = collect_wiring_counts(signal_patterns)

    a, = set(wiring_counts[3].keys()) - set(wiring_counts[2].keys())
    b, = query_wiring_counts(wiring_counts, 5, 1) & query_wiring_counts(wiring_counts, 6, 3)
    c, = query_wiring_counts(wiring_counts, 5, 2) & query_wiring_counts(wiring_counts, 6, 2)
    d, = query_wiring_counts(wiring_counts, 5, 3) & query_wiring_counts(wiring_counts, 6, 2)
    e, = query_wiring_counts(wiring_counts, 5, 1) & query_wiring_counts(wiring_counts, 6, 2)
    f, = query_wiring_counts(wiring_counts, 5, 2) & query_wiring_counts(wiring_counts, 6, 3)
    g, = set("abcdefg") - set([a, b, c, d, e, f])

    return {
        a: "a",
        b: "b",
        c: "c",
        d: "d",
        e: "e",
        f: "f",
        g: "g"
    }

In [None]:
def decode_output(signal_patterns: List[str], output: List[str]) -> int:
    wiring = get_wiring(signal_patterns)
    decoded = []
    for signal_pattern in output:
        segments = "".join(sorted(wiring[s] for s in signal_pattern))
        digit = segments_to_digit_mapping[segments]
        decoded.append(str(digit))
    return int("".join(decoded))

In [None]:
result = 0
for display in displays:
    signal_patterns, output = display
    decoded_output = decode_output(signal_patterns, output)
    result += decoded_output

In [None]:
result