In [None]:
from bisect import bisect_left as bl, bisect_right as br
from collections import defaultdict, namedtuple
from functools import reduce
from heapq import heappush, heappop, heapify
from itertools import product, permutations, combinations
from math import gcd, floor, ceil, sqrt, log
import operator as op
import string
import sys

sys.setrecursionlimit(100000000)

MOD = 1000000007

# Array

In [None]:
def next_permutation(L):
    """ 
    Permute the list L in-place to generate the next lexicographic permutation.
    Return True if such a permutation exists, else return False.
    """
    n = len(L)
 
    #------------------------------------------------------------
    # Step 1: find rightmost position i such that L[i] < L[i+1]
    i = n - 2
    while i >= 0 and L[i] >= L[i+1]:
        i -= 1
    if i == -1:
        return False
    #------------------------------------------------------------
    # Step 2: find rightmost position j to the right of i such that L[j] > L[i]
    j = i + 1
    while j < n and L[j] > L[i]:
        j += 1
    j -= 1
    #------------------------------------------------------------
    # Step 3: swap L[i] and L[j]
    L[i], L[j] = L[j], L[i]
    #------------------------------------------------------------
    # Step 4: reverse everything to the right of i
    left = i + 1
    right = n - 1
    while left < right:
        L[left], L[right] = L[right], L[left]
        left += 1
        right -= 1
             
    return True

# Graph

In [None]:
from heapq import heappush, heappop

def dijkstra(n, graph, start, end):
    """O(E log V)"""
    priority_queue = [(0, start, -1)]
    costs, prev = [-1] * n, [-1] * n

    while priority_queue:
        cost, v, p = heappop(priority_queue)
        if costs[v] < 0:
            costs[v] = cost
            prev[v] = p
            if v == end:
                break
            for w, edge_cost in graph[v]:
                heappush(priority_queue, (cost + edge_cost, w, v))

    path = [end]
    while prev[path[-1]] >= 0:
        path.append(prev[path[-1]])
    return costs[end], path

In [None]:
class DisjointSet(object):
    def __init__(self, n, edge_set=[]):
        self.roots = list(range(n))
        self.ranks = [0] * n
        self.nrof_sets = n
        for edge in edge_set:
            x, y = edge[:2]
            self.union(x, y)

    def find(self, x):
        if self.roots[x] != x:
            self.roots[x] = self.find(self.roots[x])
        return self.roots[x]
    
    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x == root_y:
            return 0
        
        if self.ranks[root_x] < self.ranks[root_y]:
            root_x, root_y = root_y, root_x
        
        self.roots[root_y] = root_x
        self.nrof_sets -= 1
        if self.ranks[root_x] == self.ranks[root_y]:
            self.ranks[root_x] += 1
        return 1

In [None]:
def kruskal_mst(n, edges_heap):
    """O(E log V). Use for sparse graphs."""
    set_ = DisjointSet(n)
    sum_ = 0
    while edges_heap:
        cost, x, y = heappop(edges_heap)
        sum_ += cost * set_.union(x, y)
    return sum_

In [None]:
def prim_mst(graph, start=0):
    """O(E log V)"""
    visited = [False] * len(graph)
    priority_queue, edges = [(0, start, -1)], []

    while priority_queue:
        cost, v, prev = heappop(priority_queue)
        if not visited[v]:
            visited[v] = True
            edges.append((cost, prev, v))
            for w, edge_cost in graph[v]:
                heappush(priority_queue, (edge_cost, w, v))
    
    return edges[1:]

# Math

In [None]:
def _try_composite(a, d, n, s):
    if pow(a, d, n) == 1:
        return False
    for i in range(s):
        if pow(a, 2**i * d, n) == n-1:
            return False
    return True
 
def miller_rabin(n,
                 _known_primes=[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97],
                 _precision_for_huge_n=16):
    if n in _known_primes or n in (0, 1):
        return True
    if any((n % p) == 0 for p in _known_primes):
        return False
    d, s = n - 1, 0
    while not d % 2:
        d, s = d >> 1, s + 1
    # Returns exact according to http://primes.utm.edu/prove/prove2_3.html
    if n < 1373653: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3))
    if n < 25326001: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5))
    if n < 118670087467: 
        if n == 3215031751: 
            return False
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7))
    if n < 2152302898747: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11))
    if n < 3474749660383: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11, 13))
    if n < 341550071728321: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11, 13, 17))
    # otherwise
    return not any(_try_composite(a, d, n, s) 
                   for a in _known_primes[:_precision_for_huge_n])

In [None]:
import string
digs = string.digits + string.ascii_letters


def to_base(x, base):
    """Decimal to custom base"""
    if x < 0:
        sign = -1
    elif x == 0:
        return digs[0]
    else:
        sign = 1

    x *= sign
    digits = []

    while x:
        digits.append(digs[int(x % base)])
        x = int(x / base)

    if sign < 0:
        digits.append('-')

    digits.reverse()

    return ''.join(digits)

