Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature parity for standalone demo application #11

Merged
merged 15 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ Although the full application is designed to work as a Jupyter widget, you can r
python -m emblaze.server
```

Visit `localhost:5000` to see the running application. This will allow you to view two demo datasets: one showing five different t-SNE projections of a subset of MNIST digits, and one showing embeddings of the same 5,000 words according to three different data sources (Google News, Wikipedia, and Twitter). To add your own datasets (for example, to host an instance of Emblaze showing a custom dataset), see the standalone app development instructions below.
Visit `localhost:5000` to see the running application. This will allow you to view two demo datasets: one showing five different t-SNE projections of a subset of MNIST digits, and one showing embeddings of the same 5,000 words according to three different data sources (Google News, Wikipedia, and Twitter). To add your own datasets to the standalone app, you can create a directory containing your saved comparison JSON files (see Saving and Loading below), then pass it as a command-line argument:

```bash
python -m emblaze.server /path/to/comparisons
```

## Examples

Expand Down Expand Up @@ -111,6 +115,18 @@ Once you have loaded a `Viewer` instance in the notebook, you can read and write
- `colorScheme` (`string`) The name of a color scheme to use to render the points. A variety of color schemes are available, listed in `src/colorschemes.ts`. This property can also be changed in the Settings panel of the widget.
- `previewMode` (`string`) The method to use to generate preview lines, which should be one of the values in `utils.

### Saving and Loading

You can save the data used to make comparisons to JSON, so that it is easy to load them again in Jupyter or the standalone application without re-running the embedding/projection code. Comparisons consist of an `EmbeddingSet` (containing the positions of the points in each 2D projection), a `Thumbnails` object (dictating how to display each point), and one or more `NeighborSet`s (which contain the nearest-neighbor sets used for comparison and display).

To save a comparison, call the `save_comparison()` method on the `Viewer`. Note that if you are using high-dimensional nearest neighbors (most use cases), this method by default saves both the high-dimensional coordinates and the nearest-neighbor IDs. This can create files ranging from hundreds of MB to GBs. To store only the nearest neighbor IDs, pass `ancestor_data=False` as a keyword argument. Note that if you disable storing the high-dimensional coordinates, you will be unable to use tools that depend on _distances_ in hi-D space (such as the high-dimensional radius select).

To load a comparison, simply initialize the `Viewer` as follows:

```python
w = emblaze.Viewer(file="/path/to/comparison.json")
```

---

## Development Installation
Expand Down
2 changes: 1 addition & 1 deletion emblaze/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
# Copyright (c) venkatesh-sivaraman.
# Distributed under the terms of the Modified BSD License.

version_info = (0, 9, 4)
version_info = (0, 10, 0)
__version__ = ".".join(map(str, version_info))
1 change: 1 addition & 0 deletions emblaze/data/MNIST.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions emblaze/data/Word Embeddings.json

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion emblaze/data/corpus_aligned_umap/data.json

This file was deleted.

1 change: 0 additions & 1 deletion emblaze/data/corpus_aligned_umap/thumbnails.json

This file was deleted.

1 change: 0 additions & 1 deletion emblaze/data/mnist-tsne/data.json

This file was deleted.

1 change: 0 additions & 1 deletion emblaze/data/mnist-tsne/thumbnails.json

This file was deleted.

7 changes: 6 additions & 1 deletion emblaze/public/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
integrity="sha384-ho+j7jyWK8fNQe+A12Hb8AhRq26LrZ/JpcUGGOn+Y7RsweNrtN/tE3MoK7ZeZDyx"
crossorigin="anonymous"
></script>
<script
src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"
integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA=="
crossorigin="anonymous"
></script>
</head>

<body></body>
<body style="padding: 0"></body>
</html>
205 changes: 113 additions & 92 deletions emblaze/server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
import eventlet
eventlet.monkey_patch()

from flask import request, Flask, send_from_directory, jsonify, send_file
from flask_socketio import SocketIO, send, emit
from engineio.payload import Payload
import os
import json
import numpy as np
from .datasets import EmbeddingSet
from .utils import affine_to_matrix, matrix_to_affine
import sys

from .viewer import Viewer

EXCLUDE_TRAITLETS = set([
'comm', 'count', 'keys', 'layout', 'log',
'message', 'json', 'connect', 'disconnect'
])

Payload.max_decode_packets = 200

app = Flask(__name__)
socketio = SocketIO(app,
async_mode='eventlet',
message_queue=os.environ.get('REDIS_URL', 'redis://'))

parent_dir = os.path.dirname(__file__)
public_dir = os.path.join(parent_dir, "public")
data_dir = os.path.join(parent_dir, "data")

user_data = {}

def socketio_thread_starter(fn, args=[], kwargs={}):
socketio.start_background_task(fn, *args, **kwargs)

# Path for our main Svelte page
@app.route("/")
def base():
Expand All @@ -21,97 +40,99 @@ def base():
def home(path):
return send_from_directory(public_dir, path)

# Example
def _get_all_datasets():
return [os.path.join(data_dir, f)
for f in sorted(os.listdir(data_dir))
if not f.startswith('.') and f.endswith('.json')]

@app.route("/datasets")
def list_datasets():
return jsonify([f for f in sorted(os.listdir(data_dir)) if os.path.isdir(os.path.join(data_dir, f))])

@app.route("/datasets/<dataset_name>/data")
def get_data(dataset_name):
dataset_base = os.path.join(data_dir, dataset_name)
if not os.path.exists(dataset_base) or not os.path.isdir(dataset_base):
return app.response_class(response="The dataset does not exist", status=404)

with open(os.path.join(dataset_base, "data.json"), "r") as file:
return app.response_class(
response=file.read(),
status=200,
mimetype='application/json'
)

