In [None]:
# | default_exp math
from nbdev import *
from nbdev.showdoc import *

# Mathy functions

In [None]:
# | exporti
from functools import reduce
import operator
import heapq
from math import sqrt, gcd
import math
from bisect import bisect_left
from itertools import combinations, permutations, chain
from aocutils.special import UnionFind

In [None]:
# | export
def factors(n):
    """
    return set of divisors of a number
    """
    step = 2 if n % 2 else 1
    return set(
        reduce(
            list.__add__,
            ([i, n // i] for i in range(1, int(sqrt(n)) + 1, step) if n % i == 0),
        )
    )

In [None]:
assert factors(20) == {1, 2, 4, 5, 10, 20}

In [None]:
# | export
def gcd(a, b):
    # example gcd(10, 15) == 5
    largest = max(a, b)
    smallest = min(a, b)
    if a == b:
        return a
    if not largest % smallest:
        return smallest
    while True:
        rest = largest % smallest
        if rest == 0:
            return prevrest
        else:
            prevrest = rest
            largest = smallest
            smallest = rest


def lcm(iterable):
    # example 
    lcm = iterable[0]
    for i in iterable[1:]:
        lcm = lcm * i // gcd(lcm, i)
    return lcm

In [None]:
assert gcd(12, 8) == 4
assert gcd(12, 4) == 4
assert gcd(12, 12) == 12
assert lcm([4, 6, 7]) == 84
assert lcm([10, 15]) == 30
assert lcm([5, 7, 11]) == 385

In [None]:
def power(a, b, M=None):
    # computes a**b. Actually python pow does this with optional third argument
    res = 1
    while b:
        if b % 2 == 1:
            res = (res * a) % M if M else res * a
            print("res", res)
        a *= a
        b //= 2
    return res


assert power(3, 2) == pow(3,2)

res 9


In [None]:
# | export
def crt(remainders, moduli):
    """
    Chinese remainder theorem
    """
    cur_rem = remainders[0]
    cur_mod = moduli[0]
    for rem, mod in zip(remainders[1:], moduli[1:]):
        i = 0
        while True:
            if (cur_rem + i * cur_mod) % mod == rem % mod:
                cur_rem += i * cur_mod
                cur_mod = lcm((cur_mod, mod))
                break
            else:
                i += 1
    print("Returning remainder and modulo. First valid number is the remainder")
    return cur_rem, cur_mod

In [None]:
# https://adventofcode.com/2020/day/13

rests, mods = (
    [0, -27, -37, -45, -54, -56, -66, -68, -81],
    [37, 41, 433, 23, 17, 19, 29, 593, 13],
)
assert crt(rests, mods) == (600691418730595, 1090937521514009)

times = [0,-1,-4,-6,-7]
mods = [7,13,59,31,19]
assert crt(times, mods)[0] == 1068781

Returning remainder and modulo. First valid number is the remainder
Returning remainder and modulo. First valid number is the remainder


In [None]:
assert crt((3, 5, 2), (4, 6, 5)) == (47, 60)
assert crt((1, 0, 1, 3), (4, 3, 5, 7)) == (381, 420)
assert crt((1, 1, 0, 3), (4, 5, 3, 7)) == (381, 420)

Returning remainder and modulo. First valid number is the remainder
Returning remainder and modulo. First valid number is the remainder
Returning remainder and modulo. First valid number is the remainder


In [None]:
# | export
def mul_inv(a, b):
    # solves e.g. 40x === 1(mod 7) --> 3
    # since 40-35 --> 5x === 1mod(7), if x would be 3, 15 === 1 (mod 7)
    # this is called the multiplicative inverse, can also be calculated with pow(a, -1, b)
    b0 = b
    x0, x1 = 0, 1
    if b == 1:
        return 1
    while a > 1:
        q = a // b
        a, b = b, a % b
        x0, x1 = x1 - q * x0, x0
    if x1 < 0:
        x1 += b0
    return x1

In [None]:
assert mul_inv(17, 29) == 12
assert mul_inv(40, 7) == 3
pow(17,-1,29), pow(40,-1,7)

(12, 3)

In [None]:
# | export

# first try at implementing a segment tree
class Segment:
    def __init__(self, array, func):
        self.length = len(array)
        self.func = func
        self.data = [0] * self.length + array

        for idx in range(self.length - 1, -1, -1):
            self.data[idx] = self.func(self.data[idx * 2], self.data[idx * 2 + 1])

    def update(self, idx, val):
        idx += self.length
        while idx > 0:
            self.data[idx] = self.func(self.data[idx], val)
            idx //= 2

    def __call__(self, leftidx, rightidx):
        self.query(leftidx, rightidx)

    def query(self, leftidx, rightidx):
        l = leftidx + self.length
        r = rightidx + self.length
        res = self.data[l]
        while l < r:
            if l % 2:
                res = self.func(res, self.data[l])
                l += 1

            if r % 2:
                res = self.func(res, self.data[r - 1])
                r -= 1
            l, r = l // 2, r // 2
        return res

In [None]:
array = [1, 2, 3, 0, 10, 100, 5, 5]
s = Segment(array, min)

In [None]:
s.data


[0, 0, 0, 5, 1, 0, 10, 5, 1, 2, 3, 0, 10, 100, 5, 5]

In [None]:
s.update(7, 2)
s.data

[0, 0, 0, 2, 1, 0, 10, 2, 1, 2, 3, 0, 10, 100, 5, 2]

In [None]:
def lis(nums, increase=True):
    """
    Computes the length of the longest in(de)creasing subsequence
    Implements https://en.wikipedia.org/wiki/Longest_increasing_subsequence
    I don't include it in the module since it's more complex than the next implementation
    although it has a better time complexity
    """
    previousidx = [-1] * len(nums)
    currentidx = []
    current = []

    for i, num in enumerate(nums):
        idx = bisect_left(current, num)
        previousidx[i] = currentidx[idx - 1] if idx else -1
        if idx == len(current):
            current.append(num)
            currentidx.append(i)
        else:
            current[idx] = num
            currentidx[idx] = i
        print(current, currentidx, previousidx)
    
    idxs = [currentidx[-1]]
    while previousidx[idxs[-1]] != -1:
        idxs.append(previousidx[idxs[-1]])
    return list(reversed([nums[idx] for idx in idxs]))

nums = [2, 8, 9, 5, 6, 7, 1]
lis(nums)

[2] [0] [-1, -1, -1, -1, -1, -1, -1]
[2, 8] [0, 1] [-1, 0, -1, -1, -1, -1, -1]
[2, 8, 9] [0, 1, 2] [-1, 0, 1, -1, -1, -1, -1]
[2, 5, 9] [0, 3, 2] [-1, 0, 1, 0, -1, -1, -1]
[2, 5, 6] [0, 3, 4] [-1, 0, 1, 0, 3, -1, -1]
[2, 5, 6, 7] [0, 3, 4, 5] [-1, 0, 1, 0, 3, 4, -1]
[1, 5, 6, 7] [6, 3, 4, 5] [-1, 0, 1, 0, 3, 4, -1]


[2, 5, 6, 7]

In [None]:
# | export
def lis(nums, func=operator.ge):
    """
    Change the func to change it into longest decreasing subsequence. Or other
    Default is operator.ge, which represents longest strictly increasing subsequence
    """
    best = [1] * len(nums)
    bestidx = 0
    bestlength = 1
    parents = {i: False for i in range(len(nums))}
    ans = []
    for i in range(1, len(nums)):
        for j in range(i):
            if func(nums[i],nums[j]):
                if best[i] < best[j] + 1: # best[i] can be improved
                    best[i] = best[j] + 1
                    parents[i] = j
                    if best[i] > bestlength:
                        bestlength = best[i]
                        bestidx = i
    ans = [nums[bestidx]]
    while parents[bestidx] is not False:
        bestidx = parents[bestidx]
        ans.append(nums[bestidx])
    return list(reversed(ans)) 

In [None]:
# | export
def angle(a,b):
    dx = b[0] - a[0]
    dy = b[1] - a[1]
    return math.degrees(math.atan2(dx,dy))
    



In [None]:
a = (1,1)
b = (2,-1)
angle(a,b)

153.434948822922

In [None]:
nums = [2, 8, 9, 5,5, 6, 7, 1]
lis(nums, operator.ge)

[2, 5, 5, 6, 7]

In [None]:
# | export
def all_combinations(it, start=None, end=None):
    """
    Returns all combinations from start to end (inclusive).
    Defaults to 1, len(end)
    """
    if not start: start = 1
    if not end: end = len(it)
    assert 0 < start < len(it)
    assert start < end
    assert 0 < end <= len(it)
    for i in range(start, end + 1):
        for comb in combinations(it, i):
            yield comb

In [None]:
list(all_combinations([1,2,3],1,3))

[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]

In [None]:
# | export
def all_permutations(it, start=None, end=None):
    if not start: start = 1
    if not end: end = len(it)
    assert 0 < start < len(it)
    assert start < end
    assert 0 < end <= len(it)
    for i in range(start, end + 1):
        for perm in permutations(it, i):
            yield perm

In [None]:
assert list(all_combinations(range(4),2,3)) == [(0, 1),
 (0, 2),
 (0, 3),
 (1, 2),
 (1, 3),
 (2, 3),
 (0, 1, 2),
 (0, 1, 3),
 (0, 2, 3),
 (1, 2, 3)]

In [None]:
# | export
def mst(edges):
    # implements kruskall with unionfind
    edges.sort(key=lambda x: (-x[2]))
    nodes = set(chain.from_iterable([[a, b] for a, b, cost in edges]))
    uf = UnionFind(nodes)
    totalcost = 0
    while not uf.is_spanning() and edges:
        a, b, cost = edges.pop()
        if uf.get_parent(a) != uf.get_parent(b):
            uf.union(a, b)
            totalcost += cost
    
    return totalcost if uf.is_spanning() else False

In [None]:
edges = [(1, 2, 3), (1, 3, 5), (2, 3, 1), (0, 1, 10)]
assert mst(edges) == 14