# Word Embeddings in Cassandra

This example uses the Cassandra driver within Python3 to store and retrieve various amounts of Word Embeddings.

You will need to install Cassandra on a local machine:

pip3 install cassandra-driver

In [39]:
from cassandra.cluster import Cluster
import numpy
import io
import time
import plotly
import pickle
from tqdm import tqdm_notebook as tqdm
import threading

Dummy Embeddings

For testing purposes we will use randomly generated numpy arrays as dummy embbeddings.

In [57]:
def embeddings(n=1000, dim=512):
    """
    Yield n tuples of random numpy arrays of *dim* length indexed by *n*
    """
    idx = 0
    while idx < n:
        yield (str(idx), numpy.random.rand(dim))
        idx += 1

# Conversion Functions

Since we can't just save a NumPy array into the database, we will convert it into a BLOB.

In [7]:
def adapt_array(array):
    """
    Using the numpy.save function to save a binary version of the array,
    and BytesIO to catch the stream of data and convert it into a BLOB.
    """
    out = io.BytesIO()
    numpy.save(out, array)
    out.seek(0)

    return out.read()


def convert_array(blob):
    """
    Using BytesIO to convert the binary version of the array back into a numpy array.
    """
    out = io.BytesIO(blob)
    out.seek(0)

    return numpy.load(out)

In [67]:
cluster = Cluster()

session = cluster.connect()

session.execute('DROP KEYSPACE IF EXISTS embeddings_ks')
session.execute('CREATE KEYSPACE IF NOT EXISTS '
                'embeddings_ks with replication = '
                '{\'class\':\'SimpleStrategy\', '
                '\'replication_factor\': 1}')

session = cluster.connect('embeddings_ks')

write_times = []
read_times = []
counts = [500, 1000, 2000, 3000, 4000, 5000, 50000, 100000, 1000000, 10000000]

def insert_embedding(key, arr):
    arr = adapt_array(emb)
    query = "INSERT INTO embeddings (key, embedding) VALUES (%s, %s)"
    session.execute_async(query, (key,arr))

for c in counts:
    print(c)
    session.execute('DROP TABLE IF EXISTS embeddings;')
    session.execute('CREATE TABLE IF NOT EXISTS embeddings (key TEXT PRIMARY KEY, embedding BLOB);')

    start_time_write = time.time()
    threads = []
    for key, emb in embeddings(c):
        t = threading.Thread(target=insert_embedding, args=(key, arr))
        t.start()
        threads.append(t)

    for i in range (c):
        threads[i].join()
    
    write_time = time.time() - start_time_write
    write_times.append(write_time)
    print(write_time)

    start_time_read = time.time()
    for key, emb in embeddings(c):
        obj = session.execute('SELECT embedding FROM embeddings WHERE key=%s;', (key,))
        emb = convert_array(obj.one().embedding)
    read_time = time.time() - start_time_read
    read_times.append(read_time)
    print(read_time)

print('DONE')

500
2.329406261444092
0.7636280059814453
1000
1.077176570892334
1.5462610721588135
2000
1.968094825744629
2.8464019298553467
3000
2.858471155166626
3.993593692779541
4000
3.840188503265381
5.770720958709717
5000
4.471283674240112
6.828747987747192
50000


KeyboardInterrupt: 

In [64]:
# save times for later plotting

with open('./collected_times/cassandra-write-times.pickle', 'wb') as f:
    pickle.dump(write_times, f)
    
with open('./collected_times/cassandra-read-times.pickle', 'wb') as f:
    pickle.dump(read_times, f)

In [65]:
# Write Times
plotly.offline.init_notebook_mode(connected=True)
trace = plotly.graph_objs.Scatter(
    y = write_times,
    x = counts,
    mode = 'lines+markers'
)
layout = plotly.graph_objs.Layout(title="Cassandra Write Times",
                xaxis=dict(title='Time in Seconds'),
                yaxis=dict(title='Embedding Count'))
data = [trace]
fig = plotly.graph_objs.Figure(data=data, layout=layout)
plotly.offline.iplot(fig, filename='jupyter-basic-scatter')

In [66]:
plotly.offline.init_notebook_mode(connected=True)
trace = plotly.graph_objs.Scatter(
    y = read_times,
    x = counts,
    mode = 'lines+markers'
)
layout = plotly.graph_objs.Layout(title="Cassandra Read Times",
                yaxis=dict(title='Time in Seconds'),
                xaxis=dict(title='Embedding Count'))
data = [trace]
fig = plotly.graph_objs.Figure(data=data, layout=layout)
plotly.offline.iplot(fig, filename='jupyter-scatter-read')