### 4. Median of Two Sorted Arrays

Given two sorted arrays `nums1` and `nums2` of size `m` and `n` respectively, return the median of the two sorted arrays.

The overall run time complexity should be `O(log(m+n))`.

<ins>Logic<ins>

**General Idea**: Partition `nums1` and `nums2` and hence can find the median based on the partitioned cutoff index. I.e., all elements that is before (inclusively) the cutoff index in `nums1` and `nums2` are the first half of the 'merged' array

<br><br>

1. Starting with **shorter array**, `a`, use its midpoint `mid1` as the initial cutoff point, since we know the partition size, so the cutoff index for the longer array, `b`, is `mid2 = n_half - mid1 - 2`. 

    Suppose `mid1` is the correct partition cutoff index, then

    `a[0], ..., a[mid1], b[0], ..., b[n_half - mid1 - 2]` constructs the left partition

<br><br>

2. To determine if the current cutoff index `mid1` is correct, the following needs to be satisfied

    `min(a[mid1 + 1], a[mid2 + 1]) >= max(a[mid1], b[mid2])`

    Since `a` and `b` are sorted, this can be reduced to $\Rightarrow$ `a[mid1 + 1] >= b[mid1] and b[mid1 + 1] >= a[mid1]`

<br><br>

3. Then we can run binary search on `a` by setting `start, end = 0, len(a) - 1`, and use the result of step2 in each loop to find the correct cutoff index
    
    - Case1: `a[mid1 + 1] < b[mid2]`: this implies the `b[mid2]` is too large, we should include more elements in `a` and less elements in `b` in left partition

        $\Rightarrow$ `start = mid1 + 1`
    
    - Case2: `b[mid2 + 1] < a[mid1]`: this implies the `a[mid1]` is too large, we should include less elements in `a` and more elements in `b` in left partition

        $\Rightarrow$ `end = mid1 - 1`

    - Case3: `a[mid1 + 1] >= b[mid1] and b[mid1 + 1] >= a[mid1]`: this implies `mid1` and  `mid2` are the correct cutoff index, and we can calculate median in next step

<br><br>

4. Calculate the median based on whether `len(a) + len(b)` is odd or even

    - Odd length: median is the first element in right partition (as `n_half = n_total // 2` which means left partition has 1 element less than right partition in case of odd length)
    
        i.e., `min(a[mid1 + 1], b[mid2 + 1])`

    - Even length: median is the mean of last element in left partition and the first element in right partition, 
    
        i.e., `(max(a[mid1], b[mid2]) + min(a[mid1 + 1], b[mid2 + 1])) / 2`

<ins>Remarks<ins>

1. why starting with shorter array

    If starting with the longer array, the value of `mid2` can be incorrect (out of bound)

    e.g., `a = [2, 3, 4], b = []` $\Rightarrow$ `mid1 = 1, mid2 = 1 - 1 - 2 = -2`, which will require additional work to fix the issue

    However, if `a` is the shorter array, such issue will not happen (can prove by contradiction)

<br><br>

2. why `mid2 = n_half - mid1 - 2 `

    In order to have `n_half` elements in left partition, we need `n_half - (mid1 + 1)` elements in `b`, which will be `n_half - (mid1 + 1) - 1` as for index

<br><br>

3. How to treat the out of bound case

    It's possible that none / all of the elements in `a` belongs to left partition, 
    
    i.e., `mid1 = -1` or `mid1 + 1 = len(a)` (same for mid2)

    When determine if the partition is correct (Step 2 in Logic), these indices are out of bound, so we need to assign these out of bound values with special value to overcome this error

    i.e., `a[-1] = -inf` and `a[len(a)] = inf` which still makes the logic work.

<br><br>    

4. How to set the while loop guard

    Normally, while loop guard is set to be `start <= end`. However, when `start > end`, `mid1` has not been placed to correct position.

    e,g, suppose the last loop set `start, end = 0, -1`, which implies none of the elements in `a` belongs to left partition. However, when loop is terminated, `mid1` hasn't been set to `-1` in while loop.

    To overcome this issue, we need to run one extra loop to set new `mid1` as well as `mid2`, by setting `while start <= end + 1`(equivalent to manually adjust `mid1` and `mid2` one extra time) or `while True` 


In [28]:
def get_median(nums1, nums2):
    # let a be the shorter array to prevent index out of bound error
    a, b = (nums1, nums2) if len(nums1) <= len(nums2) else (nums2, nums1)

    # constant
    n_total = len(a) + len(b)
    n_half = n_total // 2

    # start partition & perform bs on a
    start, end = 0, len(a) - 1
    while start <= end + 1:
        # cutoff index
        mid1 = (start + end) // 2
        mid2 = n_half - mid1 - 2

        # get 4 edge values to determine if partition is correct
        a_l = a[mid1] if mid1 >= 0 else float('-inf')
        b_l = b[mid2] if mid2 >= 0 else float('-inf')
        a_r = a[mid1 + 1] if mid1 + 1 < len(a) else float('inf')
        b_r = b[mid2 + 1] if mid2 + 1 < len(b) else float('inf')

        # determine partition
        if a_l > b_r:
            # imply a_l is too large, need less elements in a in left partition
            end = mid1 - 1
        elif b_l > a_r:
            # imply b_l is too large, need more elements in a in left partition
            start = mid1 + 1
        else:
            # imply partition is correct
            if n_total % 2:
                # odd: 1st element in right partition
                return min(a_r, b_r)
            else:
                # even mean of last element in left partition and 1st in right partition
                return (max(a_l, b_l) + min(a_r, b_r)) / 2 

In [29]:
# test
def median(nums1, nums2):
    nums = [*nums1, *nums2]
    nums.sort()
    cutoff = (len(nums) - 1) // 2
    if len(nums) % 2:
        return nums[cutoff]
    else:
        return (nums[cutoff] + nums[cutoff + 1]) / 2

def test(test_name, nums1, nums2):
    print(test_name)
    print(nums1, nums2)
    print('Pass!' if median(nums1, nums2) == get_median(nums1, nums2) else 'Fail!')
    print('\n')


# t0
nums1 = [1,2,3,4,5,6,7,8,9]
nums2 = [10, 11]
test('t0', nums1, nums2)

# t1
nums1 = [1,2,3,4,5,6,7,8,9]
nums2 = [-1, 0]
test('t1', nums1, nums2)

# t2
nums1 = [2,3,4]
nums2 = []
test('t2', nums1, nums2)

# t3
nums1 = [1,3,5,7]
nums2 = [2,4,6,8]
test('t3', nums1, nums2)

t0
[1, 2, 3, 4, 5, 6, 7, 8, 9] [10, 11]
Pass!


t1
[1, 2, 3, 4, 5, 6, 7, 8, 9] [-1, 0]
Pass!


t2
[2, 3, 4] []
Pass!


t3
[1, 3, 5, 7] [2, 4, 6, 8]
Pass!


