# Word Embeddings in Postgresql

This example uses the Postgresql Connector within Python3 to store and retrieve various amounts of Word Embeddings as NumPy arrays.

We will use a local Postgresql database running as a Docker Container for testing purposes. To start the database run:

```
docker run -ti --rm --name word_psql -e POSTGRES_PASSWORD=mikolov -p 5433:5432 postgres:10.5```

In [2]:
import io
import os
import time
import numpy
import plotly
from tqdm import tqdm_notebook as tqdm
import psycopg2
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json

# Dummy Embeddings

For testing purposes we will use randomly generated numpy arrays as dummy embbeddings.

In [3]:
def embeddings(n=10, dim=300):
    """
    Yield n tuples of random numpy arrays of *dim* length indexed by *n*
    """
    idx = 0
    while idx < n:
        yield (str(idx), numpy.random.rand(dim))
        idx += 1

In [4]:
def adapt_numpy_ndarray(numpy_ndarray):
    """
    Transform NumPy Array to bjson
    """
    return Json(numpy_ndarray.tolist())

In [6]:
connection = psycopg2.connect("host=localhost user=postgres password=mikolov port=5433")
register_adapter(numpy.ndarray, adapt_numpy_ndarray)
cursor = connection.cursor()
cursor.execute('CREATE TABLE embeddings (key varchar, embedding jsonb)')
connection.commit()

In [7]:
%%time
# Insert n = 1000 dummy embeddings into the database
for key, emb in embeddings():
    cursor.execute('INSERT INTO embeddings (key, embedding) VALUES (%s, %s)', [key, emb])
    connection.commit()

CPU times: user 13.7 ms, sys: 4.26 ms, total: 17.9 ms
Wall time: 297 ms


In [8]:
%%time
# Select n = 1000 dummy embeddings from the database
for key, _ in embeddings():
    cursor.execute('SELECT key, embedding FROM embeddings WHERE key=%s', (key,))
    data = cursor.fetchone()
    value = numpy.array(data[1])
    assert type(value) is numpy.ndarray
    assert len(value) == 300

CPU times: user 3.81 ms, sys: 0 ns, total: 3.81 ms
Wall time: 6.42 ms


# Sample some data

To test the I/O we will write and read some data from the database. This may take a while.

In [10]:
write_times = []
read_times = []
db_sizes = []
counts = [500, 1000, 2000, 3000, 4000, 5000]

for c in counts:
    cursor.execute('DROP TABLE embeddings')
    cursor.execute('CREATE TABLE IF NOT EXISTS embeddings (key varchar, embedding jsonb)')
    connection.commit()
    
    start_time_write = time.time()
    for key, emb in tqdm(embeddings(c), total=c):
        cursor.execute('INSERT INTO embeddings (key, embedding) VALUES (%s, %s)', [key, emb])
        connection.commit()
    write_times.append(time.time() - start_time_write)
    
    start_time_read = time.time()
    for key, emb in embeddings(c):
        cursor.execute('SELECT * FROM embeddings WHERE key=%s', (key,))
        data = cursor.fetchone()
    read_times.append(time.time() - start_time_read) 
    
print('DONE')

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

DONE


# Results

In [11]:
# Write Times
plotly.offline.init_notebook_mode(connected=True)
trace = plotly.graph_objs.Scatter(
    x = write_times,
    y = counts,
    mode = 'lines+markers'
)
layout = plotly.graph_objs.Layout(title="Postgres Write Times",
                xaxis=dict(title='Time in Seconds'),
                yaxis=dict(title='Embedding Count'))
data = [trace]
fig = plotly.graph_objs.Figure(data=data, layout=layout)
plotly.offline.iplot(fig, filename='jupyter-basic-scatter')

In [14]:
# Read Times
plotly.offline.init_notebook_mode(connected=True)
trace = plotly.graph_objs.Scatter(
    x = read_times,
    y = counts,
    mode = 'lines+markers'
)
layout = plotly.graph_objs.Layout(title="Postgres Read Times",
                xaxis=dict(title='Time in Seconds'),
                yaxis=dict(title='Embedding Count'))
data = [trace]
fig = plotly.graph_objs.Figure(data=data, layout=layout)
plotly.offline.iplot(fig, filename='jupyter-basic-scatter')