Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten DefaultClassifier interface #2978

Merged
merged 13 commits into from
Nov 6, 2022
Merged

Conversation

alanakbik
Copy link
Collaborator

@alanakbik alanakbik commented Nov 3, 2022

This PR attempts to flatten the DefaultClassifier interface. We remove the distinction between tensor and non-tensor operations, temporarily removing JIT support until we find a more comprehensive solution.

Changes:

  • The DefaultClassifier now has an embeddings argument in the init method, since all current implementations use embeddings. This removes the need for the _inner_embeddings property which no longer needs to be implemented in the subclasses.

  • All implementing classes now have their embeddings argument called embeddings (rather than word_embeddings/document_embeddings etc.). Hopefully this introduces a bit more consistency across model classes.

  • Previously, the get_prediction_data_points was called twice in the forward pass: once in forward_loss, then again in _prepare_tensors. It is now only called once, directly at the beginning of the forward_loss.

  • Renames _embed_prediction_data_point to _get_embedding_for_data_point (in most cases, the embedding already exists)

  • Changes _get_prediction_data_points to _get_data_points_from_sentence, changing the scope to extracting points for a since Sentence rather than the full batch

Additional changes:

  • Removes the expectation maximization approach from the PrototypicalDecoder as it never really worked and required a special function in the model
  • DataPoint now has __len__

@alanakbik alanakbik changed the title [WIP] Flatten DefaultClassifier interface Flatten DefaultClassifier interface Nov 6, 2022
@alanakbik alanakbik merged commit 615c1ad into master Nov 6, 2022
@alanakbik alanakbik deleted the default-classifier-speedup-2 branch November 6, 2022 21:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant