In [1]:
import re

In [2]:
with open("./input.txt", "r") as file: 
    data = file.read().strip().split("\n")

In [3]:
class Point: 
    def __init__(self, x, y): 
        self.x = x
        self.y = y

    def __repr__(self): 
        return f"Point({self.x},{self.y})"

    def distance(self, other:"Point"): 
        """
        Returns the squared euclidian distance from self to another point
        """
        return ((other.x - self.x)**2 + abs(other.y - self.y)**2)**0.5

In [4]:
class Span: 
    def __init__(self, start:Point, stop:Point):
        self.start = start
        self.stop  = stop
        
    def __repr__(self): 
        return f"Span({self.start}, {self.stop})"

    def __contains__(self, point:Point): 
        """
        Returns True if a point is contained in self
        """
        if self.start.x <= point.x <= self.stop.x: 
            return point.y == self.start.y
        return False
        
    def __iter__(self): 
        """
        Returns all the points contained in the span
        """
        for x in range(self.start.x, self.stop.x + 1): 
            yield Point(x, self.start.y)

    def distance(self, point: Point): 
        """
        Returns the minimum distance to a point
        """
        return min(p.distance(point) for p in self)

In [5]:
class Number: 
    def __init__(self, value, location): 
        self.value = value
        self.location = location

    def __repr__(self):
        return f"Number({self.value} at {self.location})"

In [6]:
class Symbol: 
    def __init__(self, value, location): 
        self.value = value
        self.location = location

    def __repr__(self):
        return f"Symbol({self.value} at {self.location})"

In [7]:
def parse(data): 
    """
    Parse the data and return a tuple of all numbers and all symbols. 

    Returns
    -------
    tuple[list, list]
        list of numbers and list of symbols
    """
    numbers = []
    for y, row in enumerate(data): 
        for match in re.finditer("\d+", row): 
            numbers.append(
                Number(
                    int(match.group()), 
                    Span(Point(match.span()[0], y), Point(match.span()[1] - 1, y))
                )
            )
    
    symbols = []
    for y, row in enumerate(data): 
        for x, char in enumerate(row): 
            if char != "." and not re.match("\d", char): 
                symbols.append(
                    Symbol(char, Point(x, y))
                )

    return numbers, symbols

# Part 1

In [8]:
def solve(data):
    """
    Return the sum of all numbers adjacents to a symbol
    """
    numbers, symbols = parse(data)
    
    parts = []
    for number in numbers: 
        for symbol in symbols: 
            if number.location.distance(symbol.location) <= 2**0.5: 
                parts.append(number)
    
    return sum(part.value for part in set(parts))

solve(data)

521601

# Part 2

In [9]:
def solve(data):
    """
    Return the sum of all gear products, where the gear product is the product
    of the two numbers adjacents to a gear. 

    A gear is a '*' symbol adjacent to exactly two numbers
    """
    numbers, symbols = parse(data)
    
    gears = []
    for symbol in symbols: 
        # a gearbox is represented by a * symbol
        if symbol.value != "*":
            continue

        # find all the numbers adjacent to the symbol
        adjacents = []
        for number in numbers: 
            if number.location.distance(symbol.location) <= 2**0.5: 
                adjacents.append(number)

        # there must be exactly 2 adjacents for such symbol to count as a gear
        if len(adjacents) == 2: 
            gears.append(adjacents[0].value * adjacents[1].value)
    
    return sum(gears)

solve(data)

80694070