### Skip List Join w/ Skewed Data

In [1]:
import random
import time
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Any, Optional


class SkipNode:
    """A node in the Skip List"""
    
    def __init__(self, key: Any, value: Any, level: int):
        self.key = key
        self.value = value  # Could be a list of values if there are duplicates
        self.forward = [None] * (level + 1)  # Array of pointers for each level
        
    def __repr__(self):
        return f"SkipNode(key={self.key}, value={self.value})"


class SkipList:
    """Skip list implementation with search, insert, and delete operations"""
    
    def __init__(self, max_level: int = 16, p: float = 0.5):
        self.max_level = max_level  # Maximum level of the skip list
        self.p = p  # Probability of promoting to next level
        self.level = 0  # Current maximum level of skip list
        
        # Create head node with key set to None (will be smaller than all real keys)
        self.head = SkipNode(None, None, max_level)
    
    def random_level(self) -> int:
        """Randomly determine the level for a new node"""
        level = 0
        while random.random() < self.p and level < self.max_level:
            level += 1
        return level
    
    def search(self, key: Any) -> Optional[Any]:
        """Search for a key in the skip list"""
        current = self.head
        
        # Start from the highest level and work down
        for i in range(self.level, -1, -1):
            # Move forward at the current level as far as possible
            while current.forward[i] and current.forward[i].key < key:
                current = current.forward[i]
        
        # Move to the node right after the last smaller key
        current = current.forward[0]
        
        # Return the value if the key matches, otherwise None
        if current and current.key == key:
            return current.value
        return None
    
    def search_range(self, start_key: Any, end_key: Any) -> List[Tuple[Any, Any]]:
        """Search for keys in the given range (inclusive)"""
        result = []
        
        # Find the first node greater than or equal to start_key
        current = self.head
        for i in range(self.level, -1, -1):
            while current.forward[i] and current.forward[i].key < start_key:
                current = current.forward[i]
        
        # Move to the first node in range
        current = current.forward[0]
        
        # Collect all nodes in range
        while current and current.key <= end_key:
            result.append((current.key, current.value))
            current = current.forward[0]
            
        return result
    
    def insert(self, key: Any, value: Any):
        """Insert a new key-value pair into the skip list"""
        # Array to track updates at each level
        update = [None] * (self.max_level + 1)
        current = self.head
        
        # Find the position to insert the new node
        for i in range(self.level, -1, -1):
            while current.forward[i] and current.forward[i].key < key:
                current = current.forward[i]
            update[i] = current
        
        # Move to the next node
        current = current.forward[0]
        
        # If key already exists, update the value
        if current and current.key == key:
            if isinstance(current.value, list):
                current.value.append(value)
            else:
                current.value = [current.value, value]
            return
        
        # Generate a random level for the new node
        new_level = self.random_level()
        
        # Update the skip list's level if the new level is higher
        if new_level > self.level:
            for i in range(self.level + 1, new_level + 1):
                update[i] = self.head
            self.level = new_level
        
        # Create a new node
        new_node = SkipNode(key, value, new_level)
        
        # Insert the new node by updating the forward links
        for i in range(new_level + 1):
            new_node.forward[i] = update[i].forward[i]
            update[i].forward[i] = new_node
    
    def delete(self, key: Any):
        """Delete a key from the skip list"""
        update = [None] * (self.max_level + 1)
        current = self.head
        
        # Find the position of the node to delete
        for i in range(self.level, -1, -1):
            while current.forward[i] and current.forward[i].key < key:
                current = current.forward[i]
            update[i] = current
        
        # Move to the node to delete
        current = current.forward[0]
        
        # If the key exists, update the forward links to skip it
        if current and current.key == key:
            for i in range(self.level + 1):
                if update[i].forward[i] != current:
                    break
                update[i].forward[i] = current.forward[i]
            
            # Update the skip list's level if needed
            while self.level > 0 and self.head.forward[self.level] is None:
                self.level -= 1
    
    def __iter__(self):
        """Iterate through all nodes in the skip list"""
        current = self.head.forward[0]
        while current:
            yield (current.key, current.value)
            current = current.forward[0]
    
    def __repr__(self):
        """String representation of the skip list"""
        elements = list(self)
        return f"SkipList({elements})"


