Skip to content

Commit

Permalink
feat: proposal to generalize to multiple sparse types
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Apr 15, 2021
1 parent cc2e563 commit c8f6089
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 13 deletions.
37 changes: 32 additions & 5 deletions jina/types/document/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import urllib.request
import warnings
from hashlib import blake2b
from typing import Union, Dict, Optional, TypeVar, Any, Tuple, List
from typing import Union, Dict, Optional, TypeVar, Any, Tuple, List, Type

import numpy as np
from google.protobuf import json_format
Expand All @@ -29,6 +29,34 @@

if False:
from scipy.sparse import coo_matrix
# fix type-hint complain for sphinx and flake
from typing import TypeVar
import numpy as np
import scipy
import tensorflow as tf
import torch

EmbeddingType = TypeVar(
'EncodingType',
np.ndarray,
scipy.sparse.csr_matrix,
scipy.sparse.coo_matrix,
scipy.sparse.bsr_matrix,
scipy.sparse.csc_matrix,
torch.sparse_coo_tensor,
tf.SparseTensor,
)

SparseEmbeddingType = TypeVar(
'SparseEmbeddingType',
np.ndarray,
scipy.sparse.csr_matrix,
scipy.sparse.coo_matrix,
scipy.sparse.bsr_matrix,
scipy.sparse.csc_matrix,
torch.sparse_coo_tensor,
tf.SparseTensor,
)

__all__ = ['Document', 'DocumentContentType', 'DocumentSourceType']
DIGEST_SIZE = 8
Expand Down Expand Up @@ -470,20 +498,19 @@ def blob(self, value: Union['np.ndarray', 'jina_pb2.NdArrayProto', 'NdArray']):
self._update_ndarray('blob', value)

@property
def embedding(self) -> 'np.ndarray':
def embedding(self) -> 'EmbeddingType':
"""Return ``embedding`` of the content of a Document.
:return: the embedding from the proto
"""
return NdArray(self._pb_body.embedding).value

@property
def sparse_embedding(self) -> 'coo_matrix':
def get_sparse_embedding(self, sparse_ndarray_cls_type: Type['BaseSparseNdArray'], **kwargs) -> 'SparseEmbeddingType':
"""Return ``embedding`` of the content of a Document as an sparse array.
:return: the embedding from the proto as an sparse array
"""
return NdArray(self._pb_body.embedding, is_sparse=True).value
return NdArray(self._pb_body.embedding, sparse_cls=sparse_ndarray_cls_type, is_sparse=True, **kwargs).value

@embedding.setter
def embedding(self, value: Union['np.ndarray', 'jina_pb2.NdArrayProto', 'NdArray']):
Expand Down
2 changes: 1 addition & 1 deletion jina/types/ndarray/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def value(self):
if stype == 'dense':
return self.dense_cls(self._pb_body.dense).value
elif stype == 'sparse':
return self.sparse_cls(self._pb_body.sparse).value
return self.sparse_cls(self._pb_body.sparse, **self._kwargs).value

@value.setter
def value(self, value):
Expand Down
75 changes: 68 additions & 7 deletions jina/types/sets/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,37 @@

__all__ = ['DocumentSet']

if False:
from scipy.sparse import coo_matrix
# fix type-hint complain for sphinx and flake
from typing import TypeVar
import numpy as np
import scipy
import tensorflow as tf
import torch

EmbeddingType = TypeVar(
'EncodingType',
np.ndarray,
scipy.sparse.csr_matrix,
scipy.sparse.coo_matrix,
scipy.sparse.bsr_matrix,
scipy.sparse.csc_matrix,
torch.sparse_coo_tensor,
tf.SparseTensor,
)

SparseEmbeddingType = TypeVar(
'SparseEmbeddingType',
np.ndarray,
scipy.sparse.csr_matrix,
scipy.sparse.coo_matrix,
scipy.sparse.bsr_matrix,
scipy.sparse.csc_matrix,
torch.sparse_coo_tensor,
tf.SparseTensor,
)


class DocumentSet(TraversableSequence, MutableSequence):
"""
Expand Down Expand Up @@ -163,21 +194,51 @@ def all_embeddings(self) -> Tuple['np.ndarray', 'DocumentSet']:
"""
return self.extract_docs('embedding', stack_contents=True)

@property
def all_sparse_embeddings(self) -> Tuple['coo_matrix', 'DocumentSet']:
"""Return all embeddings from every document in this set as a ndarray
def get_all_sparse_embeddings(self, sparse_cls_type=Type(torch), **kwargs) -> Tuple[SparseEmbeddingType, 'DocumentSet']:
"""Return all embeddings from every document in this set as a sparse array
:return: The corresponding documents in a :class:`DocumentSet`,
and the documents have no embedding in a :class:`DocumentSet`.
:rtype: A tuple of embedding in :class:`np.ndarray`
:rtype: A tuple of embedding amd DocumentSet as sparse arrays
"""
import scipy

def stack_embeddings(embeddings):
if sparse_cls_type == 'scipy':
import scipy
return scipy.sparse.vstack(embeddings)
elif sparse_cls_type == 'torch':
import torch
return torch.vstack(embeddings)
elif sparse_cls_type == 'tf':
import tensorflow as tf
return tf.sparse.concat(axis=1, sp_inputs=embeddings)

def get_sparse_ndarray_type_kwargs():
if sparse_cls_type == scipy.sparse.coo_matrix:
from jina.types.ndarray.sparse.scipy import SparseNdArray
return SparseNdArray, {'sp_format': 'coo'}
elif sparse_cls_type == scipy.sparse.csr_matrix:
from jina.types.ndarray.sparse.scipy import SparseNdArray
return SparseNdArray, {'sp_format': 'csr'}
elif sparse_cls_type == scipy.sparse.bsr_matrix:
from jina.types.ndarray.sparse.scipy import SparseNdArray
return SparseNdArray, {'sp_format': 'bsr'}
elif sparse_cls_type == scipy.sparse.csc_matrix:
from jina.types.ndarray.sparse.scipy import SparseNdArray
return SparseNdArray, {'sp_format': 'csc'}
elif sparse_cls_type == torch.sparse.tensor:
from jina.types.ndarray.sparse.pytorch import SparseNdArray
return SparseNdArray, {}
elif sparse_cls_type == tf.SparseTensor:
from jina.types.ndarray.sparse.pytorch import SparseNdArray
return SparseNdArray, {}

embeddings = []
docs_pts = []
bad_docs = []
sparse_ndarray_type, sparse_kwargs = get_sparse_ndarray_type_kwargs()
for doc in self:
embedding = doc.sparse_embedding
embedding = doc.get_sparse_embedding(sparse_ndarray_cls_type=sparse_ndarray_type, **sparse_kwargs)
if embedding is None:
bad_docs.append(doc)
continue
Expand All @@ -189,7 +250,7 @@ def all_sparse_embeddings(self) -> Tuple['coo_matrix', 'DocumentSet']:
f'found {len(bad_docs)} docs at granularity {bad_docs[0].granularity} are missing sparse_embedding'
)

return scipy.sparse.vstack(embeddings), docs_pts
return stack_embeddings(embeddings), docs_pts

@property
def all_contents(self) -> Tuple['np.ndarray', 'DocumentSet']:
Expand Down

0 comments on commit c8f6089

Please sign in to comment.