# Question 358

## Description

Create a data structure that performs all the following operations in O(1) time:

* plus: Add a key with value 1. If the key already exists, increment its value by one.
* minus: Decrement the value of a key. If the key's value is currently 1, remove it.
* get_max: Return a key with the highest value.
* get_min: Return a key with the lowest value.

## Solution

use a combination of data structures:

HashMap (Dictionary in Python) to store the value associated with each key.
Doubly Linked List to maintain the order of values. Each node in the list will store a value and a set of keys with that value.
HashMap (Dictionary in Python) to store pointers to nodes in the doubly linked list corresponding to a particular value.

In [1]:
class Node:
    def __init__(self, value):
        self.value = value
        self.keys = set()  # keys that map to this value
        self.next = None
        self.prev = None


class DataStructure:
    def __init__(self):
        self.key_value_map = {}
        self.value_node_map = {}
        self.head = None
        self.tail = None

    def _remove_node(self, node):
        if node.prev:
            node.prev.next = node.next
        else:
            self.head = node.next

        if node.next:
            node.next.prev = node.prev
        else:
            self.tail = node.prev

        del self.value_node_map[node.value]

    def _insert_after(self, node, new_node):
        new_node.prev = node
        new_node.next = node.next
        if node.next:
            node.next.prev = new_node
        else:
            self.tail = new_node
        node.next = new_node

    def _insert_before(self, node, new_node):
        new_node.next = node
        new_node.prev = node.prev
        if node.prev:
            node.prev.next = new_node
        else:
            self.head = new_node
        node.prev = new_node

    def plus(self, key):
        if key not in self.key_value_map:
            self.key_value_map[key] = 1
            if 1 not in self.value_node_map:
                new_node = Node(1)
                if not self.head:
                    self.head = self.tail = new_node
                else:
                    self._insert_before(self.head, new_node)
                self.value_node_map[1] = new_node
            self.value_node_map[1].keys.add(key)
        else:
            value = self.key_value_map[key]
            self.key_value_map[key] += 1
            node = self.value_node_map[value]
            node.keys.remove(key)

            if value + 1 not in self.value_node_map:
                new_node = Node(value + 1)
                self._insert_after(node, new_node)
                self.value_node_map[value + 1] = new_node
            else:
                new_node = self.value_node_map[value + 1]

            new_node.keys.add(key)

            if not node.keys:
                self._remove_node(node)

    def minus(self, key):
        if key in self.key_value_map:
            value = self.key_value_map[key]
            node = self.value_node_map[value]
            node.keys.remove(key)

            if value == 1:
                del self.key_value_map[key]
            else:
                self.key_value_map[key] -= 1
                if value - 1 not in self.value_node_map:
                    new_node = Node(value - 1)
                    self._insert_before(node, new_node)
                    self.value_node_map[value - 1] = new_node
                else:
                    new_node = self.value_node_map[value - 1]
                new_node.keys.add(key)

            if not node.keys:
                self._remove_node(node)

    def get_max(self):
        return None if not self.tail else next(iter(self.tail.keys))

    def get_min(self):
        return None if not self.head else next(iter(self.head.keys))

In [2]:
# Testing
ds = DataStructure()
ds.plus("a")
ds.plus("a")
ds.plus("b")
print(ds.get_max())  # "a"
print(ds.get_min())  # "b"
ds.minus("a")
print(ds.get_max())  # "a" or "b"
print(ds.get_min())  # "a" or "b"

a
b
b
b
