In [1]:
import hashlib
from bisect import bisect_right
from typing import List, Dict, Any, Optional

class ConsistentHash:
    def __init__(self, physical_nodes, virtual_nodes = 3):
        """
        Initialize the consistent hash ring.
        
        Args:
            physical_nodes: List of physical node names/IDs
            virtual_nodes: Number of virtual nodes per physical node
        """
        self.virtual_nodes = virtual_nodes
        self.ring = {}  # Maps hash positions to node names
        self.sorted_keys = []  # Maintains sorted hash positions
        
        # Add all physical nodes and their virtual nodes to the ring
        self._add_nodes(physical_nodes)
    
    def _hash(self, key):
        """
        Generate a hash value for a given key using MD5.
        
        Args:
            key: String to be hashed
        Returns:
            Integer hash value
        """
        md5_hash = hashlib.md5(key.encode('utf-8'))
        return int(md5_hash.hexdigest(), 16)
    
    def _add_nodes(self, nodes):
        """
        Add physical nodes and their virtual nodes to the hash ring.
        
        Args:
            nodes: List of physical node names/IDs
        """
        for node in nodes:
            # Add the physical node
            self._add_node(node)
            
            # Add virtual nodes
            for i in range(self.virtual_nodes):
                virtual_node = f"{node}_{i}"
                self._add_node(virtual_node, node)
    
    def _add_node(self, node, physical_node = None):
        """
        Add a single node to the hash ring.
        
        Args:
            node: Node name/ID to add
            physical_node: Physical node name if this is a virtual node
        """
        hash_value = self._hash(node)
        self.ring[hash_value] = physical_node or node
        self.sorted_keys = sorted(self.ring.keys())
    
    def get_node(self, key):
        """
        Get the node responsible for the given key.
        
        Args:
            key: Key to look up
        Returns:
            Node name/ID responsible for the key
        """
        if not self.ring:
            raise ValueError("Hash ring is empty")
        
        # Generate hash of the key
        hash_value = self._hash(key)
        
        # Find the first node in the ring that comes after the hash value
        index = bisect_right(self.sorted_keys, hash_value) % len(self.sorted_keys)
        return self.ring[self.sorted_keys[index]]
    
    def add_node(self, node):
        """
        Add a new physical node and its virtual nodes to the ring.
        
        Args:
            node: Node name/ID to add
        """
        self._add_nodes([node])
    
    def remove_node(self, node):
        """
        Remove a physical node and its virtual nodes from the ring.
        
        Args:
            node: Node name/ID to remove
        """
        # Remove physical and virtual nodes
        keys_to_remove = []
        for hash_value, stored_node in self.ring.items():
            if stored_node == node or stored_node.startswith(f"{node}_"):
                keys_to_remove.append(hash_value)
        
        for hash_value in keys_to_remove:
            del self.ring[hash_value]
        
        self.sorted_keys = sorted(self.ring.keys())

In [10]:
# Initialize with some physical nodes
physical_nodes = ["node1", "node2", "node3"]
ch = ConsistentHash(physical_nodes, virtual_nodes=50)

# Test some key distributions
test_keys = [f"user{n}" for n in range(1000)]

keycounts = {}
print("Initial distribution:")
for key in test_keys:
    node = ch.get_node(key)
    if node in keycounts:
        keycounts[node] += 1
    else:
        keycounts[node] = 1
    #print(f"Key: {key} -> Node: {node}")
print(keycounts)

# Add a new node
print("\nAdding node4...")
ch.add_node("node4")

print("\nDistribution after adding node4:")
keycounts2 = {}
for key in test_keys:
    node = ch.get_node(key)
    if node in keycounts2:
        keycounts2[node] += 1
    else:
        keycounts2[node] = 1
    #print(f"Key: {key} -> Node: {node}")

print(keycounts2)
# Remove a node
print("\nRemoving node2...")
ch.remove_node("node2")

print("\nDistribution after removing node2:")
keycounts3 = {}
for key in test_keys:
    node = ch.get_node(key)
    if node in keycounts3:
        keycounts3[node] += 1
    else:
        keycounts3[node] = 1
    #print(f"Key: {key} -> Node: {node}")
print(keycounts3)

Initial distribution:
{'node3': 294, 'node1': 357, 'node2': 349}

Adding node4...

Distribution after adding node4:
{'node4': 233, 'node1': 265, 'node2': 281, 'node3': 221}

Removing node2...

Distribution after removing node2:
{'node4': 301, 'node1': 383, 'node3': 316}
