# Word Embeddings in scyllaDB

This example uses the official scyllaDB Connector within Python3 to store and retrieve various amounts of Word Embeddings.

We will use a local scylla database running as a Docker Container for testing purposes. To start the database run:

```
docker run --name scylla -d scylladb/scylla
```

In [20]:
from cassandra.cluster import Cluster
import numpy
import io
import time
import plotly
import pickle
from tqdm import tqdm_notebook as tqdm
import threading

In [21]:
cluster = Cluster(['172.17.0.2'])

In [22]:
def embeddings(n=1000, dim=512):
    """
    Yield n tuples of random numpy arrays of *dim* length indexed by *n*
    """
    idx = 0
    while idx < n:
        yield (str(idx), numpy.random.rand(dim))
        idx += 1

In [23]:
def adapt_array(array):
    """
    Using the numpy.save function to save a binary version of the array,
    and BytesIO to catch the stream of data and convert it into a BLOB.
    """
    out = io.BytesIO()
    numpy.save(out, array)
    out.seek(0)

    return out.read()


def convert_array(blob):
    """
    Using BytesIO to convert the binary version of the array back into a numpy array.
    """
    out = io.BytesIO(blob)
    out.seek(0)

    return numpy.load(out)

In [24]:
def insert_embedding(key, arr):
    arr = adapt_array(emb)
    query = "INSERT INTO embeddings (key, embedding) VALUES (%s, %s)"
    session.execute_async(query, (key, arr))

In [26]:
session = cluster.connect()

session.execute('DROP KEYSPACE IF EXISTS embeddings_ks')
session.execute('CREATE KEYSPACE IF NOT EXISTS '
                'embeddings_ks with replication = '
                '{\'class\':\'SimpleStrategy\', '
                '\'replication_factor\': 3}')

session = cluster.connect('embeddings_ks')

session.execute('CREATE TABLE IF NOT EXISTS embeddings (key TEXT PRIMARY KEY, embedding BLOB);')

write_times = []
read_times = []
counts = [500, 1000, 2000, 3000, 4000, 5000, 50000, 100000, 1000000, 10000000]

for c in counts:
    print(c)
    start_time_write = time.time()
    threads = []
    for key, emb in tqdm(embeddings(c), total=c):
        t = threading.Thread(target=insert_embedding, args=(key, emb))
        t.start()
        threads.append(t)

    for i in range (c):
        threads[i].join()
    
    write_time = time.time() - start_time_write
    write_times.append(write_time)

    start_time_read = time.time()
    for key, emb in tqdm(embeddings(c), total=c):
        obj = session.execute('SELECT embedding FROM embeddings WHERE key=%s;', (key,))
        emb = convert_array(obj.one().embedding)
    read_time = time.time() - start_time_read
    read_times.append(read_time)

print('DONE')

500


HBox(children=(IntProgress(value=0, max=500), HTML(value='')))




HBox(children=(IntProgress(value=0, max=500), HTML(value='')))


1000


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


2000


HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))


3000


HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))


4000


HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))


5000


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))


50000


HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))


100000


HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))


1000000


HBox(children=(IntProgress(value=0, max=1000000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000000), HTML(value='')))

AttributeError: 'NoneType' object has no attribute 'embedding'

In [27]:
# save times for later plotting

with open('./collected_times/scylladb-write-times.pickle', 'wb') as f:
    pickle.dump(write_times, f)
    
with open('./collected_times/scylladb-read-times.pickle', 'wb') as f:
    pickle.dump(read_times, f)

In [28]:
# Write Times
plotly.offline.init_notebook_mode(connected=True)
trace = plotly.graph_objs.Scatter(
    y = write_times,
    x = counts,
    mode = 'lines+markers'
)
layout = plotly.graph_objs.Layout(title="ScyllaDB Write Times",
                xaxis=dict(title='Time in Seconds'),
                yaxis=dict(title='Embedding Count'))
data = [trace]
fig = plotly.graph_objs.Figure(data=data, layout=layout)
plotly.offline.iplot(fig, filename='jupyter-basic-scatter')

In [29]:
plotly.offline.init_notebook_mode(connected=True)
trace = plotly.graph_objs.Scatter(
    y = read_times,
    x = counts,
    mode = 'lines+markers'
)
layout = plotly.graph_objs.Layout(title="ScyllaDB Read Times",
                yaxis=dict(title='Time in Seconds'),
                xaxis=dict(title='Embedding Count'))
data = [trace]
fig = plotly.graph_objs.Figure(data=data, layout=layout)
plotly.offline.iplot(fig, filename='jupyter-scatter-read')