In [1]:
# | default_exp math
from nbdev import *
from nbdev.showdoc import *

# Mathy functions

In [2]:
# | exporti
from functools import reduce
from math import sqrt, gcd
from bisect import bisect_left
from itertools import combinations, permutations, chain
from aocutils.special import UnionFind
import heapq

In [3]:
# | export
def factors(n):
    """
    return set of divisors of a number
    """
    step = 2 if n % 2 else 1
    return set(
        reduce(
            list.__add__,
            ([i, n // i] for i in range(1, int(sqrt(n)) + 1, step) if n % i == 0),
        )
    )

In [4]:
assert factors(20) == {1, 2, 4, 5, 10, 20}

In [5]:
# | export
def gcd(a, b):
    largest = max(a, b)
    smallest = min(a, b)
    if a == b:
        return a
    if not largest % smallest:
        return smallest
    while True:
        rest = largest % smallest
        if rest == 0:
            return prevrest
        else:
            prevrest = rest
            largest = smallest
            smallest = rest


def lcm(a):
    lcm = a[0]
    for i in a[1:]:
        lcm = lcm * i // gcd(lcm, i)
    return lcm

In [6]:
assert gcd(12, 8) == 4
assert gcd(12, 4) == 4
assert gcd(12, 12) == 12
assert lcm([4, 6, 7]) == 84

In [7]:
def power(a, b, M=None):
    # computes a**b. Actually python pow does this with optional third argument
    res = 1
    while b:
        if b % 2 == 1:
            res = (res * a) % M if M else res * a
            print("res", res)
        a *= a
        b //= 2
    return res


power(3, 2)

res 9


9

In [8]:
# | export
def crt(remainders, moduli):
    """
    Chinese remainder theorem
    """
    cur_rem = remainders[0]
    cur_mod = moduli[0]
    for rem, mod in zip(remainders[1:], moduli[1:]):
        i = 0
        while True:
            if (cur_rem + i * cur_mod) % mod == rem % mod:
                cur_rem += i * cur_mod
                cur_mod = lcm((cur_mod, mod))
                break
            else:
                i += 1
    print("Returning remainder and modulo. First valid number is the remainder")
    return cur_rem, cur_mod

In [9]:
rests, mods = (
    [0, -27, -37, -45, -54, -56, -66, -68, -81],
    [37, 41, 433, 23, 17, 19, 29, 593, 13],
)
assert crt(rests, mods) == (600691418730595, 1090937521514009)

Returning remainder and modulo. First valid number is the remainder


In [10]:
assert crt((3, 5, 2), (4, 6, 5)) == (47, 60)
assert crt((1, 0, 1, 3), (4, 3, 5, 7)) == (381, 420)
assert crt((1, 1, 0, 3), (4, 5, 3, 7)) == (381, 420)

Returning remainder and modulo. First valid number is the remainder
Returning remainder and modulo. First valid number is the remainder
Returning remainder and modulo. First valid number is the remainder


In [11]:
lcm([5, 7, 11])

385

In [12]:
crt([4, 4, 6], [5, 7, 11])

Returning remainder and modulo. First valid number is the remainder


(39, 385)

In [13]:
# | export
def mul_inv(a, b):
    # solves e.g. 40x === 1(mod 7) --> 3
    # since 40-35 --> 5x === 1mod(7), if x would be 3, 15 === 1 (mod 7)
    # this is called the multiplicative inverse
    b0 = b
    x0, x1 = 0, 1
    if b == 1:
        return 1
    while a > 1:
        q = a // b
        a, b = b, a % b
        x0, x1 = x1 - q * x0, x0
    if x1 < 0:
        x1 += b0
    return x1

In [14]:
assert mul_inv(17, 29) == 12
assert mul_inv(40, 7) == 3

In [15]:
# | export

# first try at implementing a segment tree
class Segment:
    def __init__(self, array, func):
        self.length = len(array)
        self.func = func
        self.data = [0] * self.length + array

        for idx in range(self.length - 1, -1, -1):
            self.data[idx] = self.func(self.data[idx * 2], self.data[idx * 2 + 1])

    def update(self, idx, val):
        idx += self.length
        while idx > 0:
            self.data[idx] = self.func(self.data[idx], val)
            idx //= 2

    def __call__(self, leftidx, rightidx):
        self.query(leftidx, rightidx)

    def query(self, leftidx, rightidx):
        l = leftidx + self.length
        r = rightidx + self.length
        res = self.data[l]
        while l < r:
            if l % 2:
                res = self.func(res, self.data[l])
                l += 1

            if r % 2:
                res = self.func(res, self.data[r - 1])
                r -= 1
            l, r = l // 2, r // 2
        return res

In [16]:
array = [1, 2, 3, 0, 10, 100, 5, 5]
s = Segment(array, min)

In [17]:
s.data


[0, 0, 0, 5, 1, 0, 10, 5, 1, 2, 3, 0, 10, 100, 5, 5]

In [18]:
s.update(7, 2)
s.data

[0, 0, 0, 2, 1, 0, 10, 2, 1, 2, 3, 0, 10, 100, 5, 2]

In [19]:
# | export
def lis(nums, increase=True):
    """
    Computes the length of the longest in(de)creasing subsequence
    """
    previous = [-1] * len(nums)
    current = []
    ans = []

    for i, num in enumerate(nums):
        idx = bisect_left(current, num)
        previous[i] = current[idx - 1] if idx else -1
        if idx == len(current):
            current.append(num)
            ans.append(num)
        else:
            current[idx] = num
        print(current)

    return ans

In [20]:
lis([2, 5, 3, 4, 0])

[2]
[2, 5]
[2, 3]
[2, 3, 4]
[0, 3, 4]


[2, 5, 4]

In [21]:
lis([2, 5, 3, 4])

[2]
[2, 5]
[2, 3]
[2, 3, 4]


[2, 5, 4]

In [22]:
# | export
def all_combinations(it, n=None):
    if not n:
        n = len(it) - 1
    for i in range(1, n + 1):
        for comb in combinations(it, i):
            yield comb

In [23]:
# | export
def all_permutations(it, n=None):
    if not n:
        n = len(it) - 1
    for i in range(1, n + 1):
        for perm in permutations(it, i):
            yield perm

In [24]:
list(all_combinations(range(4), 2))

[(0,), (1,), (2,), (3,), (0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]

In [25]:
edges = [(1, 2, 3), (1, 3, 5), (2, 3, 0), (0, 1, 10)]
costs = []
nodes = set()
totalcost = 0
for a, b, cost in edges:
    heapq.heappush(costs, (cost, a, b))
    nodes.add(a)
    nodes.add(b)
print(nodes)
while nodes:
    cost, a, b = heapq.heappop(costs)
    if a in nodes or b in nodes:
        nodes.discard(a)
        nodes.discard(b)
        totalcost += cost
cost

{0, 1, 2, 3}


10

In [26]:
def mst(edges):
    # implements kruskall with unionfind
    edges.sort(key=lambda x: (-x[2]))
    nodes = set(chain.from_iterable([[a, b] for a, b, cost in edges]))

    uf = UnionFind(nodes)
    totalcost = 0
    while not uf.is_spanning():
        a, b, cost = edges.pop()
        if uf.get_parent(a) != uf.get_parent(b):
            uf.union(a, b)
            totalcost += cost
    return totalcost


edges = [(1, 2, 3), (1, 3, 5), (2, 3, 1), (0, 1, 10)]
mst(edges)

14

In [27]:
from itertools import chain

edges = [(1, 2, 3), (1, 3, 5), (2, 3, 1), (0, 1, 10)]
test = [[a, b] for a, b, c in edges]
set(chain.from_iterable(test))

{0, 1, 2, 3}