## 378. Kth Smallest Element in a Sorted Matrix [problem](https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/)

**Similar problems: [No.373 Find K Pairs with Smallest Sums](./find_k_pairs_with_smallest_sums.ipynb)**

---

Given an ```n x n``` matrix where each of the rows and columns is **sorted in ascending order**, return the ```k```th smallest element in the matrix.

Note that it is the ```k```th smallest element in the sorted order, not the ```k```th distinct element.

You must find a solution with a memory complexity better than $O(N^2)$.

---

**Constraints:**

* ```n == matrix.length == matrix[i].length```
* ```1 <= n <= 300```
* ```-10^9 <= matrix[i][j] <= 10^9```
* All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
* ```1 <= k <= n^2```

---

**Follow-up:**

* Could you solve the problem with a constant memory (i.e., $O(1)$ memory complexity)?
* Could you solve the problem in $O(N)$ time complexity? 

### 1. Priority queue
* Time complexity: $O(KlogK)$
* Space complexity: $O(K)$

**Key point: pop from the heap (the smallest in the heap), then immediately push the its right/down neighbors and mark as visited.**

In [1]:
from typing import List

def kthSmallest(matrix: List[List[int]], k: int) -> int:
    """
    Args:
        matrix: a 2D integer array, sorted along the rows and columns
        k: an integer
        
    Return:
        the kth smallest integer in matrix
    """

    heap = []
    visited = set()

    heapq.heappush(heap, (matrix[0][0], 0, 0))
    visited.add((0, 0))

    ret = [] # array is not necessary, can be a variable
    while len(ret) < k: # O(K)
        num, row, col = heapq.heappop(heap)
        ret.append(num)

        if row + 1 < len(matrix) and (row + 1, col) not in visited:
            heapq.heappush(heap, (matrix[row + 1][col], row + 1, col)) # O(logK)
            visited.add((row + 1, col))

        if col + 1 < len(matrix[0]) and (row, col + 1) not in visited:
            heapq.heappush(heap, (matrix[row][col + 1], row, col + 1)) # O(logK)
            visited.add((row, col + 1))

    return ret[-1]