In [None]:
import random
import string
import time
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Any, Optional


class SkipNode:
    """A node in the Skip List"""
    
    def __init__(self, key: Any, value: Any, level: int):
        self.key = key
        self.value = value  # Could be a list of values if there are duplicates
        self.forward = [None] * (level + 1)  # Array of pointers for each level
        
    def __repr__(self):
        return f"SkipNode(key={self.key}, value={self.value})"


class SkipList:
    """Skip list implementation with search, insert, and delete operations"""
    
    def __init__(self, max_level: int = 16, p: float = 0.5):
        self.max_level = max_level
        self.p = p
        self.level = 0
        self.head = SkipNode(None, None, max_level)
    
    def random_level(self) -> int:
        level = 0
        while random.random() < self.p and level < self.max_level:
            level += 1
        return level
    
    def search(self, key: Any) -> Optional[Any]:
        current = self.head
        for i in range(self.level, -1, -1):
            while current.forward[i] and self._compare_keys(current.forward[i].key, key) < 0:
                current = current.forward[i]
        current = current.forward[0]
        if current and self._compare_keys(current.key, key) == 0:
            return current.value
        return None
    
    def insert(self, key: Any, value: Any):
        update = [None] * (self.max_level + 1)
        current = self.head
        for i in range(self.level, -1, -1):
            while current.forward[i] and self._compare_keys(current.forward[i].key, key) < 0:
                current = current.forward[i]
            update[i] = current
        current = current.forward[0]
        if current and self._compare_keys(current.key, key) == 0:
            if isinstance(current.value, list):
                current.value.append(value)
            else:
                current.value = [current.value, value]
            return
        new_level = self.random_level()
        if new_level > self.level:
            for i in range(self.level + 1, new_level + 1):
                update[i] = self.head
            self.level = new_level
        new_node = SkipNode(key, value, new_level)
        for i in range(new_level + 1):
            new_node.forward[i] = update[i].forward[i]
            update[i].forward[i] = new_node
    
    def delete(self, key: Any):
        update = [None] * (self.max_level + 1)
        current = self.head
        for i in range(self.level, -1, -1):
            while current.forward[i] and self._compare_keys(current.forward[i].key, key) < 0:
                current = current.forward[i]
            update[i] = current
        current = current.forward[0]
        if current and self._compare_keys(current.key, key) == 0:
            for i in range(self.level + 1):
                if update[i].forward[i] != current:
                    break
                update[i].forward[i] = current.forward[i]
            while self.level > 0 and self.head.forward[self.level] is None:
                self.level -= 1
    
    def _compare_keys(self, key1: Any, key2: Any) -> int:
        if key1 is None and key2 is None:
            return 0
        elif key1 is None:
            return -1
        elif key2 is None:
            return 1
        else:
            return (key1 > key2) - (key1 < key2)
    
    def __iter__(self):
        current = self.head.forward[0]
        while current:
            yield (current.key, current.value)
            current = current.forward[0]
    
    def __repr__(self):
        elements = list(self)
        return f"SkipList({elements})"
    
def skip_list_join(left_data: List[Dict], right_data: List[Dict], 
                   left_key: str, right_key: str) -> List[Dict]:
    """
    Perform an inner join using a skip list to index the right table
    """
    # Build a skip list index on the right table
    skip_list = SkipList()
    start_build = time.time()
    
    for record in right_data:
        key = record[right_key]
        skip_list.insert(key, record)
    
    build_time = time.time() - start_build
    print(f"Skip list build time: {build_time:.6f} seconds")
    
    # Perform the join by searching the skip list for each record in the left table
    start_join = time.time()
    result = []
    
    for left_record in left_data:
        left_val = left_record[left_key]
        right_records = skip_list.search(left_val)
        
        if right_records:
            # Handle case where we have multiple matches
            if isinstance(right_records, list):
                for right_record in right_records:
                    joined_record = {**left_record, **right_record}
                    result.append(joined_record)
            else:
                joined_record = {**left_record, **right_records}
                result.append(joined_record)
    
    join_time = time.time() - start_join
    print(f"Skip list join time: {join_time:.6f} seconds")
    print(f"Total skip list join time: {build_time + join_time:.6f} seconds")
    
    return result


