Guided project to implement a b-tree class for storing indexs and values. The goal is to implement bracket indexing and range queries, with functionality for non-integer keys.

In [1]:
from btree import BTree

class KVStore(BTree):
    def __init__(self):
        super().__init__(2)  # implementing a binary tree
        
    def add(self, key, value):
        existing_node = self._find_node(self.root, key)
        if existing_node:
            for i, k in enumerate(existing_node.keys):  # copying the Node.get_value method
                if k == key:
                    existing_node.values[i] = value
        else:
            super().add(key, value)
    
    def __setitem__(self, key, value):
        self.add(key, value)
    
    def __getitem__(self, key):
        return self.get_value(key)
    
    def __contains__(self, key):
        return self.contains(key)
    
    def _range_query(self, range_start, range_end, current_node, min_key, max_key):
        if range_start > max_key or range_end < min_key:
            return []
        results = []
        for i, key in enumerate(current_node.keys):
            if range_start <= key and key <= range_end:
                results.append(current_node.values[i])
        if not current_node.is_leaf():
            for i, child in enumerate(current_node.children):
                new_min_key = current_node.keys[i - 1] if i > 0 else min_key
                new_max_key = current_node.keys[i] if i < len(current_node) else max_key
                results += self._range_query(range_start, range_end, child, new_min_key, new_max_key)
        return results 

    def range_query(self, range_start, range_end):
        if not self.root.keys:
            return []

        first_key = self.root.keys[0]
        if isinstance(first_key, (int, float)):  # Perform numeric range query
            return self._range_query(range_start, range_end, self.root, float('-inf'), float('inf'))
        elif isinstance(first_key, str):  # Perform string range query
            return self._range_query(range_start, range_end, self.root, '', 'zzzzzzz')
        else:
            raise TypeError("Unsupported key type")


In [2]:
# Unit testing of the KVStore class
def test_split_threshold():
    kv = KVStore()
    assert kv.split_threshold == 2

def test_add_value():
    key, value = 3, 'hello world'
    kv = KVStore()
    kv.add(key, value)
    assert kv.get_value(key) == value

def test_add_overwrite():
    key1, val1 = 3, 'hello world'
    val2 = 'goodbye'
    kv = KVStore()
    kv.add(key1, val1)
    kv.add(key1, val2)
    assert kv.get_value(key1) == val2
    
def test_set_and_get():
    key, value = 5, 'hello world'
    kv = KVStore()
    kv[key] = value
    assert kv[key] == value

def test_setget_overwrite():
    key1, val1 = 5, 'hello world'
    val2 = 'goodbye'
    kv = KVStore()
    kv[key1] = val1
    kv[key1] = val2
    assert kv[key1] == val2

def test_contains():
    key, value = 5, 'hello world'
    kv = KVStore()
    kv[key] = value
    assert key in kv
    
def test_contains_str():
    key, value = 'hello', 'world'  # cannot currently contain mix of str and int
    key1, val1 = 'foo', 'bar'
    kv = KVStore()
    kv[key] = value
    kv[key1] = val1
    assert key in kv

def test_range_query_str():
    kv = KVStore()
    kv["apple"] = 'fruit'
    kv["banana"] = 'fruit'
    kv["carrot"] = 'vegetable'
    kv["date"] = 'fruit'
    expected = ['fruit', 'vegetable', 'fruit']
    assert kv.range_query("banana", "date") == expected

test_split_threshold()
test_add_value()
test_add_overwrite()
test_set_and_get()
test_setget_overwrite()
test_contains()
test_contains_str()
test_range_query_str()

In [3]:
class DictKVStore(dict):

    def range_query(self, range_start, range_end):
        result = []
        for key in self.keys():
            if range_start <= key and key <= range_end:
                result.append(self[key])
        return result
    
dict_kv = DictKVStore()
our_kv = KVStore()
for i in range(10):
    dict_kv[i] = i
    our_kv[i] = i

for range_start, range_end in [(1, 3), (4, 6), (1, 10), (5, 5)]:
    dict_res = sorted(dict_kv.range_query(range_start, range_end))
    our_res = sorted(our_kv.range_query(range_start, range_end))
    assert dict_res == our_res, "Both data structures return the same range query result."

In [11]:
# performance comparison of KVStore versus DictKVStore
import time
import matplotlib.pyplot as plt
from random import randint