def skip_list_join(left_data: List[Dict], right_data: List[Dict], 
                   left_key: str, right_key: str) -> List[Dict]:
    """
    Perform an inner join using a skip list to index the right table
    """
    # Build a skip list index on the right table
    skip_list = SkipList()
    start_build = time.time()
    
    for record in right_data:
        key = record[right_key]
        skip_list.insert(key, record)
    
    build_time = time.time() - start_build
    print(f"Skip list build time: {build_time:.6f} seconds")
    
    # Perform the join by searching the skip list for each record in the left table
    start_join = time.time()
    result = []
    
    for left_record in left_data:
        left_val = left_record[left_key]
        right_records = skip_list.search(left_val)
        
        if right_records:
            # Handle case where we have multiple matches
            if isinstance(right_records, list):
                for right_record in right_records:
                    joined_record = {**left_record, **right_record}
                    result.append(joined_record)
            else:
                joined_record = {**left_record, **right_records}
                result.append(joined_record)
    
    join_time = time.time() - start_join
    print(f"Skip list join time: {join_time:.6f} seconds")
    print(f"Total skip list join time: {build_time + join_time:.6f} seconds")
    
    return result


def hash_join(left_data: List[Dict], right_data: List[Dict], 
              left_key: str, right_key: str) -> List[Dict]:
    """
    Perform an inner join using a hash table to index the right table
    """
    # Build a hash table index on the right table
    start_build = time.time()
    hash_table = {}
    
    for record in right_data:
        key = record[right_key]
        if key in hash_table:
            hash_table[key].append(record)
        else:
            hash_table[key] = [record]
    
    build_time = time.time() - start_build
    print(f"Hash table build time: {build_time:.6f} seconds")
    
    # Perform the join by searching the hash table for each record in the left table
    start_join = time.time()
    result = []
    
    for left_record in left_data:
        left_val = left_record[left_key]
        if left_val in hash_table:
            for right_record in hash_table[left_val]:
                joined_record = {**left_record, **right_record}
                result.append(joined_record)
    
    join_time = time.time() - start_join
    print(f"Hash join time: {join_time:.6f} seconds")
    print(f"Total hash join time: {build_time + join_time:.6f} seconds")
    
    return result


def pandas_join(left_data: List[Dict], right_data: List[Dict], 
                left_key: str, right_key: str) -> List[Dict]:
    """
    Perform an inner join using pandas
    """
    start_time = time.time()
    
    # Convert lists of dictionaries to pandas DataFrames
    left_df = pd.DataFrame(left_data)
    right_df = pd.DataFrame(right_data)
    
    # Perform the join
    result_df = pd.merge(left_df, right_df, left_on=left_key, right_on=right_key)
    
    # Convert the result back to a list of dictionaries
    result = result_df.to_dict('records')
    
    end_time = time.time()
    print(f"Pandas join time: {end_time - start_time:.6f} seconds")
    
    return result


def nested_loop_join(left_data: List[Dict], right_data: List[Dict], 
                     left_key: str, right_key: str) -> List[Dict]:
    """
    Perform a simple nested loop join
    """
    start_time = time.time()
    result = []
    
    for left_record in left_data:
        for right_record in right_data:
            if left_record[left_key] == right_record[right_key]:
                joined_record = {**left_record, **right_record}
                result.append(joined_record)
    
    end_time = time.time()
    print(f"Nested loop join time: {end_time - start_time:.6f} seconds")
    
    return result


def generate_test_data(size_left: int, size_right: int, 
                       key_range: int, seed: int = 42) -> Tuple[List[Dict], List[Dict]]:
    """
    Generate test data for join operations
    """
    random.seed(seed)
    np.random.seed(seed)
    
    left_data = []
    right_data = []
    
    # Generate left table
    for i in range(size_left):
        record = {
            'id': i,
            'join_key': random.randint(1, key_range),
            'value_left': f"left_value_{i}"
        }
        left_data.append(record)
    
    # Generate right table
    for i in range(size_right):
        record = {
            'id': i,
            'join_key': random.randint(1, key_range),
            'value_right': f"right_value_{i}"
        }
        right_data.append(record)
    
    return left_data, right_data