def hash_join(left_data: List[Dict], right_data: List[Dict], 
              left_key: str, right_key: str) -> List[Dict]:
    """
    Perform an inner join using a hash table to index the right table
    """
    # Build a hash table index on the right table
    start_build = time.time()
    hash_table = {}
    
    for record in right_data:
        key = record[right_key]
        if key in hash_table:
            hash_table[key].append(record)
        else:
            hash_table[key] = [record]
    
    build_time = time.time() - start_build
    print(f"Hash table build time: {build_time:.6f} seconds")
    
    # Perform the join by searching the hash table for each record in the left table
    start_join = time.time()
    result = []
    
    for left_record in left_data:
        left_val = left_record[left_key]
        if left_val in hash_table:
            for right_record in hash_table[left_val]:
                joined_record = {**left_record, **right_record}
                result.append(joined_record)
    
    join_time = time.time() - start_join
    print(f"Hash join time: {join_time:.6f} seconds")
    print(f"Total hash join time: {build_time + join_time:.6f} seconds")
    
    return result


def pandas_join(left_data: List[Dict], right_data: List[Dict], 
                left_key: str, right_key: str) -> List[Dict]:
    """
    Perform an inner join using pandas
    """
    start_time = time.time()
    
    # Convert lists of dictionaries to pandas DataFrames
    left_df = pd.DataFrame(left_data)
    right_df = pd.DataFrame(right_data)
    
    # Perform the join
    result_df = pd.merge(left_df, right_df, left_on=left_key, right_on=right_key)
    
    # Convert the result back to a list of dictionaries
    result = result_df.to_dict('records')
    
    end_time = time.time()
    print(f"Pandas join time: {end_time - start_time:.6f} seconds")
    
    return result


def nested_loop_join(left_data: List[Dict], right_data: List[Dict], 
                     left_key: str, right_key: str) -> List[Dict]:
    """
    Perform a simple nested loop join
    """
    start_time = time.time()
    result = []
    
    for left_record in left_data:
        for right_record in right_data:
            if left_record[left_key] == right_record[right_key]:
                joined_record = {**left_record, **right_record}
                result.append(joined_record)
    
    end_time = time.time()
    print(f"Nested loop join time: {end_time - start_time:.6f} seconds")
    
    return result


