In [3]:
def parse_file(file):
    with open(file, "r") as f:
        rules, updates = f.read().split("\n\n")
    rules = filter(lambda x: len(x) > 0, rules.split("\n"))
    rules = [tuple(map(int, rule.split("|"))) for rule in rules]
    updates = filter(lambda x: len(x.strip()) > 0, updates.split("\n"))
    updates = [tuple(map(int, update.split(","))) for update in updates]
    return rules, updates


import pandas as pd


def create_graph(rules):
    graph = pd.DataFrame(rules)
    graph.columns = ["smaller", "greater"]
    n_to_greaters = graph.groupby("smaller")["greater"].agg(set).to_dict()
    n_to_smallers = graph.groupby("greater")["smaller"].agg(set).to_dict()
    return n_to_greaters, n_to_smallers

In [4]:
def check_update_valid(update, n_to_greaters, n_to_smallers):
    for i in range(len(update)):
        number = update[i]
        prev_numbers = set(update[:i])
        # print(prev_numbers,  n_to_greaters.get(number, set()), prev_numbers & n_to_greaters.get(number, set()))
        if prev_numbers and (prev_numbers & n_to_greaters.get(number, set())):
            return False
        next_numbers = set(update[i + 1 :])
        # print(next_numbers, n_to_smallers.get(number, set()), next_numbers & n_to_smallers.get(number, set()))
        if next_numbers and (next_numbers & n_to_smallers.get(number, set())):
            return False
    return True