In [None]:
def mod_mul_inverse(l, m):
    """Generate the Modular multiplicative inverses for number from 1 -> l (mod m)"""
    invs = [0] * l
    if l > 1:
        invs[1] = 1
    for i in range(2, l):
        invs[i] = (m - invs[m % i] * (m // i) % m) % m
        if invs[i] * i % m != 1:
            raise ValueError()
    return invs

In [None]:
import operator as op
from functools import reduce

def ncr(n, r):
    """n choose r"""
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer // denom

# String

# Tree

In [None]:
class SegmentTreeMinimum(object):
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (2*self.n)
        for i in range(self.n):
            self.tree[self.n+i] = arr[i]
        self.build()
    
    def build(self):
        for i in range(self.n - 1, 0, -1):
            self.tree[i] = min(self.tree[i<<1], self.tree[i<<1|1])
    
    def update(self, i, value):
        t = self.n + i
        self.tree[t] = value
        while t > 1:
            self.tree[t>>1] = min(self.tree[t], self.tree[t^1])
            t >>= 1
    
    def query(self, l, r):
        res = float('inf');
        l += self.n
        r += self.n
        while l < n:
            if l & 1:
                res = min(res, self.tree[l])
                l += 1
            if r & 1:
                r -= 1
                res = min(res, self.tree[r])
        return res

In [None]:
class AVLNode(object):
    def __init__(self, val, bf=0):
        self.value = val
        self.left = None
        self.right = None
        self.parent = None
        self.bf = bf
        self.height = 1
    
    @staticmethod
    def get_height(node):
        return 0 if node is None else node.height
    
    def update_height_bf(self):
        self.height = max(self.get_height(node.left), self.get_height(node.right)) + 1
        self.bf = get_tree_height(node.left) - get_tree_height(node.right)
        

class AVLNode(object):
    def __init__(self, val, bf=0):
        self.value = val
        self.left = None
        self.right = None
        self.parent = None
        self.bf = bf
        self.height = 1
    
    @staticmethod
    def get_height(node):
        return 0 if node is None else node.height
    
    def update_height_bf(self):
        self.height = max(self.get_height(self.left), self.get_height(self.right)) + 1
        self.bf = self.get_height(self.left) - self.get_height(self.right)
        

class AVLTree(object):
    
    def __init__(self):
        self.root = None
    
    def insert(self, value):
        self.root = AVLTree._insert(self.root, value)
    
    @staticmethod
    def _insert(node, value):
        if node is None:
            return AVLNode(value)
        if value < node.value:
            node.left = AVLTree._insert(node.left, value)
            node.left.parent = node
        elif value > node.value:
            node.right = AVLTree._insert(node.right, value)
            node.right.parent = node
        else:
            return node
        
        node.update_height_bf()
        return AVLTree.rebalance(node)
    
    @staticmethod
    def rebalance(node):
        if node.bf == 2:
            if node.left.bf < 0:
                node.left = AVLTree.rotate_left(node.left)
            return AVLTree.rotate_right(node)
        elif node.bf == -2:
            if node.right.bf > 0:
                node.right = AVLTree.rotate_right(node.right)
            return AVLTree.rotate_left(node)
        else:
            return node
    
    @staticmethod
    def rotate_left(node):
        pivot = node.right
        tmp = pivot.left
        
        pivot.left = node
        pivot.parent = node.parent
        node.parent = pivot
        node.right = tmp
        if tmp is not None:
            tmp.parent = node
        
        if pivot.parent is not None:
            if pivot.parent.left == node:
                pivot.parent.left = pivot
            else:
                pivot.parent.right = pivot
        
        node.update_height_bf()
        pivot.update_height_bf()
        
        return pivot
    
    @staticmethod
    def rotate_right(node):
        pivot = node.left
        tmp = pivot.right
        
        pivot.right = node
        pivot.parent = node.parent
        node.parent = pivot
        node.left = tmp
        if tmp is not None:
            tmp.parent = node
        
        if pivot.parent is not None:
            if pivot.parent.right == node:
                pivot.parent.right = pivot
            else:
                pivot.parent.left = pivot
        
        node.update_height_bf()
        pivot.update_height_bf()
        
        return pivot
    
    def find(self, target, mode='exact'):
        node = self.root
        prev = None
        mode = mode.lower()
        while node is not None:
            if target == node.value:
                return node
            prev = node
            if target < node.value:
                node = node.left
            else:
                node = node.right
        
        if mode == 'exact':
            return None
        if mode == 'lowerbound':
            while prev is not None and prev.value < target:
                prev = prev.parent
            return prev
        elif mode == 'upperbound':  
            while prev is not None and prev.value > target:
                prev = prev.parent
            return prev
        else:
            raise ValueError('Mode %s is invalid.' % mode)