Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to enable 'multi_class' predictor within Sagemaker? #123

Closed
millnerryan opened this issue Dec 9, 2021 · 1 comment
Closed

How to enable 'multi_class' predictor within Sagemaker? #123

millnerryan opened this issue Dec 9, 2021 · 1 comment

Comments

@millnerryan
Copy link

I'm using the zeroshot pipeline with the valhalla/distilbart-mnli-12-9 model. How do I enable multi_class classification? When using the transformer w/ pytorch in python, I pass the argument multi_class=True, but I can't find the appropriate way to do this in Sagemaker. See code below:

from sagemaker.huggingface.model import HuggingFaceModel

# Hub Model configuration. <https://huggingface.co/models>
model = 'valhalla/distilbart-mnli-12-9'

hub = {
  'HF_MODEL_ID': model, # model_id from hf.co/models
  'HF_TASK':'zero-shot-classification' # NLP task you want to use for predictions,
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   env=hub, # configuration for loading model from Hub
   role=role, # iam role with permissions to create an Endpoint
   transformers_version="4.6", # transformers version used
   pytorch_version="1.7", # pytorch version used
    py_version="py36"
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
   initial_instance_count=1,
   instance_type="ml.p2.xlarge",
    multi_label= True
)
@millnerryan
Copy link
Author

Figured it out, here's how!

data = {
    "inputs": s,
    "parameters": {
        "candidate_labels": [
            "hostess",
            "waiter"
        ],
        "multi_label": True
    
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant