Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle multiple neighbor sets #10

Merged
merged 6 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ X, Y = ...

# Represent the high-dimensional embedding
emb = emblaze.Embedding({Field.POSITION: X, Field.COLOR: Y})
# Compute nearest neighbors in the high-D space
emb.compute_neighbors(metric='euclidean')
# Compute nearest neighbors in the high-D space (for display)
emb.compute_neighbors(metric='cosine')

# Generate UMAP 2D representations - you can pass UMAP parameters to project()
variants = emblaze.EmbeddingSet([
emb.project(method=ProjectionTechnique.UMAP) for _ in range(10)
])
# Compute neighbors again (to indicate that we want to compare projections)
variants.compute_neighbors(metric='euclidean')

w = emblaze.Viewer(embeddings=variants)
w
Expand All @@ -76,7 +78,7 @@ embeddings = emblaze.EmbeddingSet([
emblaze.Embedding({Field.POSITION: X, Field.COLOR: Y}, label=emb_name)
for X, emb_name in zip(Xs, embedding_names)
])
embeddings.compute_neighbors()
embeddings.compute_neighbors(metric='cosine')

# Make aligned UMAP
reduced = embeddings.project(method=ProjectionTechnique.ALIGNED_UMAP)
Expand Down
432 changes: 355 additions & 77 deletions emblaze/datasets.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions emblaze/frame_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def compute_colors(frames, ids_of_interest=None, scale_factor=1.0):
outer_jaccard_distances = np.zeros((len(frames), len(frames)))
inner_jaccard_distances = np.zeros((len(frames), len(frames)))
for i in range(len(frames)):
frame_1_neighbors = frames[i].field(Field.NEIGHBORS, distance_sample)
frame_1_neighbors = frames[i].get_recent_neighbors()[distance_sample]
for j in range(len(frames)):
frame_2_neighbors = frames[j].field(Field.NEIGHBORS, distance_sample)
frame_2_neighbors = frames[j].get_recent_neighbors()[distance_sample]
# If the id set is the entire frame, there will be no outer neighbors
# so we can just leave this at zero
if ids_of_interest is not None and len(ids_of_interest):
Expand Down
176 changes: 176 additions & 0 deletions emblaze/neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import numpy as np
from sklearn.neighbors import NearestNeighbors
from .utils import *

class Neighbors:
"""
An object representing a serializable set of nearest neighbors within an
embedding. The Neighbors object simply stores a matrix of integer IDs, where rows
correspond to points in the embedding and columns are IDs of neighbors in
order of proximity to each point.
"""
def __init__(self, values, ids=None, metric='euclidean', n_neighbors=100, clf=None):
"""
pos: Matrix of n x D vectors indicating high-dimensional positions
ids: If supplied, a list of IDs for the points in the matrix
"""
super().__init__()
self.values = values
self.ids = ids
self._id_index = {id: i for i, id in enumerate(self.ids)}
self.metric = metric
self.n_neighbors = n_neighbors
self.clf = clf

@classmethod
def compute(cls, pos, ids=None, metric='euclidean', n_neighbors=100):
ids = ids if ids is not None else np.arange(len(pos))
neighbor_clf = NearestNeighbors(metric=metric,
n_neighbors=n_neighbors + 1).fit(pos)
_, neigh_indexes = neighbor_clf.kneighbors(pos)

return cls(ids[neigh_indexes[:,1:]], ids=ids, metric=metric, n_neighbors=n_neighbors, clf=neighbor_clf)

def index(self, id_vals):
"""
Returns the index(es) of the given IDs.
"""
if isinstance(id_vals, (list, np.ndarray, set)):
return [self._id_index[int(id_val)] for id_val in id_vals]
else:
return self._id_index[int(id_vals)]

def __getitem__(self, ids):
"""ids can be a single ID or a sequence of IDs"""
if ids is None: return self.values
return self.values[self.index(ids)]

def __eq__(self, other):
if isinstance(other, NeighborSet): return other == self
if not isinstance(other, Neighbors): return False
return np.allclose(self.ids, other.ids) and np.allclose(self.values, other.values)

def __ne__(self, other):
return not (self == other)

