In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
import numpy as np

In [None]:
data = load_data(2024, 2)

In [None]:
# data, part_1, part_2
tests = [
    (
        """7 6 4 2 1
1 2 7 8 9
9 7 6 2 1
1 3 2 4 5
8 6 4 4 1
1 3 6 7 9
""",
        2,
        4,
    ),
    (
        """1 2 3 4 5 100
1 2 3 4 5 100 6
""",
        0,
        2,
    ),
    (
        """1 1 1 1 1
1 1 1 1 2
""",
        0,
        0,
    ),
    (
        """1 8
""",
        0,
        1,
    ),
    (
        """1 1
""",
        0,
        1,
    ),
    (
        """1 1 0
""",
        0,
        1,
    ),
    (
        """1 1 0 1
""",
        0,
        0,
    ),
    (
        """1""",
        1,
        1,
    ),
]

# Part 1

In [None]:
def is_monotonic(arr):
    """Check if a sequence is strictly monotonic."""
    return np.all(arr[:-1] > arr[1:]) or np.all(arr[:-1] < arr[1:])

def is_bounded(arr):
    """Check if the difference between consecutive values is at most 3."""
    return np.all(np.abs(arr[:-1] - arr[1:]) <= 3)

def is_safe(arr):
    """Check if a report is safe (i.e., monotonic and bounded)."""
    return is_monotonic(arr) and is_bounded(arr)

In [None]:
def count_safe(data, safety_check=is_safe):
    """Count the number of safe reports in a bunch of reports.

    Parameters
    ----------
    data : str
        The input data, with one report per line.
    safety_check : callable(NDarray) -> bool
        The function to use to assess report safety.

    Returns
    -------
    int
        The number of safe reports.
    """
    reports = [np.array([int(v) for v in line.split()]) for line in data.splitlines()]
    return sum(safety_check(report) for report in reports)

In [None]:
check(count_safe, tests)
count_safe(data)

# Part 2

Some added test cases are (probably) not present in AOC input files:

- `1`: safe
- `1 1`: safe (remove any)
- `1 1 0`: safe (remove any of the `1`s)
- `1 8` : safe (remove any)

The safety check for part 2 first identifies at which indices discrepancies occur (the `errors` variable).

- If a **single error position** is identified, then the dampening has two possible values to remove.  
E.g., errors = [5] means that values at position (5, 6) are unsafe. Values at either index 5 or 6 can be removed.
- **Two error positions** must be consecutive to be rectifiable.  
E.g., errors = [5, 6] means that values at positions (5, 6) and (6, 7) are unsafe. The only possible value to remove is at index 6.
- **Three or more error positions** are not recoverable.

In [None]:
def is_mostly_safe(arr):
    """Identify if a report is _mostly_ safe.

    A report is safe if:
    - it is strictly monotonic,
    - the difference between two consecutive values is at most 3.

    A report is _mostly_ safe if it is safe, or if it can be made safe by
    removing a single value.

    Parameters
    ----------
    arr : ndarray
        The report, as a 1D sequence of integers.

    Returns
    -------
    bool
        Whether the report is _mostly_ safe.
    """
    if len(arr) <= 1:
        return True
    diff = arr[:-1] - arr[1:]
    slopes, cnts = np.unique(np.sign(diff), return_counts=True)
    slope = slopes[np.argmax(cnts)]
    if slope == 0:
        # Edge case for two equal consecutive values among three (slope counts are tied)
        slope = 1
    errors = np.flatnonzero((np.abs(diff) > 3) | (np.sign(diff) != slope))
    if len(errors) == 0:
        return True
    if len(errors) == 1:
        pos = errors[0]
        return is_safe(np.delete(arr, pos)) or is_safe(np.delete(arr, pos + 1))
    if len(errors) == 2 and errors[0] + 1 == errors[1]:
        return is_safe(np.delete(arr, errors[-1]))
    return False

In [None]:
check(count_safe, tests, 2, safety_check=is_mostly_safe)
count_safe(data, safety_check=is_mostly_safe)