Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How could I set the number of classes for the model? #13

Closed
HeroadZ opened this issue Sep 15, 2020 · 3 comments
Closed

How could I set the number of classes for the model? #13

HeroadZ opened this issue Sep 15, 2020 · 3 comments

Comments

@HeroadZ
Copy link

HeroadZ commented Sep 15, 2020

Hello, thank you for this excellent work!
I'm a beginner of fastai and huggingface.

And I'm now using your library for sequence classfication problem with more than 2 classes.
I think it's necessary to set the number of class for huggingface. (Cause imdb is running correctly while mine cannot.)
But I don't know how according to your docs.
Could you give me some suggestions?

@HeroadZ HeroadZ changed the title How could I set the number of labels for the model? How could I set the number of classes for the model? Sep 15, 2020
@HeroadZ
Copy link
Author

HeroadZ commented Sep 15, 2020

I'm now using a very ugly way 😢 please help me

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(
    pretrained_model_name,  task=task
)
hf_config.num_labels = num_labels
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(
    pretrained_model_name,  task=task, config=hf_config
)

@ohmeow
Copy link
Owner

ohmeow commented Sep 15, 2020

Its not that ugly ... but no reason to call get_hf_objects twice :).

Checkout out the docs here, in particular the multi-label example. You'll see something like this which is a bit more efficient:

task = HF_TASKS_AUTO.SequenceClassification

pretrained_model_name = "roberta-base" # "distilbert-base-uncased" "bert-base-uncased"
config = AutoConfig.from_pretrained(pretrained_model_name)
config.num_labels = len(lbl_cols)

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               task=task, 
                                                                               config=config)

@HeroadZ
Copy link
Author

HeroadZ commented Sep 16, 2020

Thanks a lot! That's faster and simpler.

@HeroadZ HeroadZ closed this as completed Sep 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants