<div class="elfjS" data-track-load="description_content"><p>Given an <code>n x n</code> <code>matrix</code> where each of the rows and columns is sorted in ascending order, return <em>the</em> <code>k<sup>th</sup></code> <em>smallest element in the matrix</em>.</p>

<p>Note that it is the <code>k<sup>th</sup></code> smallest element <strong>in the sorted order</strong>, not the <code>k<sup>th</sup></code> <strong>distinct</strong> element.</p>

<p>You must find a solution with a memory complexity better than <code>O(n<sup>2</sup>)</code>.</p>

<p>&nbsp;</p>
<p><strong class="example">Example 1:</strong></p>

<pre><strong>Input:</strong> matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
<strong>Output:</strong> 13
<strong>Explanation:</strong> The elements in the matrix are [1,5,9,10,11,12,13,<u><strong>13</strong></u>,15], and the 8<sup>th</sup> smallest number is 13
</pre>

<p><strong class="example">Example 2:</strong></p>

<pre><strong>Input:</strong> matrix = [[-5]], k = 1
<strong>Output:</strong> -5
</pre>

<p>&nbsp;</p>
<p><strong>Constraints:</strong></p>

<ul>
	<li><code>n == matrix.length == matrix[i].length</code></li>
	<li><code>1 &lt;= n &lt;= 300</code></li>
	<li><code>-10<sup>9</sup> &lt;= matrix[i][j] &lt;= 10<sup>9</sup></code></li>
	<li>All the rows and columns of <code>matrix</code> are <strong>guaranteed</strong> to be sorted in <strong>non-decreasing order</strong>.</li>
	<li><code>1 &lt;= k &lt;= n<sup>2</sup></code></li>
</ul>

<p>&nbsp;</p>
<p><strong>Follow up:</strong></p>

<ul>
	<li>Could you solve the problem with a constant memory (i.e., <code>O(1)</code> memory complexity)?</li>
	<li>Could you solve the problem in <code>O(n)</code> time complexity? The solution may be too advanced for an interview but you may find reading <a href="http://www.cse.yorku.ca/~andy/pubs/X+Y.pdf" target="_blank">this paper</a> fun.</li>
</ul>
</div>

In [20]:
from typing import *

In [21]:
class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int, verbose: bool = True) -> int:
        # if k == 1:
        #     return matrix[0][0]

        def _udpdate_ptr(ptr_a: bool):
            """
            Helper function to update the pointers
            as we traverse the matrix.
            """
            nonlocal i, j, p, q
            if ptr_a:
                # update ptr a which starts at the
                # bottom row and traverses right to
                # left
                j -= 1
                # if i == 0:
                #     i, j = p, q
                if j < 0:
                    j = q - 1
                    i -= 1

            else:
                # update ptr b which starts at the
                # right column and traverses
                # bottom to top
                p -= 1
                # if q == 0:
                #     p, q = i, j
                if p < 0:
                    p = i - 1
                    q -= 1

        n = len(matrix)

        i, j, p, q = [n-1] * 4
        # i, q = [n-1] * 2
        # j, p = [n-2] * 2

        c = n**2 + 1
        val = None
        while c > 0:
            if verbose:
                print(f"c: {c}\n(i,j): {(i,j)}\n(p,q): {(p,q)}\nval: {val}")
            if c == k:
                # we have found the kth smallest
                # element, return it
                return val
            bottom_val = matrix[i][j]
            right_val = matrix[p][q]

            if (i,j) == (p, q):
                # ptrs are at the same position,
                # this is the next smallest element
                # and both ptrs must be updated
                val = bottom_val
                _udpdate_ptr(False)
                _udpdate_ptr(True)
            elif bottom_val >= right_val:
                val = bottom_val
                _udpdate_ptr(True)
            else:
                val = right_val
                _udpdate_ptr(False)

            c -= 1

        # we should never reach this point
        assert False, "We should never reach this point."


def main():
    test_cases = {
        "1": {
            "matrix": [[1,5,9],[10,11,13],[12,13,15]],
            "k": 8,
            "expected": 13,
        },
        "2": {
            "matrix": [[-5]],
            "k": 1,
            "expected": -5,
        },
        "3": {
            "matrix": [[1,3,5],[6,7,12],[11,14,14]],
            "k": 1,
            "expected": 1,
        },
        "4": {
            "matrix": [[1,3,5],[6,7,12],[11,14,14]],
            "k": 2,
            "expected": 3,
        },
    }

    solution = Solution()

    for tk, targs in test_cases.items():
        expected = targs.pop("expected", None)
        ret = solution.kthSmallest(**targs, verbose=True)
        if expected is not None:
            passed = ret == expected
        else:
            passed = None
        print(f"test case {tk}: {targs}\nReturned: {ret}, Expected: {expected}\nPassed:{passed}\n")


main()

c: 10
(i,j): (2, 2)
(p,q): (2, 2)
val: None
c: 9
(i,j): (2, 1)
(p,q): (1, 2)
val: 15
c: 8
(i,j): (2, 0)
(p,q): (1, 2)
val: 13
test case 1: {'matrix': [[1, 5, 9], [10, 11, 13], [12, 13, 15]], 'k': 8}
Returned: 13, Expected: 13
Passed:True

c: 2
(i,j): (0, 0)
(p,q): (0, 0)
val: None
c: 1
(i,j): (-1, -2)
(p,q): (-1, -1)
val: -5
test case 2: {'matrix': [[-5]], 'k': 1}
Returned: -5, Expected: -5
Passed:True

c: 10
(i,j): (2, 2)
(p,q): (2, 2)
val: None
c: 9
(i,j): (2, 1)
(p,q): (1, 2)
val: 14
c: 8
(i,j): (2, 0)
(p,q): (1, 2)
val: 14
c: 7
(i,j): (2, 0)
(p,q): (0, 2)
val: 12
c: 6
(i,j): (1, 1)
(p,q): (0, 2)
val: 11
c: 5
(i,j): (1, 0)
(p,q): (0, 2)
val: 7
c: 4
(i,j): (0, 1)
(p,q): (0, 2)
val: 6
c: 3
(i,j): (0, 1)
(p,q): (-1, 1)
val: 5
c: 2
(i,j): (0, 1)
(p,q): (-1, 0)
val: 14
c: 1
(i,j): (0, 1)
(p,q): (-1, -1)
val: 11
test case 3: {'matrix': [[1, 3, 5], [6, 7, 12], [11, 14, 14]], 'k': 1}
Returned: 11, Expected: 1
Passed:False

c: 10
(i,j): (2, 2)
(p,q): (2, 2)
val: None
c: 9
(i,j): (2, 1)
(p,q)