def total_middle_valid(updates, n_to_greaters, n_to_smallers):
    total = 0
    for update in updates:
        is_valid = check_update_valid(update, n_to_greaters, n_to_smallers)
        if is_valid:
            central_number = update[len(update) // 2]
            total += central_number
    return total


rules, updates = parse_file("input.txt")
n_to_greaters, n_to_smallers = create_graph(rules)
print(n_to_greaters)
print(n_to_smallers)
print(total_middle_valid(updates, n_to_greaters, n_to_smallers))

{11: {16, 17, 24, 25, 26, 28, 29, 31, 32, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 84, 87, 93, 96, 97}, 12: {11, 16, 17, 22, 24, 27, 28, 29, 31, 43, 47, 52, 55, 56, 57, 63, 67, 76, 77, 78, 84, 87, 93, 97}, 14: {11, 12, 22, 27, 33, 38, 44, 45, 49, 51, 52, 58, 63, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 98}, 16: {14, 17, 25, 26, 29, 32, 34, 43, 44, 45, 47, 51, 56, 58, 67, 75, 76, 77, 79, 84, 87, 88, 96, 97}, 17: {14, 25, 26, 32, 33, 34, 38, 44, 45, 49, 51, 58, 64, 65, 72, 73, 75, 79, 81, 82, 88, 94, 96, 98}, 22: {11, 16, 17, 24, 25, 27, 28, 29, 31, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 84, 87, 93, 96, 97}, 24: {14, 16, 17, 25, 26, 29, 32, 34, 43, 44, 45, 47, 51, 56, 58, 67, 75, 76, 77, 79, 84, 87, 96, 97}, 25: {12, 14, 26, 32, 33, 34, 38, 44, 45, 49, 51, 58, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 96, 98}, 26: {12, 14, 22, 32, 33, 38, 44, 45, 49, 51, 52, 58, 63, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 98}, 27: {11, 16, 17, 24, 25, 26, 28, 29, 31, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 

In [5]:
class Code(int):
    def __init__(self, value):
        self.value = value

    def __lt__(self, other):
        # define in negative way
        # if self < other, then not (self > other)
        # other < self
        x = self.value
        y = other.value

        """
        y in n_to_smallers.get(x, set()) -> y < x
        y in n_to_greaters.get(x, set()) -> x < y
        x in n_to_smallers.get(y, set()) -> x < y
        x in n_to_greaters.get(y, set()) -> y < x
        """
        if (y in n_to_greaters.get(x, set())) or (x in n_to_smallers.get(y, set())):
            return True
        elif (y in n_to_smallers.get(x, set())) or (x in n_to_greaters.get(y, set())):
            return False
        elif self.value == other.value:
            return False
        else:
            return x < y

    def __gt__(self, other):
        return other.__lt__(self)

    def __eq__(self, other):
        return self.value == other.value

    def __repr__(self) -> str:
        return f"<{self.value}>"

    def __hash__(self) -> int:
        return hash(self.value)


def total_middle_valid_correcting(updates, n_to_greaters, n_to_smallers):
    total = 0
    for update in updates:
        is_valid = check_update_valid(update, n_to_greaters, n_to_smallers)
        if is_valid:
            central_number = 0
            # central_number = update[len(update)//2]
        else:
            codes = [Code(i) for i in update]
            codes.sort()
            central_number = codes[len(codes) // 2].value

        total += central_number

    return total


In [6]:
rules, updates = parse_file("input.txt")
n_to_greaters, n_to_smallers = create_graph(rules)
print(n_to_greaters)
print(n_to_smallers)

print(total_middle_valid_correcting(updates, n_to_greaters, n_to_smallers))

{11: {16, 17, 24, 25, 26, 28, 29, 31, 32, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 84, 87, 93, 96, 97}, 12: {11, 16, 17, 22, 24, 27, 28, 29, 31, 43, 47, 52, 55, 56, 57, 63, 67, 76, 77, 78, 84, 87, 93, 97}, 14: {11, 12, 22, 27, 33, 38, 44, 45, 49, 51, 52, 58, 63, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 98}, 16: {14, 17, 25, 26, 29, 32, 34, 43, 44, 45, 47, 51, 56, 58, 67, 75, 76, 77, 79, 84, 87, 88, 96, 97}, 17: {14, 25, 26, 32, 33, 34, 38, 44, 45, 49, 51, 58, 64, 65, 72, 73, 75, 79, 81, 82, 88, 94, 96, 98}, 22: {11, 16, 17, 24, 25, 27, 28, 29, 31, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 84, 87, 93, 96, 97}, 24: {14, 16, 17, 25, 26, 29, 32, 34, 43, 44, 45, 47, 51, 56, 58, 67, 75, 76, 77, 79, 84, 87, 96, 97}, 25: {12, 14, 26, 32, 33, 34, 38, 44, 45, 49, 51, 58, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 96, 98}, 26: {12, 14, 22, 32, 33, 38, 44, 45, 49, 51, 52, 58, 63, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 98}, 27: {11, 16, 17, 24, 25, 26, 28, 29, 31, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 

## Second Form: Using a dictionary (not working)

The binary relations don't allow to build a [Total order](https://en.wikipedia.org/wiki/Total_order) relationship,
since for the input.txt some terms break the transivity property

In [22]:
a = 16
b = 14
c = 82


def manual_smaller(a, b, n_to_greaters, n_to_smallers):
    # a < b
    c1 = b in n_to_greaters.get(a, set())
    c2 = a in n_to_smallers.get(b, set())
    if c1 or c2:
        return True
    # b < a
    c3 = a in n_to_greaters.get(b, set())
    c4 = b in n_to_smallers.get(a, set())
    if c3 or c4:
        return False


res1 = manual_smaller(a, b, n_to_greaters, n_to_smallers)
print(f"{a} < {b} = {res1}")
res5 = manual_smaller(b, c, n_to_greaters, n_to_smallers)
print(f"{b} < {c} = {res5}")
res3 = manual_smaller(a, c, n_to_greaters, n_to_smallers)
print(f"{a} < {c} = {res3}")
# 82 < 16
res2 = manual_smaller(c, a, n_to_greaters, n_to_smallers)
print(f"{c} < {a} = {res2}")

16 < 14 = True
14 < 82 = True
16 < 82 = False
82 < 16 = True


In [26]:
# find rules that contains at last two out of the three numbers <16> < <14> < <82>
for rule in rules:
    if (a in rule) + (b in rule) + (c in rule) >= 2:
        print(f"rule: {rule[0]} < {rule[1]}")

rule: 14 < 82
rule: 16 < 14
rule: 82 < 16


In [8]:
class CodeFixed(int):
    def __init__(self, value):
        self.value = value

    def __lt__(self, other):
        return not Code(self.value) >= Code(other.value)

    def __gt__(self, other):
        return not Code(self.value) <= Code(other.value)

    def __eq__(self, other):
        return self.value == other.value

    def __le__(self, other):
        return Code(self.value) <= Code(other.value)

    def __ge__(self, other):
        return Code(self.value) >= Code(other.value)

    def __repr__(self) -> str:
        return f"<{self.value}>"

In [9]:
def sort_basic(update):
    codes = [Code(i) for i in update]
    codes.sort()
    return tuple(code.value for code in codes)


def get_all_numbers_sorted(rules):
    all_numbers = set()
    for n, greaters in n_to_greaters.items():
        all_numbers.add(n)
        all_numbers.update(greaters)
    for n, smallers in n_to_smallers.items():
        all_numbers.add(n)
        all_numbers.update(smallers)
    sorted_numbers = sort_basic(all_numbers)
    return sorted_numbers

In [10]:
rules, updates = parse_file("input.txt")
n_to_greaters, n_to_smallers = create_graph(rules)
print(n_to_greaters)
print(n_to_smallers)

print(total_middle_valid_correcting(updates, n_to_greaters, n_to_smallers))

{11: {16, 17, 24, 25, 26, 28, 29, 31, 32, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 84, 87, 93, 96, 97}, 12: {11, 16, 17, 22, 24, 27, 28, 29, 31, 43, 47, 52, 55, 56, 57, 63, 67, 76, 77, 78, 84, 87, 93, 97}, 14: {11, 12, 22, 27, 33, 38, 44, 45, 49, 51, 52, 58, 63, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 98}, 16: {14, 17, 25, 26, 29, 32, 34, 43, 44, 45, 47, 51, 56, 58, 67, 75, 76, 77, 79, 84, 87, 88, 96, 97}, 17: {14, 25, 26, 32, 33, 34, 38, 44, 45, 49, 51, 58, 64, 65, 72, 73, 75, 79, 81, 82, 88, 94, 96, 98}, 22: {11, 16, 17, 24, 25, 27, 28, 29, 31, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 84, 87, 93, 96, 97}, 24: {14, 16, 17, 25, 26, 29, 32, 34, 43, 44, 45, 47, 51, 56, 58, 67, 75, 76, 77, 79, 84, 87, 96, 97}, 25: {12, 14, 26, 32, 33, 34, 38, 44, 45, 49, 51, 58, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 96, 98}, 26: {12, 14, 22, 32, 33, 38, 44, 45, 49, 51, 52, 58, 63, 64, 65, 72, 73, 78, 79, 81, 82, 88, 94, 98}, 27: {11, 16, 17, 24, 25, 26, 28, 29, 31, 34, 43, 47, 55, 56, 57, 67, 75, 76, 77, 

In [11]:
sorted_numbers = get_all_numbers_sorted(rules)

In [23]:
def is_total_order(cod_objects, order_relation):
    """
    Check if the given objects define a total order under the given order relation.

    Parameters
    ----------
        cod_objects (list)
        A list of Cod objects to check.
        order_relation (function)
        A function taking two Cod objects and returning True if the first is less than the second.

    Returns
    -------
        bool: True if the objects define a total order, False otherwise.


    References
    ----------
    https://en.wikipedia.org/wiki/Total_order
    """
    n = len(cod_objects)

    output = True
    # Check totality and transitivity
    for i in range(n):
        a = cod_objects[i]
        # Reflexivity: a <= a
        if not order_relation(a, a):
            print(f"Failed reflexivity: {a} and {a}")
            output = False

        for j in range(n):
            b = cod_objects[j]

            # Antisymmetry: if a <= b and b <= a, then a == b
            # (a and b -> c) equiv (!a or !b or c) -> (negating) (a and b and !c)
            if order_relation(a, b) and order_relation(b, a) and (not (a == b)):
                print(f"Failed antisymmetry: {a} and {b}")
                output = False

            # Transitivity: if a < b and b < c, then a < c
            for k in range(n):
                c = cod_objects[k]
                if (
                    order_relation(a, b)
                    and order_relation(b, c)
                    and not order_relation(a, c)
                ):
                    print(f"Failed transitivity: {a} < {b} < {c}")
                    output = False

            # Strongly connected: a <= b or b <= a
            if not (order_relation(a, b) or order_relation(b, a)):
                print(f"Failed strongly connected: {a} and {b}")
                output = False

    return output


def is_strict_total_order(strings, strict_order):
    """
    Check if the given strings define a strict total order under the given strict order relation.

    Parameters
    ----------
        strings (list of str): A list of strings to check.
        strict_order (function): A function taking two strings and returning True if the first is strictly less than the second.

    Returns
    -------
        bool: True if the strings define a strict total order, False otherwise.

    References
    ----------
    https://en.wikipedia.org/wiki/Total_order
    """
    n = len(strings)

    output = True
    # Check irreflexivity, asymmetry, transitivity, and connectedness
    for i in range(n):
        a = strings[i]

        # Irreflexivity: Not a < a
        if strict_order(a, a):
            print(f"Failed irreflexivity: {a} < {a}")
            output = False

        for j in range(n):
            b = strings[j]

            # Asymmetry: if a < b, then not b < a
            if strict_order(a, b) and strict_order(b, a):
                print(f"Failed asymmetry: {a} < {b} and {b} < {a}")
                output = False

            # Connectedness: if a != b, then a < b or b < a
            if a != b and not (strict_order(a, b) or strict_order(b, a)):
                print(f"Failed connectedness: {a} and {b} are not comparable")
                output = False

            for k in range(n):
                c = strings[k]

                # Transitivity: if a < b and b < c, then a < c
                if strict_order(a, b) and strict_order(b, c) and not strict_order(a, c):
                    print(f"Failed transitivity: {a} < {b} < {c}")
                    output = False

    return output


order_relation = lambda a, b: a <= b
strict_order_relation = lambda a, b: a < b


is_total_order([Code(i) for i in sorted_numbers], order_relation)
is_strict_total_order([Code(i) for i in sorted_numbers], strict_order_relation)
sorted_codes = [Code(i) for i in sorted_numbers]
sorted_codes.sort()
print(sorted_codes)

Failed transitivity: <16> < <14> < <82>
Failed transitivity: <16> < <14> < <72>
Failed transitivity: <16> < <14> < <94>
Failed transitivity: <16> < <14> < <73>
Failed transitivity: <16> < <14> < <64>
Failed transitivity: <16> < <14> < <98>
Failed transitivity: <16> < <14> < <81>
Failed transitivity: <16> < <14> < <49>
Failed transitivity: <16> < <14> < <38>
Failed transitivity: <16> < <14> < <33>
Failed transitivity: <16> < <14> < <65>
Failed transitivity: <16> < <14> < <12>
Failed transitivity: <16> < <14> < <78>
Failed transitivity: <16> < <14> < <63>
Failed transitivity: <16> < <14> < <52>
Failed transitivity: <16> < <14> < <22>
Failed transitivity: <16> < <14> < <27>
Failed transitivity: <16> < <14> < <11>
Failed transitivity: <16> < <88> < <82>
Failed transitivity: <16> < <88> < <72>
Failed transitivity: <16> < <88> < <94>
Failed transitivity: <16> < <88> < <73>
Failed transitivity: <16> < <88> < <64>
Failed transitivity: <16> < <88> < <98>
Failed transitivity: <16> < <88> < <81>


In [13]:
order_relation = lambda a, b: a <= b
strict_order_relation = lambda a, b: a < b


test1 = is_total_order([CodeFixed(i) for i in sorted_numbers], order_relation)
print(test1)
test2 = is_strict_total_order(
    [CodeFixed(i) for i in sorted_numbers], strict_order_relation
)
print(test2)

fixed_codes = [CodeFixed(i) for i in sorted_numbers]
fixed_codes.sort()
print(fixed_codes)

True
True
[<11>, <12>, <14>, <16>, <17>, <22>, <24>, <25>, <26>, <27>, <28>, <29>, <31>, <32>, <33>, <34>, <38>, <43>, <44>, <45>, <47>, <49>, <51>, <52>, <55>, <56>, <57>, <58>, <63>, <64>, <65>, <67>, <72>, <73>, <75>, <76>, <77>, <78>, <79>, <81>, <82>, <84>, <87>, <88>, <93>, <94>, <96>, <97>, <98>]


In [14]:
a = 16
b = 14
c = 82
print(Code(a) > Code(b))
print(Code(b) > Code(c))
print(Code(a) > Code(c))

False
False
True


In [15]:
sorted_numbers = get_all_numbers_sorted(rules)
print(sorted_numbers)
print(check_update_valid(sorted_numbers, n_to_greaters, n_to_smallers))
number_to_order = {code: i for i, code in enumerate(sorted_numbers)}
print(number_to_order)

(16, 14, 88, 82, 72, 94, 73, 64, 98, 81, 49, 38, 33, 65, 12, 78, 63, 52, 22, 27, 11, 31, 93, 57, 28, 55, 24, 97, 29, 76, 77, 56, 43, 47, 67, 87, 84, 17, 75, 25, 96, 34, 26, 32, 45, 44, 79, 58, 51)
False
{16: 0, 14: 1, 88: 2, 82: 3, 72: 4, 94: 5, 73: 6, 64: 7, 98: 8, 81: 9, 49: 10, 38: 11, 33: 12, 65: 13, 12: 14, 78: 15, 63: 16, 52: 17, 22: 18, 27: 19, 11: 20, 31: 21, 93: 22, 57: 23, 28: 24, 55: 25, 24: 26, 97: 27, 29: 28, 76: 29, 77: 30, 56: 31, 43: 32, 47: 33, 67: 34, 87: 35, 84: 36, 17: 37, 75: 38, 25: 39, 96: 40, 34: 41, 26: 42, 32: 43, 45: 44, 44: 45, 79: 46, 58: 47, 51: 48}


In [16]:
def insertion_sort(numbers):
    numbers = [Code(n) for n in numbers]
    for i in range(1, len(numbers)):
        j = i
        while j > 0 and numbers[j] < numbers[j - 1]:
            numbers[j], numbers[j - 1] = numbers[j - 1], numbers[j]
            j -= 1

    output = tuple(code.value for code in numbers)
    assert check_update_valid(output, n_to_greaters, n_to_smallers)
    return output


print(sort_basic((16, 82, 14)))
print(sort_basic((82, 14, 16)))

print(insertion_sort((82, 14, 16)))
print(insertion_sort((16, 82, 14)))


(14, 82, 16)
(16, 14, 82)


AssertionError: 

In [None]:
for i in range(len(sorted_numbers)):
    for j in range(i + 1, len(sorted_numbers)):
        if not (Code(sorted_numbers[i]) < Code(sorted_numbers[j])):
            print(sorted_numbers[i], sorted_numbers[j])

16 82
16 72
16 94
16 73
16 64
16 98
16 81
16 49
16 38
16 33
16 65
16 12
16 78
16 63
16 52
16 22
16 27
16 11
16 31
16 93
16 57
16 28
16 55
16 24
14 31
14 93
14 57
14 28
14 55
14 24
14 97
14 29
14 76
14 77
14 56
14 43
14 47
14 67
14 87
14 84
14 17
14 75
14 25
14 96
14 34
14 26
14 32
88 97
88 29
88 76
88 77
88 56
88 43
88 47
88 67
88 87
88 84
88 17
88 75
88 25
88 96
88 34
88 26
88 32
88 45
88 44
88 79
88 58
88 51
82 97
82 29
82 76
82 77
82 56
82 43
82 47
82 67
82 87
82 84
82 17
82 75
82 25
82 96
82 34
82 26
82 32
82 45
82 44
82 79
82 58
82 51
72 29
72 76
72 77
72 56
72 43
72 47
72 67
72 87
72 84
72 17
72 75
72 25
72 96
72 34
72 26
72 32
72 45
72 44
72 79
72 58
72 51
94 76
94 77
94 56
94 43
94 47
94 67
94 87
94 84
94 17
94 75
94 25
94 96
94 34
94 26
94 32
94 45
94 44
94 79
94 58
94 51
73 77
73 56
73 43
73 47
73 67
73 87
73 84
73 17
73 75
73 25
73 96
73 34
73 26
73 32
73 45
73 44
73 79
73 58
73 51
64 56
64 43
64 47
64 67
64 87
64 84
64 17
64 75
64 25
64 96
64 34
64 26
64 32
64 45
64 44
64 7

In [None]:
def sort_dict(update, number_to_order):
    return tuple(sorted(update, key=lambda x: number_to_order[x]))


test = (28, 55, 24, 77, 17, 14, 44)
print(sort_basic(test))
print(sort_dict(test, number_to_order))

(28, 55, 24, 77, 17, 14, 44)
(14, 28, 55, 24, 77, 17, 44)


In [None]:
test = (28, 55, 24, 77, 17, 14, 44)
codes = [Code(i) for i in test]
codes.sort()
print(codes)
test_sorted = tuple(sorted(test, key=lambda x: number_to_order[x]))
print(test_sorted)

[<28>, <55>, <24>, <77>, <17>, <14>, <44>]
(14, 28, 55, 24, 77, 17, 44)


In [None]:
def total_middle_number(updates, number_to_order, n_to_greaters, n_to_smallers):
    total = 0
    for update in updates:
        sorted_update = tuple(sorted(update, key=lambda x: number_to_order[x]))
        is_equal = sorted_update == update
        is_valid = check_update_valid(update, n_to_greaters, n_to_smallers)
        if not is_equal and is_valid:
            print(update, sorted_update)
        if sorted_update != update:
            central_number = sorted_update[len(sorted_update) // 2]
            total += central_number

    return total


print(total_middle_number(updates, number_to_order, n_to_greaters, n_to_smallers))

(28, 55, 24, 77, 17, 14, 44) (14, 28, 55, 24, 77, 17, 44)
(87, 84, 17, 25, 45, 44, 88) (88, 87, 84, 17, 25, 45, 44)
(29, 43, 47, 87, 75, 25, 34, 26, 32, 79, 82) (82, 29, 43, 47, 87, 75, 25, 34, 26, 32, 79)
(47, 67, 84, 75, 25, 34, 26, 32, 14, 44, 79, 58, 82, 72, 94, 64, 98) (14, 82, 72, 94, 64, 98, 47, 67, 84, 75, 25, 34, 26, 32, 44, 79, 58)
(72, 73, 98, 81, 78, 22, 27, 11, 57, 28, 55, 24, 16) (16, 72, 73, 98, 81, 78, 22, 27, 11, 57, 28, 55, 24)
(52, 22, 57, 55, 24, 16, 76, 77, 43, 47, 67, 87, 25) (16, 52, 22, 57, 55, 24, 76, 77, 43, 47, 67, 87, 25)
(77, 56, 47, 87, 84, 17, 25, 34, 32, 14, 45, 44, 58, 88, 82, 72, 94) (14, 88, 82, 72, 94, 77, 56, 47, 87, 84, 17, 25, 34, 32, 45, 44, 58)
(38, 33, 65, 12, 78, 63, 52, 22, 27, 11, 31, 93, 57, 28, 55, 24, 16, 97, 76, 77, 56, 43, 47) (16, 38, 33, 65, 12, 78, 63, 52, 22, 27, 11, 31, 93, 57, 28, 55, 24, 97, 76, 77, 56, 43, 47)
(43, 47, 67, 87, 75, 96, 14, 45, 79, 58, 51, 72, 94, 73, 64) (14, 72, 94, 73, 64, 43, 47, 67, 87, 75, 96, 45, 79, 58, 51