def __len__(self):
return len(self.values)

def calculate_neighbors(self, pos, return_distance=True, n_neighbors=None):
if self.clf is None:
raise ValueError(
("Cannot compute neighbors because the Neighbors was not "
"initialized with a neighbor classifier - was it deserialized "
"from JSON without saving the original coordinates or "
"concatenated to another Neighbors?"))
neigh_dists, neigh_indexes = self.clf.kneighbors(pos, n_neighbors=n_neighbors or self.n_neighbors)
if return_distance:
return neigh_dists, neigh_indexes
return neigh_indexes

def concat(self, other):
"""Concatenates the two Neighbors together, discarding the original
classifier."""
assert not (set(self.ids.tolist()) & set(other.ids.tolist())), "Cannot concatenate Neighbors objects with overlapping ID values"
assert self.metric == other.metric, "Cannot concatenate Neighbors objects with different metrics"
return Neighbors(
np.concatenate(self.values, other.values),
ids=np.concatenate(self.ids, other.ids),
metric=self.metric,
n_neighbors = max(self.n_neighbors, other.n_neighbors)
)

def to_json(self, compressed=True, num_neighbors=None):
"""Serializes the neighbors to a JSON object."""
result = {}
result["metric"] = self.metric
result["n_neighbors"] = self.n_neighbors

neighbors = self.values
if num_neighbors is not None:
neighbors = neighbors[:,:min(num_neighbors, neighbors.shape[1])]

if compressed:
result["_format"] = "compressed"
# Specify the type name that will be used to encode the point IDs.
# This is important because the highlight array takes up the bulk
# of the space when transferring to file/widget.
dtype, type_name = choose_integer_type(self.ids)
result["_idtype"] = type_name
result["_length"] = len(self)
result["ids"] = encode_numerical_array(self.ids, dtype)

result["neighbors"] = encode_numerical_array(neighbors.flatten(),
astype=dtype,
interval=neighbors.shape[1])
else:
result["_format"] = "expanded"
result["neighbors"] = {}
indexes = self.index(self.ids)
for id_val, index in zip(self.ids, indexes):
result["neighbors"][id_val] = neighbors[index].tolist()
return result

@classmethod
def from_json(cls, data):
if data.get("_format", "expanded") == "compressed":
dtype = np.dtype(data["_idtype"])
ids = decode_numerical_array(data["ids"], dtype)
neighbors = decode_numerical_array(data["neighbors"], dtype)
else:
neighbor_dict = data["neighbors"]
try:
ids = [int(id_val) for id_val in list(neighbor_dict.keys())]
neighbor_dict = {int(k): v for k, v in neighbor_dict.items()}
except:
ids = list(neighbor_dict.keys())
ids = sorted(ids)
neighbors = np.array([neighbor_dict[id_val] for id_val in ids])

return cls(neighbors, ids=ids, metric=data["metric"], n_neighbors=data["n_neighbors"])

class NeighborSet:
"""
An object representing a serializable collection of Neighbors objects.
"""
def __init__(self, neighbor_objects):
super().__init__()
self._neighbors = neighbor_objects

def __getitem__(self, slice):
return self._neighbors[slice]

def __setitem__(self, slice, val):
self._neighbors[slice] = val

def __len__(self):
return len(self._neighbors)

def __iter__(self):
return iter(self._neighbors)

def __eq__(self, other):
if isinstance(other, NeighborSet):
return len(other) == len(self) and all(n1 == n2 for n1, n2 in zip(self, other))
elif isinstance(other, Neighbors):
return all(n1 == other for n1 in self)
return False

def __ne__(self, other):
return not (self == other)

def to_json(self, compressed=True, num_neighbors=None):
"""
Serializes the list of Neighbors objects to JSON.
"""
return [n.to_json(compressed=compressed, num_neighbors=num_neighbors)
for n in self]

@classmethod
def from_json(cls, data):
return [Neighbors.from_json(d) for d in data]

