Skip to content

Commit

Permalink
Handle multiple neighbor sets (#10)
Browse files Browse the repository at this point in the history
* Check if descriptions are available before loading; add save() and load() convenience methods

* Add separate NeighborSet class in backend and frontend; display ancestor neighbors when serializing for widget
(Does not serialize the correct neighbor sets to file yet)

* Separate Neighbors/NeighborSet into a module; serialization of Viewer comparison config

* Projections no longer retain their parents' neighbor sets; serialize recent_neighbors separately

* Update readme examples
  • Loading branch information
venkatesh-sivaraman committed Jan 6, 2022
1 parent 71e2239 commit d0bf117
Show file tree
Hide file tree
Showing 10 changed files with 743 additions and 106 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ X, Y = ...

# Represent the high-dimensional embedding
emb = emblaze.Embedding({Field.POSITION: X, Field.COLOR: Y})
# Compute nearest neighbors in the high-D space
emb.compute_neighbors(metric='euclidean')
# Compute nearest neighbors in the high-D space (for display)
emb.compute_neighbors(metric='cosine')

# Generate UMAP 2D representations - you can pass UMAP parameters to project()
variants = emblaze.EmbeddingSet([
emb.project(method=ProjectionTechnique.UMAP) for _ in range(10)
])
# Compute neighbors again (to indicate that we want to compare projections)
variants.compute_neighbors(metric='euclidean')

w = emblaze.Viewer(embeddings=variants)
w
Expand All @@ -76,7 +78,7 @@ embeddings = emblaze.EmbeddingSet([
emblaze.Embedding({Field.POSITION: X, Field.COLOR: Y}, label=emb_name)
for X, emb_name in zip(Xs, embedding_names)
])
embeddings.compute_neighbors()
embeddings.compute_neighbors(metric='cosine')

# Make aligned UMAP
reduced = embeddings.project(method=ProjectionTechnique.ALIGNED_UMAP)
Expand Down
432 changes: 355 additions & 77 deletions emblaze/datasets.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions emblaze/frame_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def compute_colors(frames, ids_of_interest=None, scale_factor=1.0):
outer_jaccard_distances = np.zeros((len(frames), len(frames)))
inner_jaccard_distances = np.zeros((len(frames), len(frames)))
for i in range(len(frames)):
frame_1_neighbors = frames[i].field(Field.NEIGHBORS, distance_sample)
frame_1_neighbors = frames[i].get_recent_neighbors()[distance_sample]
for j in range(len(frames)):
frame_2_neighbors = frames[j].field(Field.NEIGHBORS, distance_sample)
frame_2_neighbors = frames[j].get_recent_neighbors()[distance_sample]
# If the id set is the entire frame, there will be no outer neighbors
# so we can just leave this at zero
if ids_of_interest is not None and len(ids_of_interest):
Expand Down
176 changes: 176 additions & 0 deletions emblaze/neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import numpy as np
from sklearn.neighbors import NearestNeighbors
from .utils import *

class Neighbors:
"""
An object representing a serializable set of nearest neighbors within an
embedding. The Neighbors object simply stores a matrix of integer IDs, where rows
correspond to points in the embedding and columns are IDs of neighbors in
order of proximity to each point.
"""
def __init__(self, values, ids=None, metric='euclidean', n_neighbors=100, clf=None):
"""
pos: Matrix of n x D vectors indicating high-dimensional positions
ids: If supplied, a list of IDs for the points in the matrix
"""
super().__init__()
self.values = values
self.ids = ids
self._id_index = {id: i for i, id in enumerate(self.ids)}
self.metric = metric
self.n_neighbors = n_neighbors
self.clf = clf

@classmethod
def compute(cls, pos, ids=None, metric='euclidean', n_neighbors=100):
ids = ids if ids is not None else np.arange(len(pos))
neighbor_clf = NearestNeighbors(metric=metric,
n_neighbors=n_neighbors + 1).fit(pos)
_, neigh_indexes = neighbor_clf.kneighbors(pos)

return cls(ids[neigh_indexes[:,1:]], ids=ids, metric=metric, n_neighbors=n_neighbors, clf=neighbor_clf)

def index(self, id_vals):
"""
Returns the index(es) of the given IDs.
"""
if isinstance(id_vals, (list, np.ndarray, set)):
return [self._id_index[int(id_val)] for id_val in id_vals]
else:
return self._id_index[int(id_vals)]

def __getitem__(self, ids):
"""ids can be a single ID or a sequence of IDs"""
if ids is None: return self.values
return self.values[self.index(ids)]

def __eq__(self, other):
if isinstance(other, NeighborSet): return other == self
if not isinstance(other, Neighbors): return False
return np.allclose(self.ids, other.ids) and np.allclose(self.values, other.values)

def __ne__(self, other):
return not (self == other)

def __len__(self):
return len(self.values)

def calculate_neighbors(self, pos, return_distance=True, n_neighbors=None):
if self.clf is None:
raise ValueError(
("Cannot compute neighbors because the Neighbors was not "
"initialized with a neighbor classifier - was it deserialized "
"from JSON without saving the original coordinates or "
"concatenated to another Neighbors?"))
neigh_dists, neigh_indexes = self.clf.kneighbors(pos, n_neighbors=n_neighbors or self.n_neighbors)
if return_distance:
return neigh_dists, neigh_indexes
return neigh_indexes

def concat(self, other):
"""Concatenates the two Neighbors together, discarding the original
classifier."""
assert not (set(self.ids.tolist()) & set(other.ids.tolist())), "Cannot concatenate Neighbors objects with overlapping ID values"
assert self.metric == other.metric, "Cannot concatenate Neighbors objects with different metrics"
return Neighbors(
np.concatenate(self.values, other.values),
ids=np.concatenate(self.ids, other.ids),
metric=self.metric,
n_neighbors = max(self.n_neighbors, other.n_neighbors)
)

def to_json(self, compressed=True, num_neighbors=None):
"""Serializes the neighbors to a JSON object."""
result = {}
result["metric"] = self.metric
result["n_neighbors"] = self.n_neighbors

neighbors = self.values
if num_neighbors is not None:
neighbors = neighbors[:,:min(num_neighbors, neighbors.shape[1])]

if compressed:
result["_format"] = "compressed"
# Specify the type name that will be used to encode the point IDs.
# This is important because the highlight array takes up the bulk
# of the space when transferring to file/widget.
dtype, type_name = choose_integer_type(self.ids)
result["_idtype"] = type_name
result["_length"] = len(self)
result["ids"] = encode_numerical_array(self.ids, dtype)

result["neighbors"] = encode_numerical_array(neighbors.flatten(),
astype=dtype,
interval=neighbors.shape[1])
else:
result["_format"] = "expanded"
result["neighbors"] = {}
indexes = self.index(self.ids)
for id_val, index in zip(self.ids, indexes):
result["neighbors"][id_val] = neighbors[index].tolist()
return result

@classmethod
def from_json(cls, data):
if data.get("_format", "expanded") == "compressed":
dtype = np.dtype(data["_idtype"])
ids = decode_numerical_array(data["ids"], dtype)
neighbors = decode_numerical_array(data["neighbors"], dtype)
else:
neighbor_dict = data["neighbors"]
try:
ids = [int(id_val) for id_val in list(neighbor_dict.keys())]
neighbor_dict = {int(k): v for k, v in neighbor_dict.items()}
except:
ids = list(neighbor_dict.keys())
ids = sorted(ids)
neighbors = np.array([neighbor_dict[id_val] for id_val in ids])

return cls(neighbors, ids=ids, metric=data["metric"], n_neighbors=data["n_neighbors"])

class NeighborSet:
"""
An object representing a serializable collection of Neighbors objects.
"""
def __init__(self, neighbor_objects):
super().__init__()
self._neighbors = neighbor_objects

def __getitem__(self, slice):
return self._neighbors[slice]

def __setitem__(self, slice, val):
self._neighbors[slice] = val

def __len__(self):
return len(self._neighbors)

def __iter__(self):
return iter(self._neighbors)

def __eq__(self, other):
if isinstance(other, NeighborSet):
return len(other) == len(self) and all(n1 == n2 for n1, n2 in zip(self, other))
elif isinstance(other, Neighbors):
return all(n1 == other for n1 in self)
return False

def __ne__(self, other):
return not (self == other)

def to_json(self, compressed=True, num_neighbors=None):
"""
Serializes the list of Neighbors objects to JSON.
"""
return [n.to_json(compressed=compressed, num_neighbors=num_neighbors)
for n in self]

@classmethod
def from_json(cls, data):
return [Neighbors.from_json(d) for d in data]

def identical(self):
"""Returns True if all Neighbors objects within this NeighborSet are equal to each other."""
if len(self) == 0: return True
return all(n == self[0] for n in self)
12 changes: 6 additions & 6 deletions emblaze/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _make_neighbor_changes(self, idx_1, idx_2, filter_points=None):
"""
frame_1 = self.embeddings[idx_1]
frame_2 = self.embeddings[idx_2]
frame_1_neighbors = frame_1.field(Field.NEIGHBORS, ids=filter_points or None)
frame_2_neighbors = frame_2.field(Field.NEIGHBORS, ids=filter_points or None)
frame_1_neighbors = frame_1.get_recent_neighbors()[filter_points or None]
frame_2_neighbors = frame_2.get_recent_neighbors()[filter_points or None]
gained_ids = [set(frame_2_neighbors[i]) - set(frame_1_neighbors[i]) for i in range(len(filter_points or frame_1))]
lost_ids = [set(frame_1_neighbors[i]) - set(frame_2_neighbors[i]) for i in range(len(filter_points or frame_1))]

Expand All @@ -79,15 +79,15 @@ def _consistency_score(self, ids, frame):
Computes the consistency between the neighbors for the given set of IDs
in the given frame.
"""
return (np.sum(1 - self._pairwise_jaccard_distances(frame.field(Field.NEIGHBORS, ids=ids))) - len(ids)) / (len(ids) * (len(ids) - 1))
return (np.sum(1 - self._pairwise_jaccard_distances(frame.get_recent_neighbors()[ids])) - len(ids)) / (len(ids) * (len(ids) - 1))

def _inner_change_score(self, ids, frame_1, frame_2):
"""
Computes the inverse intersection of the neighbor sets in the given
two frames.
"""
return np.mean(inverse_intersection(frame_1.field(Field.NEIGHBORS, ids=ids),
frame_2.field(Field.NEIGHBORS, ids=ids),
return np.mean(inverse_intersection(frame_1.get_recent_neighbors()[ids],
frame_2.get_recent_neighbors()[ids],
List(ids),
False))

Expand Down Expand Up @@ -180,7 +180,7 @@ def query(self, ids_of_interest=None, filter_ids=None, frame_idx=None, preview_f

# Assemble a list of candidates
if ids_of_interest is not None:
neighbor_ids = set([n for n in self.embeddings[frame_key[0]].field(Field.NEIGHBORS, ids_of_interest)[:,:NUM_NEIGHBORS_FOR_SEARCH].flatten()])
neighbor_ids = set([n for n in self.embeddings[frame_key[0]].get_recent_neighbors()[ids_of_interest][:,:NUM_NEIGHBORS_FOR_SEARCH].flatten()])
else:
neighbor_ids = None

Expand Down
18 changes: 17 additions & 1 deletion emblaze/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class Field:
POSITION = "position"
COLOR = "color"
RADIUS = "r"
NEIGHBORS = "highlight"
ALPHA = "alpha"

# Thumbnail fields
Expand Down Expand Up @@ -174,6 +173,19 @@ def choose_integer_type(values):
return np.uint16, "u2"
return np.uint8, "u1"

def _detect_numerical_sequence(arr):
"""
Detects a numerical sequence to compress large arrays of integer IDs when
they are regularly spaced. If a sequence is detected, returns the start, end,
and step such that using np.arange() with these three arguments yields the
appropriate result. If no sequence is detected, returns None.
"""
diffs = arr[1:] - arr[:-1]
if np.allclose(diffs, diffs[0]):
step = diffs[0]
return (arr[0], arr[-1] + step, step)
return None

def encode_numerical_array(arr, astype=np.float32, positions=None, interval=None):
"""
Encodes the given numpy array into a base64 representation for fast transfer
Expand All @@ -187,6 +199,10 @@ def encode_numerical_array(arr, astype=np.float32, positions=None, interval=None
If interval is not None, it is passed into the result object directly (and
signifies the same as positions, but with a regularly spaced interval).
"""
# TODO support saving arrays as numerical sequence metadata
# sequence_info = _detect_numerical_sequence(arr)
# if sequence_info is not None:
# result = { ""}
result = { "values": base64.b64encode(arr.astype(astype)).decode('ascii') }
if positions is not None:
result["positions"] = base64.b64encode(positions.astype(np.int32)).decode('ascii')
Expand Down

0 comments on commit d0bf117

Please sign in to comment.