# [Explore Amazon](https://leetcode.com/explore/interview/card/amazon/)

Top interview questions asked by Amazon as voted by the community.

We compiled this list thoroughly so you can save time and get well-prepared for an Amazon interview.

Completing this card should give you a good idea of the type of questions you would encounter in your Amazon interview.

## Arrays and strings

Amazon likes to ask simple, basic array questions. We highly recommend you to practice First Unique Character in a String, which is a popular question being asked. We also recommend Integer to English Words.

---
**[1. Two Sum](https://leetcode.com/problems/two-sum/)**

Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target.

You may assume that each input would have exactly one solution, and you may not use the same element twice.

You can return the answer in any order.

In [1]:
import string
from typing import List, Optional
from collections import Counter

In [2]:
def two_sum(nums: List[int], x: int) -> bool:
    
    nums.sort()
    
    if len(nums) <= 1:
        return False
    
    i, j = 0, len(nums) - 1
    
    while i < j:
        curr_sum = nums[i] + nums[j]
        if curr_sum == x:
            return True
        
        elif curr_sum < x:
            i += 1
        else:
            j -= 1
        
    return False

In [3]:
assert two_sum(nums=[0, 2, -4, 4, 6, 11, -7], x=7)
assert two_sum(nums=[0, -1, 2, 4, 12, -6, 8, 15], x=8)

---
**[273. Integer to English Words](https://leetcode.com/problems/integer-to-english-words/) (Recommended)**

Convert a non-negative integer num to its English words representation.

In [4]:
digits = {'1': 'One', '2': 'Two', '3': 'Three', '4': 'Four', '5': 'Five',
                 '6': 'Six', '7': 'Seven', '8': 'Eight', '9': 'Nine'}

def numberToWordsUpToTwoDigits(str_num: str) -> str:
            
    teens = {'10': 'Ten', '11': 'Eleven', '12': 'Twelve', '13': 'Thirteen', '14': 'Fourteen', 
             '15': 'Fiveteen', '16': 'Sixteen', '17': 'Seventeen', '18': 'Eightteen', 
             '19': 'Nineteen'}

    decades = {'2': 'Twenty', '3': 'Thirty', '4': 'Fourty', '5': 'Fifty', '6': 'Sixty',
              '7': 'Seventy', '8': 'Eighty', '9': 'Ninety'}

    if set(str_num) == {'0'}:
        return ""
    if str_num[0] == '0':
        return digits[str_num[1]]
    if str_num[0] == '1':
        return teens[str_num]
    else:
        if str_num[1] == '0':
            return decades[str_num[0]]
        else:
            return decades[str_num[0]] + " " + digits[str_num[1]]
        
def numberToWordsUpToThreeDigits(str_num: str) -> str:
            
    if set(str_num) == {'0'}:
        return ""
    if str_num[0] == '0':
        return numberToWordsUpToTwoDigits(str_num[1:])

    res = digits[str_num[0]] + " Hundred " + numberToWordsUpToTwoDigits(str_num[1:])
    return res.strip()

def numberToWords(num: int) -> str:
    str_num = str(num)
    rev_str_num = str_num[::-1]
    res = []

    scales = {'0': '', '3': 'Thousand', '6': 'Million', '9': 'Billion'}

    for i in range(0, len(str_num), 3):
        num = rev_str_num[i: i+3][::-1]
        
        if len(num) == 1:
            curr_res = digits[num]
                
        elif len(num) == 2:
            curr_res = numberToWordsUpToTwoDigits(num)

        else:
            curr_res = numberToWordsUpToThreeDigits(num)
            
        if i > 0:
        
            if num != '0' and num != '00' and num != '000':
                curr_res = f" {curr_res} {scales[str(i)]}"
                res.append(curr_res)
        else:
            if curr_res:
                res.append(" " + curr_res)

    return "".join(res[::-1]).strip()

In [5]:
assert numberToWords('2000564321') == 'Two Billion Five Hundred Sixty Four Thousand Three Hundred Twenty One'
assert numberToWords('100000') == 'One Hundred Thousand'

---
**[819. Most Common Word](https://leetcode.com/problems/most-common-word/)**

Given a string paragraph and a string array of the banned words banned, return the most frequent word that is not banned. It is guaranteed there is at least one word that is not banned, and that the answer is unique.

The words in paragraph are case-insensitive and the answer should be returned in lowercase.

 

In [6]:
def mostCommonWord(paragraph: str, banned: List[str]) -> str:
    
    for i in string.punctuation:
        paragraph = paragraph.replace(i, ' ')
        
    freqs = Counter([el.strip(string.punctuation).lower() for el in paragraph.split()])

    for word, freq in sorted(freqs.items(), key=lambda p: -p[1]):
        if word not in banned:
            return word

In [7]:
assert mostCommonWord(paragraph="Bob hit a ball, the hit BALL flew far after it was hit.", banned = ["hit"]) == 'ball'
assert mostCommonWord(paragraph="a, a, a, a, b,b,b,c, c", banned=["a"]) == 'b'

---
**[15. 3sum](https://leetcode.com/problems/3sum/)**

Given an integer array `nums`, return all the triplets `[nums[i], nums[j], nums[k]]` such that `i != j, i != k`, and `j != k`, and `nums[i] + nums[j] + nums[k] == 0`.

Notice that the solution set must not contain duplicate triplets.

In [8]:
def threeSum(nums: List[int]) -> List[List[int]]:

    inv_num_dict = {}

    res, dups = set(), set()


    for i, el in enumerate(nums):
        if el in inv_num_dict:
            inv_num_dict[el].append(i)
        else:
            inv_num_dict[el] = [i]

    for i in range(len(nums)):
        if nums[i] not in dups:
            dups.add(nums[i])

            for j in range(i+1, len(nums)):

                complement = 0 - nums[i] - nums[j] 

                if complement in inv_num_dict:

                    for id_k in inv_num_dict[complement]:
                        if id_k != i and id_k != j:
                            triple = [nums[i], nums[j], nums[id_k]]
                            res.add(tuple(sorted(triple)))
        return res

In [9]:
assert threeSum(nums=[-1,0,1,2,-1,-4]) == {(-1, -1, 2), (-1, 0, 1)}
assert threeSum(nums=[0,1,1]) == set()

---
**[347. Top K Frequent Elements](https://leetcode.com/problems/top-k-frequent-elements/)**

Given an integer array nums and an integer k, return the k most frequent elements. You may return the answer in any order.

First solution $O(n \log n)$

In [10]:
def topKFrequent(nums: List[int], k: int) -> List[int]:
        
    freqs = Counter(nums)

    sorted_freqs = sorted(freqs.items(), key=lambda p: -p[1])

    return [k for (k, v) in sorted_freqs[:k]]

In [11]:
assert topKFrequent(nums=[1,1,1,2,2,3], k=2) == [1, 2]
assert topKFrequent(nums=[1], k=1) == [1]

Better

In [12]:
import heapq

In [13]:
def topKFrequent(nums: List[int], k: int) -> List[int]:
        
    res = []
    freqs = Counter(nums)
    inv_freqs = [(-v, k) for (k, v) in freqs.items()]
    
    heapq.heapify(inv_freqs)
    
    for _ in range(k):
        
        res.append(heapq.heappop(inv_freqs)[1])
        
    return res

In [14]:
assert topKFrequent(nums=[1,1,1,2,2,3], k=2) == [1, 2]
assert topKFrequent(nums=[1], k=1) == [1]

---
**[167. Two Sum II - Input Array Is Sorted](https://leetcode.com/problems/two-sum-ii-input-array-is-sorted/)** 

Input array is sorted.

Given a 1-indexed array of integers numbers that is already sorted in non-decreasing order, find two numbers such that they add up to a specific target number. Let these two numbers be numbers[index1] and numbers[index2] where 1 <= index1 < index2 <= numbers.length.

Return the indices of the two numbers, index1 and index2, added by one as an integer array [index1, index2] of length 2.

The tests are generated such that there is exactly one solution. You may not use the same element twice.

Your solution must use only constant extra space.

In [15]:
def twoSum(numbers: List[int], target: int) -> List[int]:

    left, right = 0, len(numbers) - 1

    while left < right:

        if numbers[left] + numbers[right] == target:
            return [left + 1, right + 1]
        elif numbers[left] + numbers[right] < target:
            left += 1
        else:
            right -=1

    return -1

In [16]:
assert twoSum(numbers=[2,7,11,15], target=9) == [1, 2]
assert twoSum(numbers=[2,3,4], target=6) == [1, 3]
assert twoSum(numbers=[-1,0], target=-1) == [1, 2]

---
**[12. Integer to Roman](https://leetcode.com/problems/integer-to-roman/)**

Roman numerals are represented by seven different symbols: I, V, X, L, C, D and M.

For example, 2 is written as II in Roman numeral, just two one's added together. 12 is written as XII, which is simply X + II. The number 27 is written as XXVII, which is XX + V + II.

Roman numerals are usually written largest to smallest from left to right. However, the numeral for four is not IIII. Instead, the number four is written as IV. Because the one is before the five we subtract it making four. The same principle applies to the number nine, which is written as IX. There are six instances where subtraction is used:

I can be placed before V (5) and X (10) to make 4 and 9. 
X can be placed before L (50) and C (100) to make 40 and 90. 
C can be placed before D (500) and M (1000) to make 400 and 900.
Given an integer, convert it to a roman numeral.

In [17]:
def intToRoman(num: int) -> str:

    symb_dict = {1: 'I', 5: 'V', 10: 'X', 50: 'L', 100: 'C', 500: 'D', 1000: 'M'}
    combs_dict = {4: 'IV', 9: 'IX', 40: 'XL', 90: 'XC', 400: 'CD', 900: 'CM'}

    res = ""

    for i, digit in enumerate(str(num)):

        digit = int(digit)

        scale = len(str(num)) - i - 1

        if digit * 10 ** scale in combs_dict:
            res += combs_dict[digit * 10 ** scale]
        elif digit * 10 ** scale in symb_dict:
            res += symb_dict[digit * 10 ** scale]
        else:
            if digit >= 5:
                res += symb_dict[5 * 10 ** scale] 
                res += symb_dict[10 ** scale] * (digit - 5)
            else:
                res += symb_dict[10 ** scale] * digit

    return res

In [18]:
assert intToRoman(3) == 'III'
assert intToRoman(58) == "LVIII"
assert intToRoman(1994) == 'MCMXCIV'

---
**[937. Reorder Data in Log Files](https://leetcode.com/problems/reorder-data-in-log-files/)**

You are given an array of logs. Each log is a space-delimited string of words, where the first word is the identifier.

There are two types of logs:

- Letter-logs: All words (except the identifier) consist of lowercase English letters.
- Digit-logs: All words (except the identifier) consist of digits.

Reorder these logs so that:

- The letter-logs come before all digit-logs.
- The letter-logs are sorted lexicographically by their contents. If their contents are the same, then sort them lexicographically by their identifiers.
- The digit-logs maintain their relative ordering.

Return the final order of the logs.

In [19]:
def reorderLogFiles(logs: List[str]) -> List[str]:

    def is_digit_log(log: str):
        for el in log.split()[1:]:
            if el.isalpha():
                return False
        return True


    digit_logs, letter_logs = [], []

    for log in logs:
        if is_digit_log(log):
            digit_logs.append(log)
        else:
            letter_logs.append(log)

    res = sorted(letter_logs, key=lambda s: (" ".join(s.split()[1:]), s.split()[0])) + digit_logs

    return res


In [20]:
assert reorderLogFiles(logs=["dig1 8 1 5 1","let1 art can","dig2 3 6","let2 own kit dig","let3 art zero"])

_Nicer_

In [21]:
def reorderLogFiles(logs: List[str]) -> List[str]:

    def get_key(log):
        _id, rest = log.split(" ", maxsplit=1)
        return (0, rest, _id) if rest[0].isalpha() else (1, )

    return sorted(logs, key=get_key)

In [22]:
assert reorderLogFiles(logs=["dig1 8 1 5 1","let1 art can","dig2 3 6","let2 own kit dig","let3 art zero"])

---
**[42. Trapping Rain Water](https://leetcode.com/problems/trapping-rain-water/)**

Given n non-negative integers representing an elevation map where the width of each bar is 1, compute how much water it can trap after raining.

Works correctly but doesn't fit into time limits

In [23]:
def trap(height: List[int]) -> int:

    max_level_left, max_level_right = 0, 0
    max_level_right_pos = 0
    trapped = 0

    for i, h in enumerate(height):
        if h > max_level_left:
            max_level_left = h

        else:
            if max_level_right_pos > i:
                if max_level_right >= max_level_left:
                    #print(i, "added max_level_left-h", max_level_left - h, max_level_right, max_level_left)
                    trapped += max_level_left - h
                else:
                    trapped += max(max_level_right - h, 0)
            
            else:
                max_level_right, max_level_right_pos = 0, i + 1
                for j in range(i + 1, len(height)):
                    if height[j] > max_level_right:
                        max_level_right = height[j]
                        max_level_right_pos = j

                    if max_level_right >= max_level_left:
                        #print(i, "added max_level_left-h", max_level_left - h, max_level_right, max_level_left)
                        trapped += max_level_left - h
                        break
                    
                if max_level_right < max_level_left:
                    #print(i, "added max_level_right-h", max_level_right - h)
                    trapped += max(max_level_right - h, 0)

    return trapped

In [24]:
assert trap([0,1,0,2,1,0,1,3,2,1,2,1]) == 6
assert trap([4,2,0,3,2,5]) == 9

In [25]:
def trap(height: List[int]) -> int:
    
    res = 0
    left_maxes, right_maxes = [height[0]] * len(height), [height[-1]] * len(height)
    
    for i, h in enumerate(height):
        left_maxes[i] = max(h, left_maxes[i-1])
    for i in range(len(height) - 2, -1, -1):
        right_maxes[i] = max(height[i], right_maxes[i+1]) 
        
    for i in range(len(height)):
        res += min(left_maxes[i], right_maxes[i]) - height[i]
        
    return res

In [26]:
assert trap([0,1,0,2,1,0,1,3,2,1,2,1]) == 6
assert trap([4,2,0,3,2,5]) == 9

## Linked List

These are some of the must-practice linked list questions asked by Amazon. We recommend you practice all of these questions.

---
**[138. Copy List with Random Pointer](https://leetcode.com/problems/copy-list-with-random-pointer/)**

Share
A linked list of length n is given such that each node contains an additional random pointer, which could point to any node in the list, or null.

Construct a deep copy of the list. The deep copy should consist of exactly n brand new nodes, where each new node has its value set to the value of its corresponding original node. Both the next and random pointer of the new nodes should point to new nodes in the copied list such that the pointers in the original list and copied list represent the same list state. None of the pointers in the new list should point to nodes in the original list.

For example, if there are two nodes X and Y in the original list, where X.random --> Y, then for the corresponding two nodes x and y in the copied list, x.random --> y.

Return the head of the copied linked list.

The linked list is represented in the input/output as a list of n nodes. Each node is represented as a pair of [val, random_index] where:

val: an integer representing Node.val
random_index: the index of the node (range from 0 to n-1) that the random pointer points to, or null if it does not point to any node.
Your code will only be given the head of the original linked list.

In [27]:
class Node:
    def __init__(self, x: int, next: 'Node' = None, random: 'Node' = None):
        self.val = int(x)
        self.next = next
        self.random = random

def copyRandomList(head: 'Optional[Node]') -> 'Optional[Node]':

    if not head:
        return None

    curr_node = head

    randoms = {}
    old_new_node_map, new_old_node_map = {}, {}

    new_head = Node(x=curr_node.val, next=None, random=None)
    new_head_bkp = new_head

    while curr_node:

        if curr_node.next:
            new_head.next = Node(x=curr_node.next.val, next=None, random=None)
        else:
            new_head.next = None

        old_new_node_map[curr_node] = new_head
        new_old_node_map[new_head] = curr_node
        randoms[curr_node] = curr_node.random

        curr_node = curr_node.next
        new_head = new_head.next

    curr_node = new_head_bkp

    while curr_node:

        old_random = randoms[new_old_node_map[curr_node]]

        curr_node.random = old_new_node_map[old_random] if old_random else None

        curr_node = curr_node.next

    return new_head_bkp

## Trees and Graphs

As you can see, Amazon likes to ask questions related to the Tree data structure. We highly recommend Number of Islands which seems to be Amazon's favorite.

**[103. Binary Tree Zigzag Level Order Traversal](https://leetcode.com/problems/binary-tree-zigzag-level-order-traversal/)**

Given the root of a binary tree, return the zigzag level order traversal of its nodes' values. (i.e., from left to right, then right to left for the next level and alternate between).



In [28]:
# Definition for a binary tree node.
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def zigzagLevelOrder(root: Optional[TreeNode]) -> List[List[int]]:
        
    if not root:
        return []

    visited = [[root.val]]
    queue = [[root]]

    level = 0

    while 1:

        curr_visited = []
        curr_queue = []

        for node in queue[level]:
            if node.left:
                curr_queue.append(node.left)
                curr_visited.append(node.left.val)
            if node.right:
                curr_queue.append(node.right)
                curr_visited.append(node.right.val)

        if curr_queue:
            level += 1
            if level % 2 == 1:
                visited.append(curr_visited[::-1])
            else:
                visited.append(curr_visited)
            queue.append(curr_queue)
        else:
            break

    return visited    

In [29]:
tree = TreeNode(val=3, 
                left=TreeNode(val=9), 
                right=TreeNode(val=20, 
                               left=TreeNode(val=15), 
                               right=TreeNode(val=7)))

In [30]:
assert zigzagLevelOrder(tree) == [[3], [20, 9], [15, 7]]

## Sorting and Searching

We highly recommend Kth Largest Element in an Array, which has been asked many times in an Amazon phone interview.

---
**[973. K Closest Points to Origin](https://leetcode.com/problems/k-closest-points-to-origin/)**

Given an array of points where $points[i] = [x_i, y_i]$ represents a point on the X-Y plane and an integer `k`, return the `k` closest points to the origin (0, 0).

The distance between two points on the X-Y plane is the Euclidean distance (i.e., $\sqrt{{(x_1 - x_2)}^2 + {(y_1 - y_2)}^2}$).

You may return the answer in any order. The answer is guaranteed to be unique (except for the order that it is in).



In [31]:
from math import sqrt

def kClosest(points: List[List[int]], k: int) -> List[List[int]]:

    return sorted(points, key=lambda pair: pair[0] ** 2 + pair[1] ** 2)[:k]

In [32]:
assert kClosest(points=[[1,3],[-2,2]], k=1) == [[-2,2]]
assert kClosest(points=[[3,3],[5,-1],[-2,4]], k=2) == [[3,3],[-2,4]]

- Time complexity; $O(N \cdot \log N)$ for the sorting of points.
- Space complexity: $O(\log N)$ to $O(N)$ for the extra space required by the sorting process.

_Using a heap_

In [33]:
import heapq

In [34]:
def squared_distance(point: List[int]):
    return point[0] ** 2 + point[1] ** 2 

In [35]:
def kClosest(points: List[List[int]], k: int) -> List[List[int]]:
    
    heap = []
    heapq.heapify(heap)
    
    for point in points:
        curr_dist = squared_distance(point)
        if len(heap) == k:
            if curr_dist < -heap[0]:
                heapq.heappop(heap)
                heapq.heappush(heap, -squared_distance(point))
        else:
            heapq.heappush(heap, -squared_distance(point))
    
    k_closest = set([-heapq.heappop(heap) for _ in range(k)])
    
    result = []
    
    for point in points:
        if squared_distance(point) in k_closest:
            result.append(point)
    
    return result
            

In [36]:
assert kClosest(points=[[1,3],[-2,2]], k=1) == [[-2,2]]
assert kClosest(points=[[3,3],[5,-1],[-2,4]], k=2) == [[3,3],[-2,4]]

- Time complexity; $O(N\cdot \log k)$. Adding to/removing from the heap (or priority queue) only takes $O(\log k)$ time when the size of the heap is capped at $k$ elements.
- Space complexity: $O(k)$. The heap (or priority queue) will contain at most kk elements.

---
**[215. Kth Largest Element in an Array](https://leetcode.com/problems/kth-largest-element-in-an-array/)**

Given an integer array nums and an integer k, return the kth largest element in the array.

Note that it is the kth largest element in the sorted order, not the kth distinct element.

In [37]:
from random import randint

def findKthLargest(nums: List[int], k: int) -> int:
    
    if len(nums) == 0:
        return -1
    
    if len(nums) == 1:
        return nums[0]
    
    pivot = nums[randint(0, len(nums) - 1)]
    
    left = [el for el in nums if el < pivot]
    right = [el for el in nums if el > pivot]
    same = [el for el in nums if el == pivot]
    
    pivot_rank = len(right) + 1
    
    if pivot_rank > k:
        return findKthLargest(right, k)
    elif pivot_rank < k:
        return findKthLargest(left + same[:-1], k-pivot_rank)
    else:
        return pivot

In [38]:
assert findKthLargest(nums=[3, 2, 1, 5, 6, 4], k=2) == 5
assert findKthLargest(nums=[3, 2, 3, 1, 2, 4, 5, 5, 6], k=4) == 4

In [53]:
def findKthLargest(nums: List[int], k: int) -> int:
        
    def partition(start, end):
        ran = randint(start, end)
        pivot = end
        nums[pivot], nums[ran] = nums[ran], nums[pivot]

        border = start
        for cur in range(start, end):
            if nums[cur] >= nums[pivot]:
                nums[cur], nums[border] = nums[border], nums[cur]
                border += 1

        nums[border], nums[pivot] = nums[pivot], nums[border]
        return border

    def quick_select(start, end, k_largest):
        res = None
        while start <= end:
            p = partition(start, end)
            if p == k_largest:
                res = nums[k_largest]
                break
            elif p > k_largest:
                end = p - 1
            else:
                start = p + 1
        return res

    return quick_select(0, len(nums)-1, k-1)

In [54]:
assert findKthLargest(nums=[3, 2, 1, 5, 6, 4], k=2) == 5
assert findKthLargest(nums=[3, 2, 3, 1, 2, 4, 5, 5, 6], k=4) == 4

## Dynamic Programming

Amazon does not ask a whole lot of Dynamic Programming questions. We recommend Best Time to Buy and Sell Stock.

---
**[5. Longest Palindromic Substring](https://leetcode.com/problems/longest-palindromic-substring/)**

Given a string s, return the longest palindromic substring in s.


_Brute-force_

In [160]:
def longestPalindrome(s: str) -> str:
    
    if len(s) == 1:
        return s
    
    curr_longest = s[0]
    
    for left in range(len(s)): 
        for right in range(left + len(curr_longest), len(s)):
            #print(s[left:right+1])
            if s[left:right+1] == s[left:right+1][::-1]:
                curr_longest = s[left:right+1]
    
    return curr_longest
            

In [161]:
assert longestPalindrome('babad') in ['aba', 'bab']
assert longestPalindrome('cbbd') == 'bb'
assert longestPalindrome('babab') == 'babab'
assert longestPalindrome('bb') == 'bb'
assert longestPalindrome('eabcb') == 'bcb'

- Time complexity : $O(n^3)$
- Space complexity : $O(1)$

_Smarter_

In [177]:
def longestPalindrome(s: str) -> str:
    
    if len(s) == 1:
        return s

    curr_longest = s[0]

    for i in range(len(s)):

        palindrom1 = expand_around_center(s, center_index=i, inbetween=False)
        palindrom2 = expand_around_center(s, center_index=i, inbetween=True)

        if len(palindrom1) > len(curr_longest):
            curr_longest = palindrom1

        if len(palindrom2) > len(curr_longest):
            curr_longest = palindrom2

    return curr_longest
            

def expand_around_center(s, center_index, inbetween=False):
    if inbetween:
        left, right = center_index - 1, center_index
    else:
        left, right = center_index - 1, center_index + 1

    while left >= 0 and right < len(s) and s[left] == s[right]:
        left -= 1
        right += 1

    return s[left+1:right]

In [178]:
assert longestPalindrome('babad') in ['aba', 'bab']
assert longestPalindrome('cbbd') == 'bb'
assert longestPalindrome('babab') == 'babab'
assert longestPalindrome('bb') == 'bb'
assert longestPalindrome('eabcb') == 'bcb'

- Time complexity : $O(n^2)$
- Space complexity : $O(1)$

## Design

These are some design questions for you to practice for your Amazon interview. We highly recommend LRU Cache.

---
**[146. LRU Cache](https://leetcode.com/problems/lru-cache/)**

Design a data structure that follows the constraints of a Least Recently Used (LRU) cache.

Implement the `LRUCache` class:

- `LRUCache(int capacity)` Initialize the LRU cache with positive size capacity.
- `int get(int key)` Return the value of the key if the key exists, otherwise return -1.
- `void put(int key, int value)` Update the value of the key if the key exists. Otherwise, add the key-value pair to the cache. If the number of keys exceeds the capacity from this operation, evict the least recently used key.

The functions get and put must each run in $O(1)$ average time complexity.

_Slow_

In [39]:
class LRUCache:

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.data = {}
        self.queue = []

    def get(self, key: int) -> int:
        
        if key in self.data:
            self.queue.remove(key)
            self.queue.append(key)
            return self.data[key]
        else:
            return -1
        

    def put(self, key: int, value: int) -> None:

        if len(self.data) == self.capacity:
            if key not in self.data:
                old_key = self.queue.pop(0)
                del self.data[old_key]
        
        
        self.data[key] = value
        if key in self.queue:
            self.queue.remove(key)
        self.queue.append(key)

In [40]:
lRUCache = LRUCache(capacity=2)
lRUCache.put(1, 1)
#print(lRUCache.data, lRUCache.queue)
lRUCache.put(2, 2)
#print(lRUCache.data, lRUCache.queue)
#print(lRUCache.get(1))
lRUCache.put(3, 3)
#print(lRUCache.get(2))
lRUCache.put(4, 4)
#print(lRUCache.get(1)) 
#print(lRUCache.get(3))
assert lRUCache.get(4) == 4

In [41]:
lRUCache = LRUCache(capacity=2)
#print(lRUCache.get(2)) 
lRUCache.put(2, 6)
#print(lRUCache.data, lRUCache.queue)
#print(lRUCache.get(1)) 
lRUCache.put(1, 5)
#print(lRUCache.data, lRUCache.queue)
lRUCache.put(1, 2)
#print(lRUCache.data, lRUCache.queue)
#print(lRUCache.get(1)) 
assert lRUCache.get(2) == 6

In [42]:
lRUCache = LRUCache(capacity=2)
lRUCache.put(2, 1)
#print(lRUCache.data, lRUCache.queue)
lRUCache.put(1, 1)
#print(lRUCache.data, lRUCache.queue)
lRUCache.put(1, 2)
#print(lRUCache.data, lRUCache.queue)
lRUCache.put(2,3)
#print(lRUCache.data, lRUCache.queue)
lRUCache.put(4, 1)
#print(lRUCache.data, lRUCache.queue)
#print(lRUCache.get(1)) 
assert lRUCache.get(2) == 3

_Faster_

In [43]:
from sortedcontainers import SortedDict

In [44]:
class LRUCache:

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.data = {}
        self.freqs = {}
        self.inv_freqs = SortedDict({})

    def get(self, key: int) -> int:
        
        if key in self.data:
            
            old_freq = self.freqs[key]
            if self.inv_freqs:
                self.freqs[key] = self.inv_freqs.keys()[-1] + 1
            else:
                self.freqs[key] = 1
            if old_freq in self.inv_freqs:
                del self.inv_freqs[old_freq]
            self.inv_freqs[self.freqs[key]] = key
            return self.data[key]
        else:
            return -1
        

    def put(self, key: int, value: int) -> None:

        if len(self.data) == self.capacity:
            if key not in self.data:
                least_freq = self.inv_freqs.keys()[0]
                least_freq_elem = self.inv_freqs[least_freq]
                del self.data[least_freq_elem]
                del self.inv_freqs[least_freq]
                del self.freqs[least_freq_elem]
                              
        
        
        self.data[key] = value
        
        old_freq = self.freqs.get(key, -1)
        if self.inv_freqs:
            self.freqs[key] = self.inv_freqs.keys()[-1] + 1
        else:
            self.freqs[key] = 1
        if old_freq != -1:
            del self.inv_freqs[old_freq]
        self.inv_freqs[self.freqs[key]] = key

In [45]:
lRUCache = LRUCache(capacity=2)
lRUCache.put(1, 1)
#print(lRUCache.data, lRUCache.freqs, lRUCache.inv_freqs)
lRUCache.put(2, 2)
#print(lRUCache.data, lRUCache.freqs, lRUCache.inv_freqs)
#print(lRUCache.get(1))
lRUCache.put(3, 3)
#print(lRUCache.get(2))
lRUCache.put(4, 4)
#print(lRUCache.get(1)) 
#print(lRUCache.get(3))
assert lRUCache.get(4) == 4

In [46]:
lRUCache = LRUCache(capacity=2)
#print(lRUCache.get(2)) 
lRUCache.put(2, 6)
#print(lRUCache.data, lRUCache.freqs, lRUCache.inv_freqs)
#print(lRUCache.get(1)) 
lRUCache.put(1, 5)
#print(lRUCache.data, lRUCache.freqs, lRUCache.inv_freqs)
lRUCache.put(1, 2)
#print(lRUCache.data, lRUCache.freqs, lRUCache.inv_freqs)
#print(lRUCache.get(1)) 
assert lRUCache.get(2) == 6

Much nicer

In [47]:
from collections import OrderedDict
class LRUCache(OrderedDict):

    def __init__(self, capacity):
        """
        :type capacity: int
        """
        self.capacity = capacity

    def get(self, key):
        """
        :type key: int
        :rtype: int
        """
        if key not in self:
            return - 1
        
        self.move_to_end(key)
        return self[key]

    def put(self, key, value):
        """
        :type key: int
        :type value: int
        :rtype: void
        """
        if key in self:
            self.move_to_end(key)
        self[key] = value
        if len(self) > self.capacity:
            self.popitem(last = False)

In [48]:
class DLinkedNode(): 
    def __init__(self):
        self.key = 0
        self.value = 0
        self.prev = None
        self.next = None
            
class LRUCache():
    def _add_node(self, node):
        """
        Always add the new node right after head.
        """
        node.prev = self.head
        node.next = self.head.next

        self.head.next.prev = node
        self.head.next = node

    def _remove_node(self, node):
        """
        Remove an existing node from the linked list.
        """
        prev = node.prev
        new = node.next

        prev.next = new
        new.prev = prev

    def _move_to_head(self, node):
        """
        Move certain node in between to the head.
        """
        self._remove_node(node)
        self._add_node(node)

    def _pop_tail(self):
        """
        Pop the current tail.
        """
        res = self.tail.prev
        self._remove_node(res)
        return res

    def __init__(self, capacity):
        """
        :type capacity: int
        """
        self.cache = {}
        self.size = 0
        self.capacity = capacity
        self.head, self.tail = DLinkedNode(), DLinkedNode()

        self.head.next = self.tail
        self.tail.prev = self.head
        

    def get(self, key):
        """
        :type key: int
        :rtype: int
        """
        node = self.cache.get(key, None)
        if not node:
            return -1

        # move the accessed node to the head;
        self._move_to_head(node)

        return node.value

    def put(self, key, value):
        """
        :type key: int
        :type value: int
        :rtype: void
        """
        node = self.cache.get(key)

        if not node: 
            newNode = DLinkedNode()
            newNode.key = key
            newNode.value = value

            self.cache[key] = newNode
            self._add_node(newNode)

            self.size += 1

            if self.size > self.capacity:
                # pop the tail
                tail = self._pop_tail()
                del self.cache[tail.key]
                self.size -= 1
        else:
            # update the value.
            node.value = value
            self._move_to_head(node)

In [49]:
lRUCache = LRUCache(capacity=2)
lRUCache.put(1, 1)
lRUCache.put(2, 2)
#print(lRUCache.get(1))
lRUCache.put(3, 3)
#print(lRUCache.get(2))
lRUCache.put(4, 4)
#print(lRUCache.get(1)) 
#print(lRUCache.get(3))
assert lRUCache.get(4) == 4