Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow vector queries too #37

Merged
merged 2 commits into from
Mar 28, 2023
Merged

Allow vector queries too #37

merged 2 commits into from
Mar 28, 2023

Conversation

koaning
Copy link
Owner

@koaning koaning commented Mar 28, 2023

Seems useful.

@koaning koaning merged commit 403c6f7 into main Mar 28, 2023
@koaning koaning deleted the query_vector branch March 28, 2023 15:17
@twielfaert
Copy link

twielfaert commented Mar 29, 2023

I think there's a little error in your example in README.md. You probably wanted to show this?

v_pork = encoder.transform("pork")
index.query_vector(v_pork)

@koaning
Copy link
Owner Author

koaning commented Mar 29, 2023

Per scikit-learn standards, dataprep.transform() will return a two dimensional object. In this demo I'm only interested in the first item, which is a vector.

Or does this code create an error on your end?

@twielfaert
Copy link

When I run your example, I get this:

IndexError                                Traceback (most recent call last)
[<ipython-input-4-118d6f1b28f1>](https://localhost:8080/#) in <cell line: 18>()
     16 # You can also query using vectors
     17 v_pork = encoder.transform("pork")[0]
---> 18 texts, dists = index.query(v_pork)

5 frames
[/usr/local/lib/python3.9/dist-packages/simsity/__init__.py](https://localhost:8080/#) in query(self, query, n)
     28         The object handles the encoder/data from disk.
     29         """
---> 30         arr = self.encoder.transform(query)
     31         return self.query_vector(query=arr, n=n)
     32 

[/usr/local/lib/python3.9/dist-packages/sklearn/utils/_set_output.py](https://localhost:8080/#) in wrapped(self, X, *args, **kwargs)
    138     @wraps(f)
    139     def wrapped(self, X, *args, **kwargs):
--> 140         data_to_wrap = f(self, X, *args, **kwargs)
    141         if isinstance(data_to_wrap, tuple):
    142             # only wrap the first output for cross decomposition

[/usr/local/lib/python3.9/dist-packages/embetter/text/_sbert.py](https://localhost:8080/#) in transform(self, X, y)
     81             X = X.to_numpy()
     82 
---> 83         return self.tfm.encode(X)

[/usr/local/lib/python3.9/dist-packages/sentence_transformers/SentenceTransformer.py](https://localhost:8080/#) in encode(self, sentences, batch_size, show_progress_bar, output_value, convert_to_numpy, convert_to_tensor, device, normalize_embeddings)
    159         for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
    160             sentences_batch = sentences_sorted[start_index:start_index+batch_size]
--> 161             features = self.tokenize(sentences_batch)
    162             features = batch_to_device(features, device)
    163 

[/usr/local/lib/python3.9/dist-packages/sentence_transformers/SentenceTransformer.py](https://localhost:8080/#) in tokenize(self, texts)
    317         Tokenizes the texts
    318         """
--> 319         return self._first_module().tokenize(texts)
    320 
    321     def get_sentence_features(self, *features):

[/usr/local/lib/python3.9/dist-packages/sentence_transformers/models/Transformer.py](https://localhost:8080/#) in tokenize(self, texts)
    100             batch1, batch2 = [], []
    101             for text_tuple in texts:
--> 102                 batch1.append(text_tuple[0])
    103                 batch2.append(text_tuple[1])
    104             to_tokenize = [batch1, batch2]

IndexError: invalid index to scalar variable.

The result of the transform() seems to be a vector already:

v_pork = encoder.transform("pork")
v_pork.shape
(384,)

@koaning
Copy link
Owner Author

koaning commented Mar 31, 2023

found the issue! should be this:

v_pork = encoder.transform(["pork"])

@twielfaert
Copy link

Makes sense.

But then the next line should be query_vector() rather than query(), no?

@koaning
Copy link
Owner Author

koaning commented Mar 31, 2023

d0h!

Again, well spotted!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants