In [1]:
import functools
from pathlib import Path

In [2]:
test_input = """[1,1,3,1,1]
[1,1,5,1,1]

[[1],[2,3,4]]
[[1],4]

[9]
[[8,7,6]]

[[4,4],4,4]
[[4,4],4,4,4]

[7,7,7,7]
[7,7,7]

[]
[3]

[[[]]]
[[]]

[1,[2,[3,[4,[5,6,7]]]],8,9]
[1,[2,[3,[4,[5,6,0]]]],8,9]"""

In [3]:
def parse_input(packet_data, pairs=True):
    if pairs:
        return [tuple(map(eval, packet.split("\n"))) for packet in packet_data.strip().split("\n\n")]
    else:
        return [eval(packet) for packet in packet_data.strip().split("\n") if packet] + [[[2]], [[6]]]

def in_order(a, b):
    #print(f"Compare: {a} vs {b}")
    match a, b:
        case int(a), int(b):
            if a > b:
                return 1
            elif a < b:
                return -1
            else:
                return 0
        case list(a), list(b):
            for a2, b2 in zip(a, b):
                comparison = in_order(a2, b2)
                if abs(comparison) == 1:
                    return comparison
            if len(a) < len(b):
                #print("Left ran out of items")
                return -1
            elif len(a) > len(b):
                #print("Right ran out of items")
                return +1
            return 0
        case list(a), int(b):
            return in_order(a, [b])
        case int(a), list(b):
            return in_order([a], b)
        
def decode_key(packets):
    packets.sort(key=functools.cmp_to_key(in_order))
    divider_1 = packets.index([[2]]) + 1
    divider_2 = packets.index([[6]]) + 1
    return divider_1 * divider_2

packets = parse_input(test_input)

assert in_order(*packets[0]) == -1
assert in_order(*packets[1]) == -1
assert in_order(*packets[2]) == 1
assert in_order(*packets[3]) == -1
assert in_order(*packets[4]) == 1
assert in_order(*packets[5]) == -1
assert in_order(*packets[6]) == 1
assert in_order(*packets[7]) == 1

assert len(parse_input(test_input, pairs=False)) == 18

In [4]:
# Part 1 - test
packets = parse_input(test_input)
assert sum([n for n, packet in enumerate(packets, start=1) if in_order(*packet) == -1]) == 13

In [5]:
# Part 1
packets = parse_input(Path("input.txt").read_text())
print(sum([n for n, packet in enumerate(packets, start=1) if in_order(*packet) == -1]))

5503


In [6]:
# Part 2 - test
packets = parse_input(test_input, pairs=False)
assert decode_key(packets) == 140

In [7]:
# Part 2 - test
packets = parse_input(Path("input.txt").read_text(), pairs=False)
print(decode_key(packets))

20952
