Skip to content

Commit

Permalink
fix: add examples fashion (#2898)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbp committed Jul 9, 2021
1 parent a55432d commit 3598413
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 5 deletions.
4 changes: 2 additions & 2 deletions jina/helloworld/fashion/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
index_generator,
query_generator,
)
from my_executors import MyEncoder, MyIndexer, MyEvaluator
from my_executors import MyEncoder, MyIndexer, MyEvaluator, MyConverter
else:
from .helper import (
print_result,
Expand All @@ -21,7 +21,7 @@
index_generator,
query_generator,
)
from .my_executors import MyEncoder, MyIndexer, MyEvaluator
from .my_executors import MyEncoder, MyIndexer, MyEvaluator, MyConverter

cur_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down
12 changes: 9 additions & 3 deletions jina/helloworld/fashion/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def index_generator(num_docs: int, target: dict):
:yields: index data
"""
for internal_doc_id in range(num_docs):
d = Document(content=target['index']['data'][internal_doc_id])
# x_blackwhite.shape is (28,28)
x_blackwhite = target['index']['data'][internal_doc_id]
# x_color.shape is (28,28,3)
x_color = np.stack((x_blackwhite,) * 3, axis=-1).astype(np.uint8)
d = Document(content=x_color)
d.tags['id'] = internal_doc_id
yield d

Expand All @@ -70,7 +74,9 @@ def query_generator(num_docs: int, target: dict, with_groundtruth: bool = True):
for _ in range(num_docs):
num_data = len(target['query-labels']['data'])
idx = random.randint(0, num_data - 1)
d = Document(content=(target['query']['data'][idx]))
x = target['query']['data'][idx]
x_stacked = np.stack((x,) * 3, axis=-1).astype(np.uint8)
d = Document(content=x_stacked)

if with_groundtruth:
gt = gts[target['query-labels']['data'][idx][0]]
Expand Down Expand Up @@ -182,7 +188,7 @@ def load_mnist(path):
"""

with gzip.open(path, 'rb') as fp:
return np.frombuffer(fp.read(), dtype=np.uint8, offset=16).reshape([-1, 784])
return np.frombuffer(fp.read(), dtype=np.uint8, offset=16).reshape([-1, 28, 28])


def load_labels(path: str):
Expand Down
67 changes: 67 additions & 0 deletions jina/helloworld/fashion/my_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,31 @@


class MyIndexer(Executor):
"""
Executor with basic exact search using cosine distance
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._docs = DocumentArrayMemmap(self.workspace + '/indexer')

@requests(on='/index')
def index(self, docs: 'DocumentArray', **kwargs):
"""Extend self._docs
:param docs: DocumentArray containing Documents
:param kwargs: other keyword arguments
"""
self._docs.extend(docs)

@requests(on=['/search', '/eval'])
def search(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
"""Append best matches to each document in docs
:param docs: documents that are searched
:param parameters: dictionary of pairs (parameter,value)
:param kwargs: other keyword arguments
"""
a = np.stack(docs.get_attributes('embedding'))
b = np.stack(self._docs.get_attributes('embedding'))
q_emb = _ext_A(_norm(a))
Expand All @@ -33,6 +48,11 @@ def search(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
def _get_sorted_top_k(
dist: 'np.array', top_k: int
) -> Tuple['np.ndarray', 'np.ndarray']:
"""Sort and select top k distances
:param dist: array of distances
:param top_k: number of values to retrieve
:return: indices and distances
"""
if top_k >= dist.shape[1]:
idx = dist.argsort(axis=1)[:, :top_k]
dist = np.take_along_axis(dist, idx, axis=1)
Expand All @@ -47,6 +67,10 @@ def _get_sorted_top_k(


class MyEncoder(Executor):
"""
Encode data using SVD decomposition
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
np.random.seed(1337)
Expand All @@ -57,15 +81,40 @@ def __init__(self, **kwargs):

@requests
def encode(self, docs: 'DocumentArray', **kwargs):
"""Encode the data using an SVD decomposition
:param docs: input documents to update with an embedding
:param kwargs: other keyword arguments
"""
# reduce dimension to 50 by random orthogonal projection
content = np.stack(docs.get_attributes('content'))
# content.shape=(request_size, 28, 28, 3)
content = content[:, :, :, 0].reshape(-1, 784)
# content.shape=(request_size, 784)
embeds = (content.reshape([-1, 784]) / 255) @ self.oth_mat
for doc, embed in zip(docs, embeds):
doc.embedding = embed
doc.convert_image_blob_to_uri(width=28, height=28)
doc.pop('blob')


class MyConverter(Executor):
"""
Convert DocumentArrays removing blob and reshaping blob as image
"""

@requests
def convert(self, docs: 'DocumentArray', **kwargs):
"""
Remove blob and reshape documents as squared images
:param docs: documents to modify
:param kwargs: other keyword arguments
"""
for doc in docs:
doc.convert_image_blob_to_uri(width=28, height=28)
doc.pop('blob')


def _get_ones(x, y):
return np.ones((x, y))

Expand Down Expand Up @@ -101,6 +150,10 @@ def _cosine(A_norm_ext, B_norm_ext):


class MyEvaluator(Executor):
"""
Executor that evaluates precision and recall
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.eval_at = 50
Expand All @@ -110,10 +163,18 @@ def __init__(self, **kwargs):

@property
def avg_precision(self):
"""
Computes precision
:return: precision values
"""
return self.total_precision / self.num_docs

@property
def avg_recall(self):
"""
Computes recall
:return: np.ndarray with recall values
"""
return self.total_recall / self.num_docs

def _precision(self, actual, desired):
Expand All @@ -133,6 +194,12 @@ def _recall(self, actual, desired):

@requests(on='/eval')
def evaluate(self, docs: 'DocumentArray', groundtruths: 'DocumentArray', **kwargs):
"""Evaluate documents using the class values from ground truths
:param docs: documents to evaluate
:param groundtruths: ground truth for the documents
:param kwargs: other keyword arguments
"""
for doc, groundtruth in zip(docs, groundtruths):
self.num_docs += 1
actual = [match.tags['id'] for match in doc.matches]
Expand Down

0 comments on commit 3598413

Please sign in to comment.