Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of AutoModel logic #3067

Merged
merged 5 commits into from Jan 30, 2023
Merged

Refactoring of AutoModel logic #3067

merged 5 commits into from Jan 30, 2023

Conversation

alanakbik
Copy link
Collaborator

This PR adapts the auto-loading logic from #3011 such that the load function will distinguish between abstract and non-abstract classes:

  • When called on a regular class like SequenceTagger, it will fetch/load a model as always
  • When called on an abstract class like Model, it will try all non-abstract subclasses to infer the type and then load the model

This removes the need to register supported models in each class and automatically allows loading through the class hierarchy.

So the regular way for loading a model still works:

sentiment_tagger = TextClassifier.load('sentiment')

Now also these options work:

# load sentiment model through Model base class
sentiment_tagger = Model.load('sentiment')

# load sentiment model through Classifier base class
sentiment_tagger = Classifier.load('sentiment')

Since most models inherit from Classifier, you can load and run multiple different models with the same code. So, to run three different taggers for sentiment, entities and frames, do:

from flair.data import Sentence
from flair.nn import Classifier

# load three taggers to tag entities, frames and sentiment
tagger_1 = Classifier.load('ner')
tagger_2 = Classifier.load('frame')
tagger_3 = Classifier.load('sentiment')

# example sentence
sentence = Sentence('Dirk celebrated in Essen')

# predict with all three models
tagger_1.predict(sentence)
tagger_2.predict(sentence)
tagger_3.predict(sentence)

# print all predictions
for label in sentence.get_labels():
    print(label)

@alanakbik alanakbik merged commit 2001469 into master Jan 30, 2023
@alanakbik alanakbik deleted the model-loading branch January 30, 2023 14:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant