# Implement stack with max API

Design a max stack that supports push, pop, top, peek_max and pop_max.    

push(x) -- Push element x onto stack.    
pop() -- Remove the element on top of the stack and return it.    
peek() -- Get the element on the top.    
peek_max() -- Retrieve the maximum element in the stack.    
pop_max() -- Retrieve the maximum element in the stack, and remove it. If you find more than one maximum elements, only remove the top-most one.
   
**Example 1:**  

MaxStack stack = new MaxStack()  
stack.push(5)   
stack.push(1)    
stack.push(5)    
stack.top() -> 5   
stack.pop_max() -> 5   
stack.top() -> 1    
stack.peek_max() -> 5    
stack.pop() -> 1   
stack.top() -> 5   


**Reference:** 

[LeetCode: 716. Max Stack](https://leetcode.com/problems/max-stack/)   
[Elements Of Programming interviews](https://www.amazon.com/Elements-Programming-Interviews-Python-Insiders/dp/1537713949#:~:text=If%20so%2C%20you%20need%20to,asked%20at%20leading%20software%20companies.): 8. Stacks and Queues - 8.1 Implement stack with max API

## Two stacks
* Time complexity:
    * push, peek, pop, peek_max: O(1)
    * pop_max: O(n)
* Additional space complexity for the stack storing max elements: O(n)

In [45]:
class MaxStack:
    
    def __init__(self):
        self.data = []
        self.max_elements = []
        
    # O(1)
    def push(self, val):
        self.data.append(val)
        
        if len(self.max_elements) == 0 or val >= self.max_elements[-1]:
            self.max_elements.append(val)
        
    # O(1)
    def peek(self):
        if len(self.data) == 0:
            raise IndexError('peak(): empty stack')
            
        return self.data[-1]
    
    # O(1)
    def pop(self):
        if len(self.data) == 0:
            raise IndexError('pop(): empty stack')

        val = self.data.pop()
        if val == self.max_elements[-1]:
            self.max_elements.pop()
            
        return val
       
    # O(1)
    def peek_max(self):
        if len(self.data) == 0:
            raise IndexError('peak_max(): empty stack')
        
        return self.max_elements[-1]
    
    # O(n), n - number of stack elements
    def pop_max(self):
        if len(self.data) == 0:
            raise IndexError('pop_max(): empty stack')
            
        idx = len(self.data) - 1
        tmp_stack = []
        while idx >= 0:
            if self.data[idx] != self.max_elements[-1]:
                tmp_stack.append(self.data.pop())
                idx -= 1
            else:
                self.data.pop()
                self.max_elements.pop()
                break
            
        for element in reversed(tmp_stack):
            self.push(element)

## Test

In [46]:
stack = MaxStack()

print('\npush 1')
stack.push(1)
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npush 2')
stack.push(2)
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npush 3')
stack.push(3)
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npush 4')
stack.push(4)
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npush 1')
stack.push(1)
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npush 2')
stack.push(2)
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npop max (' + str(stack.peek_max()) + ')')
stack.pop_max()
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npop (' + str(stack.peek()) + ')')
stack.pop()
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npop (' + str(stack.peek()) + ')')
stack.pop()
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))

print('\npop (' + str(stack.peek()) + ')')
stack.pop()
print('peek: ' + str(stack.peek()))
print('max: ' + str(stack.peek_max()))
print('data:' + str(stack.data))


push 1
peek: 1
max: 1
data:[1]

push 2
peek: 2
max: 2
data:[1, 2]

push 3
peek: 3
max: 3
data:[1, 2, 3]

push 4
peek: 4
max: 4
data:[1, 2, 3, 4]

push 1
peek: 1
max: 4
data:[1, 2, 3, 4, 1]

push 2
peek: 2
max: 4
data:[1, 2, 3, 4, 1, 2]

pop max (4)
peek: 2
max: 3
data:[1, 2, 3, 1, 2]

pop (2)
peek: 1
max: 3
data:[1, 2, 3, 1]

pop (1)
peek: 3
max: 3
data:[1, 2, 3]

pop (3)
peek: 2
max: 2
data:[1, 2]
