In [1]:
import re
from dataclasses import asdict, dataclass
from pathlib import Path

import polars as pl
from aoc.polars_utils import pl_group_all

data_file = Path("../Data/day3.txt").read_text()


EXAMPLE = "xmul(2,4)%&mul[3,7]!@^do_not_mul(5,5)+mul(32,64]then(mul(11,8)mul(8,5))"
EXAMPLE2 = "xmul(2,4)&mul[3,7]!^don't()_mul(5,5)+mul(32,64](mul(11,8)undo()?mul(8,5))"


@dataclass
class Result:
    number1: int
    number2: int
    start_index: int
    is_enabled: bool = False


def adjust_results(results: list[Result], enabled_matches: list[tuple[bool, int]]):
    adjusted_results: list[Result] = []
    is_enabled = True

    def adjust_results_from(
        is_enabled: bool,
        current_enabled_match_start_index: int | None,
        next_enabled_match_start_index: int | None,
    ):
        for result_index in range(len(adjusted_results), len(results)):
            result = results[result_index]
            if (
                current_enabled_match_start_index
                and result.start_index < current_enabled_match_start_index
            ):
                return

            if (
                next_enabled_match_start_index
                and result.start_index > next_enabled_match_start_index
            ):
                return

            result.is_enabled = is_enabled
            adjusted_results.append(result)

    for enable_index in range(len(enabled_matches)):
        enabled_match = enabled_matches[enable_index]
        if enable_index == 0:
            adjust_results_from(is_enabled, None, enabled_match[1])

        is_enabled = enabled_match[0]
        next_enabled_match_index = (
            enabled_matches[enable_index + 1][1]
            if (enable_index + 1) < len(enabled_matches)
            else None
        )
        adjust_results_from(is_enabled, enabled_match[1], next_enabled_match_index)

    # To whatever end if any 🐸
    adjust_results_from(is_enabled, None, None)

    return adjusted_results


def prepare(input: str):
    multiplication_pattern = r"mul\((\d+),(\d+)\)"
    results = list(
        map(
            lambda match: Result(
                int(match.group(1)), int(match.group(2)), match.start()
            ),
            re.finditer(multiplication_pattern, input),
        )
    )
    enable_pattern = r"(do|don't)\(\)"
    enabled_matches = list(
        map(
            lambda match: (match.group(0) == "do()", match.start()),
            re.finditer(enable_pattern, input),
        )
    )

    return pl.LazyFrame(map(asdict, adjust_results(results, enabled_matches)))


example_data = prepare(EXAMPLE)
example2_data = prepare(EXAMPLE2)
data = prepare(data_file)

example2_is_enabled = example2_data.collect().get_column("is_enabled").to_list()
assert len(example2_is_enabled) == 4
assert example2_is_enabled[0] is True
assert example2_is_enabled[1] is False
assert example2_is_enabled[2] is False
assert example2_is_enabled[3] is True

In [2]:
def part1(input: pl.LazyFrame):
    return pl_group_all(
        input.with_columns(
            (pl.col("number1") * pl.col("number2")).alias("multiplication")
        ),
        lambda data_frame: data_frame.agg(pl.col("multiplication").sum()),
    ).collect()


assert part1(example_data).get_column("multiplication")[0] == 161

result = part1(data)

assert result.get_column("multiplication")[0] == 166_357_705

part1(data)

multiplication
i64
166357705


In [3]:
def part2(input: pl.LazyFrame):
    return pl_group_all(
        input.filter(pl.col("is_enabled")).with_columns(
            (pl.col("number1") * pl.col("number2")).alias("multiplication")
        ),
        lambda data_frame: data_frame.agg(pl.col("multiplication").sum()),
    ).collect()


assert part2(example2_data).get_column("multiplication")[0] == 48

result = part2(data)

assert result.get_column("multiplication")[0] == 88_811_886

part2(data)

multiplication
i64
88811886
