In [1]:
import utils

data = utils.get_aoc_input(2021, 8).splitlines()
print(len(data), "lines")


200 lines


In [2]:
# aoc 2021 day 8 part 1
def part1():
    count = 0
    for line in data:
        patterns, output = line.split(" | ")
        output_values = output.split()
        for value in output_values:
            if len(value) in [2, 3, 4, 7]:
                count += 1
    return count

print("Part 1:", part1())

Part 1: 514


In [3]:
# aoc 2021 day 8 part 2
def part2():
    total = 0
    for line in data:
        patterns, output = line.split(" | ")
        patterns = patterns.split()
        output = output.split()
        for value in output:
            if len(value) == 2:
                total += 1
            elif len(value) == 3:
                total += 7
            elif len(value) == 4:
                total += 4
            elif len(value) == 7:
                total += 8
    return total

print("Part 2:", part2())

Part 2: 2564


In [4]:
def part2():
    total = 0
    # Identify unique digits first:
    # Length 2 → digit 1 (uses segments c,f)
    # Length 3 → digit 7 (uses segments a,c,f)
    # Length 4 → digit 4 (uses segments b,c,d,f)
    # Length 7 → digit 8 (uses all segments)
    
    for line in data:
        patterns, output = line.split(" | ")
        patterns = patterns.split()
        output = output.split()
        
        digit_map = {}
        segment_map = {}
        
        # Identify digits with unique lengths
        for pattern in patterns:
            if len(pattern) == 2:
                digit_map[1] = set(pattern)
            elif len(pattern) == 3:
                digit_map[7] = set(pattern)
            elif len(pattern) == 4:
                digit_map[4] = set(pattern)
            elif len(pattern) == 7:
                digit_map[8] = set(pattern)
        
        """ 
        Deduce remaining digits (0, 2, 3, 5, 6, 9):
            Length 5 (digits 2, 3, 5):
                If it contains all segments from digit 1 → it's 3
                If it overlaps with 4 segments in 2 places → it's 2
            Otherwise → it's 5

            Length 6 (digits 0, 6, 9):
                If it contains all segments from digit 4 → it's 9
                If it contains all segments from digit 1 → it's 0
            Otherwise → it's 6
        """
        for pattern in patterns:
            pattern_set = set(pattern)
            if len(pattern) == 5:  # Could be 2, 3, or 5
                if digit_map[1].issubset(pattern_set):
                    digit_map[3] = pattern_set
                elif len(pattern_set.intersection(digit_map[4])) == 2:
                    digit_map[2] = pattern_set
                else:
                    digit_map[5] = pattern_set
            elif len(pattern) == 6:  # Could be 0, 6, or 9
                if digit_map[4].issubset(pattern_set):
                    digit_map[9] = pattern_set
                elif digit_map[1].issubset(pattern_set):
                    digit_map[0] = pattern_set
                else:
                    digit_map[6] = pattern_set
        # Decode the output values
        output_value = ""
        for value in output:
            value_set = set(value)
            for digit, pattern in digit_map.items():
                if value_set == pattern:
                    output_value += str(digit)
                    break
                
        total += int(output_value)
    return total

print("Part 2:", part2())        

Part 2: 1012272


In [5]:
def part2_improved(data):
    """
    Decode seven-segment displays using constraint-based deduction.
    Returns sum of all decoded four-digit output values.
    """
    def decode_line(line):
        patterns, output = line.split(" | ")
        patterns = [set(p) for p in patterns.split()]
        output = [set(o) for o in output.split()]
        
        # Find digits with unique segment counts
        by_len = {len(p): p for p in patterns}
        digit_map = {
            1: by_len[2],
            4: by_len[4],
            7: by_len[3],
            8: by_len[7]
        }
        
        # Deduce remaining digits using set operations
        for pattern in patterns:
            size = len(pattern)
            if size == 5:  # digits 2, 3, 5
                if digit_map[1] <= pattern:
                    digit_map[3] = pattern
                elif len(pattern & digit_map[4]) == 2:
                    digit_map[2] = pattern
                else:
                    digit_map[5] = pattern
            elif size == 6:  # digits 0, 6, 9
                if digit_map[4] <= pattern:
                    digit_map[9] = pattern
                elif digit_map[1] <= pattern:
                    digit_map[0] = pattern
                else:
                    digit_map[6] = pattern
        
        # Create reverse mapping: pattern -> digit
        pattern_to_digit = {frozenset(p): d for d, p in digit_map.items()}
        
        # Decode output by converting each pattern to its digit
        return int(''.join(str(pattern_to_digit[frozenset(o)]) for o in output))
    
    return sum(decode_line(line) for line in data)

print("Part 2 Improved:", part2_improved(data))

Part 2 Improved: 1012272


In [6]:
# Alternative: More functional approach
def part2_functional(data):
    """Purely functional version with helper functions."""
    
    def identify_unique_digits(patterns):
        """Return dict of easily identifiable digits."""
        by_len = {len(p): p for p in patterns}
        return {1: by_len[2], 4: by_len[4], 7: by_len[3], 8: by_len[7]}
    
    def classify_pattern(pattern, known):
        """Classify a pattern based on known digits."""
        size = len(pattern)
        if size == 5:
            return (3 if known[1] <= pattern else
                   2 if len(pattern & known[4]) == 2 else 5)
        elif size == 6:
            return (9 if known[4] <= pattern else
                   0 if known[1] <= pattern else 6)
        return None
    
    def decode_line(line):
        patterns, output = line.split(" | ")
        patterns = [set(p) for p in patterns.split()]
        output = [set(o) for o in output.split()]
        
        digit_map = identify_unique_digits(patterns)
        
        for p in patterns:
            if (digit := classify_pattern(p, digit_map)) is not None:
                digit_map[digit] = p
        
        lookup = {frozenset(p): d for d, p in digit_map.items()}
        return int(''.join(str(lookup[frozenset(o)]) for o in output))
    
    return sum(decode_line(line) for line in data)

print("Part 2 Functional:", part2_functional(data))

Part 2 Functional: 1012272