def test_joins():
    """
    Test different join algorithms with various data sizes and distributions
    """
    print("=== Small Dataset Test ===")
    left_data, right_data = generate_test_data(1000, 1000, 500)
    
    # Verify all joins produce the same results
    skip_result = skip_list_join(left_data, right_data, 'join_key', 'join_key')
    hash_result = hash_join(left_data, right_data, 'join_key', 'join_key')
    pandas_result = pandas_join(left_data, right_data, 'join_key', 'join_key')
    
    print(f"Skip list join result size: {len(skip_result)}")
    print(f"Hash join result size: {len(hash_result)}")
    print(f"Pandas join result size: {len(pandas_result)}")
    
    # Test nested loop join only on small dataset as it's very slow
    nested_result = nested_loop_join(left_data[:100], right_data[:100], 'join_key', 'join_key')
    print(f"Nested loop join result size (on 100x100 subset): {len(nested_result)}")
    
    print("\n=== Medium Dataset Test ===")
    left_data, right_data = generate_test_data(10000, 10000, 5000)
    
    skip_result = skip_list_join(left_data, right_data, 'join_key', 'join_key')
    hash_result = hash_join(left_data, right_data, 'join_key', 'join_key')
    pandas_result = pandas_join(left_data, right_data, 'join_key', 'join_key')
    
    print(f"Skip list join result size: {len(skip_result)}")
    print(f"Hash join result size: {len(hash_result)}")
    print(f"Pandas join result size: {len(pandas_result)}")
    
    print("\n=== Large Dataset Test ===")
    left_data, right_data = generate_test_data(100000, 100000, 50000)
    
    skip_result = skip_list_join(left_data, right_data, 'join_key', 'join_key')
    hash_result = hash_join(left_data, right_data, 'join_key', 'join_key')
    pandas_result = pandas_join(left_data, right_data, 'join_key', 'join_key')
    
    print(f"Skip list join result size: {len(skip_result)}")
    print(f"Hash join result size: {len(hash_result)}")
    print(f"Pandas join result size: {len(pandas_result)}")
    
    # Test with skewed data distribution
    print("\n=== Skewed Data Distribution Test ===")
    
    # Create skewed data with 80% of keys in a small range
    left_skewed = []
    right_skewed = []
    
    for i in range(50000):
        skewed_key = random.randint(1, 100) if random.random() < 0.9 else random.randint(101, 5000)
        left_record = {
            'id': i,
            'join_key': skewed_key,
            'value_left': f"left_value_{i}"
        }
        left_skewed.append(left_record)
        
        skewed_key = random.randint(1, 100) if random.random() < 0.9 else random.randint(101, 5000)
        right_record = {
            'id': i,
            'join_key': skewed_key,
            'value_right': f"right_value_{i}"
        }
        right_skewed.append(right_record)
    
    skip_result = skip_list_join(left_skewed, right_skewed, 'join_key', 'join_key')
    hash_result = hash_join(left_skewed, right_skewed, 'join_key', 'join_key')
    pandas_result = pandas_join(left_skewed, right_skewed, 'join_key', 'join_key')
    
    print(f"Skip list join result size: {len(skip_result)}")
    print(f"Hash join result size: {len(hash_result)}")
    print(f"Pandas join result size: {len(pandas_result)}")
    

if __name__ == "__main__":
    test_joins()

=== Small Dataset Test ===
Skip list build time: 0.001834 seconds
Skip list join time: 0.001843 seconds
Total skip list join time: 0.003677 seconds
Hash table build time: 0.000106 seconds
Hash join time: 0.000483 seconds
Total hash join time: 0.000589 seconds
Pandas join time: 0.021457 seconds
Skip list join result size: 1982
Hash join result size: 1982
Pandas join result size: 1982
Nested loop join time: 0.000356 seconds
Nested loop join result size (on 100x100 subset): 27

=== Medium Dataset Test ===
Skip list build time: 0.023480 seconds
Skip list join time: 0.051116 seconds
Total skip list join time: 0.074596 seconds
Hash table build time: 0.001619 seconds
Hash join time: 0.005374 seconds
Total hash join time: 0.006993 seconds
Pandas join time: 0.042420 seconds
Skip list join result size: 20346
Hash join result size: 20346
Pandas join result size: 20346

=== Large Dataset Test ===
Skip list build time: 0.379243 seconds
Skip list join time: 0.539289 seconds
Total skip list join time