# Word Embeddings in KeyDB

We will use a local Redis server running on localhost:6379. To start the database run:

keydb-server --protected-mode no --daemonize yes


In [1]:
import redis
import time
import numpy 
import plotly
import pickle
import struct
import numpy as np
from tqdm import tqdm_notebook as tqdm

# Dummy Embeddings

For testing purposes we will use randomly generated numpy arrays as dummy embbeddings.

In [2]:
def embeddings(n=1000, dim=512):
    """
    Yield n tuples of random numpy arrays of *dim* length indexed by *n*
    """
    idx = 0
    while idx < n:
        yield (str(idx), numpy.random.rand(dim))
        idx += 1

# Conversion Functions

Since we can't just save a NumPy array into the database, we will convert numpy arrays

In [3]:
def adapt_array(array):
    """
    Adapt numpy array for saving into redisDB

    :param numpy.array array: NumPy array to encode
    :return: encoded NumPy array
    """
    array = array.astype(np.float32)
    h = array.shape[0]

    shape = struct.pack('>I',h)
    encoded = shape + array.tobytes()

    return encoded


def convert_array(encoded):
    """

    :param BLOG encoded: encoded NumPy array
    :return: One steaming hot NumPy array
    :rtype: numpy.array
    """
    
    h = struct.unpack('>I',encoded[:4])
    array = np.frombuffer(encoded, dtype=np.float32, offset=4).reshape(h)
    
    return array

In [9]:
db = redis.Redis(host='localhost', port=6379, db=0)

# Sample some data

To test the I/O we will write and read some data from the database. This may take a while.

In [18]:
write_times = []
read_times = []
counts = [500, 1000, 2000, 3000, 4000, 5000, 50000, 100000, 1000000, 10000000]

for c in counts:
    db.flushall()
    
    start_time_write = time.time()
    for key, emb in tqdm(embeddings(c), total=c):
        arr = adapt_array(emb)
        db.set(key, arr)
    write_times.append(time.time() - start_time_write)
    
    start_time_read = time.time()
    for key, _ in embeddings(c):
        obj = db.get(key)
        emb = convert_array(obj)
        assert(type(emb) is numpy.ndarray)
    read_times.append(time.time() - start_time_read)
    
print('DONE')

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10000000), HTML(value='')))

KeyboardInterrupt: 

In [13]:
# save times for later plotting

with open('./collected_times/keydb-write-times.pickle', 'wb') as f:
    pickle.dump(write_times, f)
    
with open('./collected_times/keydb-read-times.pickle', 'wb') as f:
    pickle.dump(read_times, f)

In [14]:
plotly.offline.init_notebook_mode(connected=True)
trace = plotly.graph_objs.Scatter(
    y = write_times,
    x = counts,
    mode = 'lines+markers'
)
layout = plotly.graph_objs.Layout(title="KyeDB Write Times",
                yaxis=dict(title='Time in Seconds'),
                xaxis=dict(title='Embedding Count'))
data = [trace]
fig = plotly.graph_objs.Figure(data=data, layout=layout)
plotly.offline.iplot(fig, filename='jupyter-scatter-write')

In [16]:
plotly.offline.init_notebook_mode(connected=True)
trace = plotly.graph_objs.Scatter(
    y = read_times,
    x = counts,
    mode = 'lines+markers'
)
layout = plotly.graph_objs.Layout(title="KeyDB Read Times",
                yaxis=dict(title='Time in Seconds'),
                xaxis=dict(title='Embedding Count'))
data = [trace]
fig = plotly.graph_objs.Figure(data=data, layout=layout)
plotly.offline.iplot(fig, filename='jupyter-scatter-read')