def identical(self):
"""Returns True if all Neighbors objects within this NeighborSet are equal to each other."""
if len(self) == 0: return True
return all(n == self[0] for n in self)
12 changes: 6 additions & 6 deletions emblaze/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _make_neighbor_changes(self, idx_1, idx_2, filter_points=None):
"""
frame_1 = self.embeddings[idx_1]
frame_2 = self.embeddings[idx_2]
frame_1_neighbors = frame_1.field(Field.NEIGHBORS, ids=filter_points or None)
frame_2_neighbors = frame_2.field(Field.NEIGHBORS, ids=filter_points or None)
frame_1_neighbors = frame_1.get_recent_neighbors()[filter_points or None]
frame_2_neighbors = frame_2.get_recent_neighbors()[filter_points or None]
gained_ids = [set(frame_2_neighbors[i]) - set(frame_1_neighbors[i]) for i in range(len(filter_points or frame_1))]
lost_ids = [set(frame_1_neighbors[i]) - set(frame_2_neighbors[i]) for i in range(len(filter_points or frame_1))]

Expand All @@ -79,15 +79,15 @@ def _consistency_score(self, ids, frame):
Computes the consistency between the neighbors for the given set of IDs
in the given frame.
"""
return (np.sum(1 - self._pairwise_jaccard_distances(frame.field(Field.NEIGHBORS, ids=ids))) - len(ids)) / (len(ids) * (len(ids) - 1))
return (np.sum(1 - self._pairwise_jaccard_distances(frame.get_recent_neighbors()[ids])) - len(ids)) / (len(ids) * (len(ids) - 1))

def _inner_change_score(self, ids, frame_1, frame_2):
"""
Computes the inverse intersection of the neighbor sets in the given
two frames.
"""
return np.mean(inverse_intersection(frame_1.field(Field.NEIGHBORS, ids=ids),
frame_2.field(Field.NEIGHBORS, ids=ids),
return np.mean(inverse_intersection(frame_1.get_recent_neighbors()[ids],
frame_2.get_recent_neighbors()[ids],
List(ids),
False))

Expand Down Expand Up @@ -180,7 +180,7 @@ def query(self, ids_of_interest=None, filter_ids=None, frame_idx=None, preview_f

# Assemble a list of candidates
if ids_of_interest is not None:
neighbor_ids = set([n for n in self.embeddings[frame_key[0]].field(Field.NEIGHBORS, ids_of_interest)[:,:NUM_NEIGHBORS_FOR_SEARCH].flatten()])
neighbor_ids = set([n for n in self.embeddings[frame_key[0]].get_recent_neighbors()[ids_of_interest][:,:NUM_NEIGHBORS_FOR_SEARCH].flatten()])
else:
neighbor_ids = None

Expand Down
18 changes: 17 additions & 1 deletion emblaze/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class Field:
POSITION = "position"
COLOR = "color"
RADIUS = "r"
NEIGHBORS = "highlight"
ALPHA = "alpha"

# Thumbnail fields
Expand Down Expand Up @@ -174,6 +173,19 @@ def choose_integer_type(values):
return np.uint16, "u2"
return np.uint8, "u1"

def _detect_numerical_sequence(arr):
"""
Detects a numerical sequence to compress large arrays of integer IDs when
they are regularly spaced. If a sequence is detected, returns the start, end,
and step such that using np.arange() with these three arguments yields the
appropriate result. If no sequence is detected, returns None.
"""
diffs = arr[1:] - arr[:-1]
if np.allclose(diffs, diffs[0]):
step = diffs[0]
return (arr[0], arr[-1] + step, step)
return None

def encode_numerical_array(arr, astype=np.float32, positions=None, interval=None):
"""
Encodes the given numpy array into a base64 representation for fast transfer
Expand All @@ -187,6 +199,10 @@ def encode_numerical_array(arr, astype=np.float32, positions=None, interval=None
If interval is not None, it is passed into the result object directly (and
signifies the same as positions, but with a regularly spaced interval).
"""
# TODO support saving arrays as numerical sequence metadata
# sequence_info = _detect_numerical_sequence(arr)
# if sequence_info is not None:
# result = { ""}
result = { "values": base64.b64encode(arr.astype(astype)).decode('ascii') }
if positions is not None:
result["positions"] = base64.b64encode(positions.astype(np.int32)).decode('ascii')
Expand Down