@app.route("/datasets/<dataset_name>/thumbnails")
def get_thumbnails(dataset_name):
dataset_base = os.path.join(data_dir, dataset_name)
if not os.path.exists(dataset_base) or not os.path.isdir(dataset_base):
return app.response_class(response="The dataset does not exist", status=404)

if not os.path.exists(os.path.join(dataset_base, "thumbnails.json")):
return app.response_class(response="The dataset has no thumbnails", status=404)

with open(os.path.join(dataset_base, "thumbnails.json"), "r") as file:
return app.response_class(
response=file.read(),
status=200,
mimetype='application/json'
)

@app.route("/datasets/<dataset_name>/supplementary/<filename>")
def get_supplementary_file(dataset_name, filename):
dataset_base = os.path.join(data_dir, dataset_name)
if not os.path.exists(dataset_base) or not os.path.isdir(dataset_base):
return app.response_class(response="The dataset does not exist", status=404)

return send_file(os.path.join(dataset_base, "supplementary", filename))

@app.route("/align/<dataset_name>/<current_frame>", methods=['GET', 'POST'])
def align_frames(dataset_name, current_frame):
dataset_base = os.path.join(data_dir, dataset_name)
if not os.path.exists(dataset_base) or not os.path.isdir(dataset_base):
return app.response_class(response="The dataset does not exist", status=404)

# Allow client-supplied initial transform as 3x3 nested list
base_transform = None
ids = None
if request.method == 'POST':
body = request.json
if body:
if "initialTransform" in body:
base_transform = matrix_to_affine(np.array(body["initialTransform"]))
if "ids" in body:
ids = body["ids"]

with open(os.path.join(dataset_base, "data.json"), "r") as file:
data = json.load(file)
emb_set = EmbeddingSet.from_json(data)

if not ids:
return jsonify({"transformations": [
np.eye(3).tolist()
for i in range(len(emb_set))
]})

try:
current_frame = int(current_frame)
except:
current_frame = None
return jsonify(_get_all_datasets())

@socketio.on('connect')
def connect():
print('connected', request.sid)
widget = Viewer(file=_get_all_datasets()[0], thread_starter=socketio_thread_starter)
user_data[request.sid] = widget
for trait_name in widget.trait_names(sync=lambda x: x):
if trait_name in EXCLUDE_TRAITLETS: continue

# Register callbacks for getting and setting from frontend
socketio.on_event('get:' + trait_name, _read_value_handler(trait_name))
socketio.on_event('set:' + trait_name, _write_value_handler(trait_name))

# Emit responses when backend state changes
widget.observe(_emit_value_handler(trait_name, request.sid), trait_name)

if current_frame is None or current_frame < 0 or current_frame >= len(emb_set):
return app.response_class(response="Invalid frame number", status=400)

print(base_transform, current_frame, ids, emb_set)
@socketio.on('disconnect')
def disconnect():
print('disconnected', request.sid)
del user_data[request.sid]

def _read_value_handler(name):
def handle_msg():
if request.sid not in user_data:
print("Missing request SID:", request.sid)
return None
return getattr(user_data[request.sid], name)
return handle_msg

def _write_value_handler(name):
def handle_msg(data):
if request.sid not in user_data:
print("Missing request SID:", request.sid)
return
setattr(user_data[request.sid], name, data)
return handle_msg

def _emit_value_handler(name, sid):
def handle_msg(change):
with app.app_context():
emit('change:' + name, change.new, room=sid, namespace='/')
return handle_msg

def run_server(start_redis=False, data_directory=None, debug=False):
"""
Starts the Flask server. If start_redis is True, automatically starts a
Redis instance (on port 6379). If not, the server expects a Redis instance
at the URL in the environment variable REDIS_URL.
"""
global data_dir
if data_directory: data_dir = data_directory
assert len(_get_all_datasets()) > 0, "No datasets (.json files) found in data directory"

redis_pid = None

if start_redis:
import redis_server
import subprocess
import tempfile

temp_dir = tempfile.mkdtemp()
pid_path = os.path.join(temp_dir, 'redis.pid')

server_path = redis_server.REDIS_SERVER_PATH
subprocess.check_call(
'{} --daemonize yes --pidfile {} --logfile {}'.format(
server_path,
pid_path,
os.path.join(temp_dir, 'redis.log')),
shell=True)

# If there is a pid file, read it to know what to shut down when the server stops
if os.path.exists(pid_path):
with open(pid_path, 'r') as file:
redis_pid = file.read().strip()
print("Started redis server (pid {})".format(redis_pid))

transformations = []
for frame in emb_set.embeddings:
transformations.append(affine_to_matrix(frame.align_to(
emb_set.embeddings[current_frame],
ids=list(set(ids)) if ids else None,
base_transform=base_transform,
return_transform=True,
allow_flips=False)).tolist())

return jsonify({"transformations": transformations})

if __name__ == "__main__":
print("Running Flask server with public dir '{}' and data dir '{}'".format(public_dir, data_dir))
app.run(debug=True)
try:
socketio.run(app, debug=debug)
except KeyboardInterrupt as e:
if redis_pid is not None:
print("Shutting down redis server")
subprocess.check_call('kill {}'.format(redis_pid))
raise e

if __name__ == "__main__":
run_server(start_redis=True, data_directory=sys.argv[1] if len(sys.argv) > 1 else None, debug=True)
8 changes: 4 additions & 4 deletions emblaze/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class PreviewMode:

class SidebarPane:
"""Indexes of sidebar panes in the widget."""
CURRENT = 0
SAVED = 1
RECENT = 2
SUGGESTED = 3
CURRENT = 1
SAVED = 2
RECENT = 3
SUGGESTED = 4

FLIP_FACTORS = [
np.array([1, 1, 1]),
Expand Down