def timit(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        ret = func(*args, **kwargs)
        end = time.time()
        print(f'Function {func.__name__} took {end-start} seconds.')
        return ret
    return wrapper

def time_run(func, *args, **kwargs):
    start = time.time()
    ret = func(*args, **kwargs)
    end = time.time()
    return end - start  # throwing away the return value, only interested in time to complete

@timit
def insert_randomvals(struct, num_vals):
    index = struct()
    for i in range(num_vals):
        index[i] = randint(0, 10**5)
    return index

In [12]:
import csv

def entries_generator():
    with open('entries.csv', 'r') as file:
        reader = csv.reader(file)
        header = next(reader)
        for row in reader:
            key, val = int(row[0]), int(row[1])
            yield (key, val)

def queries_generator():
    with open('queries.csv', 'r') as file:
        reader = csv.reader(file)
        header = next(reader)
        for row in reader:
            range_start, range_end = int(row[0]), int(row[1])
            yield (range_start, range_end)

@timit
def insert_entries_csv(struct):
    index = struct()
    for key, val in entries_generator():
        index[key] = val
    return index

@timit
def query_ranges_csv(struct):
    for range_start, range_end in queries_generator():
        struct.range_query(range_start, range_end)

our_kv = insert_entries_csv(KVStore)
dict_kv = insert_entries_csv(DictKVStore)

query_ranges_csv(our_kv)
query_ranges_csv(dict_kv)

Function insert_entries_csv took 3.520458936691284 seconds.
Function insert_entries_csv took 0.05988717079162598 seconds.
Function query_ranges_csv took 1.2807974815368652 seconds.
Function query_ranges_csv took 4.563742160797119 seconds.


In [15]:
# graphing the times it takes to execute the range queries
# NEED TO EXECUTE PREVIOUS CELL

time_ratios = []

with open('queries.csv', 'r') as file:
    reader = csv.reader(file)
    header = next(reader)
    for row in reader:
        range_start, range_end = int(row[0]), int(row[1])
        our_time = time_run(our_kv.range_query, range_start, range_end)
        dict_time = time_run(dict_kv.range_query, range_start, range_end)
        ratio = our_time / dict_time
        time_ratios.append(ratio)

plt.plot(time_ratios)
plt.xlabel('Query range result size')
plt.ylabel('Runtime ratio')
plt.show()

AttributeError: 'list' object has no attribute 'tolist'

In [23]:
time_ratios

[0.05597233306730513,
 0.025516018741294163,
 0.03250329687378791,
 0.022041828738936807,
 0.03111035692064031,
 0.02397909514781985,
 0.02533792465436604,
 0.03221772710555351,
 0.02442616290879545,
 0.02500254608412262,
 0.02754534587599397,
 0.02852523386072295,
 0.02919351689123218,
 0.03459489767311466,
 0.0347065299264152,
 0.039003550821127386,
 0.03203943314849045,
 0.03451735902067028,
 0.036377795082879076,
 0.039911456935873356,
 0.03716437459070072,
 0.04003096317593719,
 0.04561962091438237,
 0.03193520094928546,
 0.03614764963801367,
 0.04433017591339648,
 0.043007825791085405,
 0.04429281831849196,
 0.036213696665471494,
 0.03746562786434464,
 0.03402954661232807,
 0.047542702102731704,
 0.04588143805706765,
 0.043000914913083256,
 0.05073957322987391,
 0.046493649917172836,
 0.05340744320927984,
 0.04479253981559095,
 0.04299321159816871,
 0.06084304371627434,
 0.052926703667677516,
 0.042572594192464605,
 0.05228866434763403,
 0.05601634067838574,
 0.04431247144814984,

In [22]:
plt.plot(time_ratios)
plt.xlabel('Query range result size')
plt.ylabel('Runtime ratio')

Text(0, 0.5, 'Runtime ratio')

SUGGESTIONS FOR IMPROVEMENT FROM DATAQUEST:
    Implement the __iter__() method to make it iterable. With this method, users will be able to iterate over all keys using for key in kv:, where kv in an instance of KVStore.

Implement the save() and load() method to save and load the KVStore into a file.

Make each node keep track of the number of keys it stores in its subtree. Use that to implement a range_count() method that counts the number of entries in a range. It's possible to make this query run in O(log(n)) time.

Implement a B+ tree to replace the underlying b-tree.