In [None]:
## 493. Reverse Pairs
https://leetcode.com/problems/reverse-pairs/

Given an array nums, we call (i, j) an important reverse pair if i < j and nums[i] > 2*nums[j].

You need to return the number of important reverse pairs in the given array.

Example1:

Input: [1,3,2,3,1]
Output: 2
Example2:

Input: [2,4,3,5,1]
Output: 3
Note:

The length of the given array will not exceed 50,000.
All the numbers in the input array are in the range of 32-bit integer.

In [None]:
class Solution:
    def reversePairs(self, nums: List[int]) -> int:
        
        '''
        # BINARY INDEXED TREE
        https://en.wikipedia.org/wiki/Fenwick_tree#Applications
        https://www.geeksforgeeks.org/binary-indexed-tree-or-fenwick-tree-2/
        
        Given an input nums[0,... n-1]
        A BIT is a special array that supporting the query/update for sum of 
        nums[0...idx-1] in O(log(n))
        Each BIT[idx] MIGHT or MIGHT NOT contain the sum of nums[0... idx-1]
        
        The sum of the first idx elements (assume idx=2^i1 + 2^i2 + .. +2^ik)  
        (k<log(n)) nums[0...2^i1...2^i1+2^i2...2^i1 + 2^i2 + .. +2^ik] is 
        broken down into the sum of the first 2^i1 elements,
        the sum of the next 2^i2 ones, ... and the sum of the last 2^ik ones.
        
        When we query for the sum of the first idx elements,i.e., sum of 
        nums[0...idx-1], we have to add all these BIT at locations 
        idx - (idx & (-idx)) 
        
        Note: idx & -idx : least significant set-bit (bit=1) in the binary
        representation of the number idx
        
        
        When we needs to add val into element (of the original input array) at 
        index idx, i.e. nums[idx] += val, sum of all elements from the first to 
        the index m>=idx, has to be updated
        --> besides updating BIT[idx], we have to update BIT 
        at locations idx + (idx & (-idx)) 
        
        # In this important reverse pair problem, BIT structure is used to obtain
        O(log(n))-time for update/query the number of elements that smaller than 
        the current num
        
        At the 1st for-loop, we need to update BIT so that BIT is updated so that
        some BIT[i] corresponding to the elements (in nums) that are smaller than 
        nums[0] increase by 1
        
        Particularly, if nums[0] is ranked idx in the sorted_nums 
        (sorted_nums[idx] = nums[0]), then some related BIT[i] 
        (idx=2^i1 + 2^i2...) is increased by 1
        
        In the FIRST loop, before updating, we do query to see how many nums[h] 
        (with h<0) is smaller than 2*nums[0]+1 --> we query BIT for the index 
        equals to the ranking of 2*nums[0] + 1 in the sorted_nums . (Noted: 
        this is the 1st loop, so the return is 0) 
        
        At the 2nd for-loop, we do query to see how many nums[h] 
        (with h<1) is smaller than 2*nums[1]+1 --> we query BIT for the index 
        equals to the ranking of 2*nums[1] + 1 in the sorted_nums. Noted: 
        in this 2st loop, some BIT[idx_h] has been updated with value 1 when we
        updating BIT corresponding to the ranking of nums[0] in the 1st for-loop
        
        Assuming nums[1] is ranked at idx1 in the sorted_nums
        then some related BIT[j] (idx1=2^j1 + 2^j2+...) is incrased by 1
        
        
        '''
        
        def update(bit, idx, val):
            '''
            
            '''
            while idx > 0:
                lsb = idx & (-idx)
                bit[idx] += val
                idx -= lsb
            
        def query(bit, idx):
            '''
            In the original Binary Indexed Tree, the 'query' function returns 
            the sum of nums[0...idx-1]
            BIT is a structure that supporting the query for 
            sum of nums[0...idx-1] in O(log(n)): the returned sum is the 
            accummulation of the sum at locations BIT[i1] + BIT[i2] +... + BIT[ik]
            where idx = 2^i1 + 2^i2 + ... 2^ik
            we always have k <= log(n) --> run time is O(log(n))
            BIT[i1]: hold the sum of the first 2^i1 elements, i.e.,
            nums[0],...,nums[2^i1 - 1]
            BIT[i2]: hold the sum of the next 2^i2 - 2^i1 elements, i.e.,
            nums[2^i1], ..., nums[2^i2 - 1], 
            In the original BIT, to get the query sum, we travel from i1 (sum+=BIT[i1]), to i2 (sum+=BIT[i2]),... to ik (sum+=[ik])
            idx -= idx & (-idx)
            (idx & (-idx): give the least significant bit of idx)
            
            The FOLLOWING code modified the original query slightly:
            idx += idx & (-idx)
            
            '''
            queried_sum = 0 
            n = len(bit)
            while idx < n: # if idx is >=n, 
                           # this while-loop is skiped and the query returns 0
                lsb = idx & (-idx)
                queried_sum += bit[idx] 
                idx += lsb
            return queried_sum
    
        def get_lowest_index_greater_or_equal(sorted_nums, target):
            '''
            Binary search for lowest index at which val is greater 
            or equal to the given target
            '''
            lo, hi = 0, len(sorted_nums) - 1
            if sorted_nums[hi] < target:
                return hi + 1
            while lo < hi:
                mid = (lo + hi)//2
                if sorted_nums[mid] < target:
                    lo = mid + 1
                else:
                    hi = mid
            return hi
    
        n = len(nums)
        sorted_nums = sorted(nums)
        #print(sorted_nums)
        bit = [0 for _ in range(n+1)]
        rp_count = 0
        for num in nums:
            # adding 1 to idx as BIT[0] is dummy
            idx = get_lowest_index_greater_or_equal(sorted_nums, 2 * num + 1) + 1
            rp_count += query(bit, idx)
            #print(num, idx, rp_count)
            idx = get_lowest_index_greater_or_equal(sorted_nums, num) + 1
            update(bit, idx, 1)
            #print('--', idx, bit)
        return rp_count    
    
    
        