def generate_test_data(size_left: int, size_right: int, 
                       key_range: int, skewed: bool = False, 
                       string_keys: bool = False, null_ratio: float = 0.0,
                       seed: int = 42) -> Tuple[List[Dict], List[Dict]]:
    random.seed(seed)
    np.random.seed(seed)
    
    left_data = []
    right_data = []
    
    for i in range(size_left):
        if string_keys:
            key = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
        else:
            key = random.randint(1, key_range) if not skewed or random.random() >= 0.9 else random.randint(1, key_range // 10)
        
        if random.random() < null_ratio:
            key = None
        
        record = {
            'id': i,
            'join_key': key,
            'value_left': f"left_value_{i}",
            'extra_attr': random.random()
        }
        left_data.append(record)
    
    for i in range(size_right):
        if string_keys:
            key = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
        else:
            key = random.randint(1, key_range) if not skewed or random.random() >= 0.9 else random.randint(1, key_range // 10)
        
        if random.random() < null_ratio:
            key = None
        
        record = {
            'id': i,
            'join_key': key,
            'value_right': f"right_value_{i}",
            'extra_attr': random.random()
        }
        right_data.append(record)
    
    return left_data, right_data

    
def test_joins():
    """
    Test different join algorithms with various data sizes and distributions
    """
    sizes = [
        ("Large", 100000, 100000, 50000),
    ]
    
    for size_name, left_size, right_size, key_range in sizes:
        print(f"\n=== {size_name} Dataset Test (Uniform Distribution) ===")
        left_data, right_data = generate_test_data(left_size, right_size, key_range)
        
        skip_result = skip_list_join(left_data, right_data, 'join_key', 'join_key')
        hash_result = hash_join(left_data, right_data, 'join_key', 'join_key')
        pandas_result = pandas_join(left_data, right_data, 'join_key', 'join_key')
        
        print(f"Skip list join result size: {len(skip_result)}")
        print(f"Hash join result size: {len(hash_result)}")
        print(f"Pandas join result size: {len(pandas_result)}")
        
        if size_name == "Small":
            nested_result = nested_loop_join(left_data, right_data, 'join_key', 'join_key')
            print(f"Nested loop join result size: {len(nested_result)}")
        
        print(f"\n=== {size_name} Dataset Test (Skewed Distribution) ===")
        left_data, right_data = generate_test_data(left_size, right_size, key_range, skewed=True)
        
        skip_result = skip_list_join(left_data, right_data, 'join_key', 'join_key')
        hash_result = hash_join(left_data, right_data, 'join_key', 'join_key')
        pandas_result = pandas_join(left_data, right_data, 'join_key', 'join_key')
        
        print(f"Skip list join result size: {len(skip_result)}")
        print(f"Hash join result size: {len(hash_result)}")
        print(f"Pandas join result size: {len(pandas_result)}")
        
        if size_name == "Small":
            nested_result = nested_loop_join(left_data, right_data, 'join_key', 'join_key')
            print(f"Nested loop join result size: {len(nested_result)}")
        
        print(f"\n=== {size_name} Dataset Test (String Keys) ===")
        left_data, right_data = generate_test_data(left_size, right_size, key_range, string_keys=True)
        
        skip_result = skip_list_join(left_data, right_data, 'join_key', 'join_key')
        hash_result = hash_join(left_data, right_data, 'join_key', 'join_key')
        pandas_result = pandas_join(left_data, right_data, 'join_key', 'join_key')
        
        print(f"Skip list join result size: {len(skip_result)}")
        print(f"Hash join result size: {len(hash_result)}")
        print(f"Pandas join result size: {len(pandas_result)}")
        
        if size_name == "Small":
            nested_result = nested_loop_join(left_data, right_data, 'join_key', 'join_key')
            print(f"Nested loop join result size: {len(nested_result)}")
        
        print(f"\n=== {size_name} Dataset Test (Null Values) ===")
        left_data, right_data = generate_test_data(left_size, right_size, key_range, null_ratio=0.1)
        
        skip_result = skip_list_join(left_data, right_data, 'join_key', 'join_key')
        hash_result = hash_join(left_data, right_data, 'join_key', 'join_key')
        pandas_result = pandas_join(left_data, right_data, 'join_key', 'join_key')
        
        print(f"Skip list join result size: {len(skip_result)}")
        print(f"Hash join result size: {len(hash_result)}")
        print(f"Pandas join result size: {len(pandas_result)}")
        
        if size_name == "Small":
            nested_result = nested_loop_join(left_data, right_data, 'join_key', 'join_key')
            print(f"Nested loop join result size: {len(nested_result)}")
            
if __name__ == "__main__":
    test_joins()


=== Large Dataset Test (Uniform Distribution) ===
Skip list build time: 0.664407 seconds
Skip list join time: 0.720558 seconds
Total skip list join time: 1.384965 seconds
Hash table build time: 0.050177 seconds
Hash join time: 0.123285 seconds
Total hash join time: 0.173462 seconds
Pandas join time: 0.491829 seconds
Skip list join result size: 201202
Hash join result size: 201202
Pandas join result size: 201202

=== Large Dataset Test (Skewed Distribution) ===
Skip list build time: 0.498172 seconds
Skip list join time: 1.177492 seconds
Total skip list join time: 1.675664 seconds
Hash table build time: 0.062106 seconds
Hash join time: 0.670021 seconds
Total hash join time: 0.732127 seconds
Pandas join time: 3.102796 seconds
Skip list join result size: 1656293
Hash join result size: 1656293
Pandas join result size: 1656293

=== Large Dataset Test (String Keys) ===
Skip list build time: 0.964959 seconds
Skip list join time: 0.735104 seconds
Total skip list join time: 1.700063 seconds
Has