-
Notifications
You must be signed in to change notification settings - Fork 456
/
fill_mask.py
50 lines (40 loc) · 1.67 KB
/
fill_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import Any, List, Optional
from .base import BaseInferenceType
@dataclass
class FillMaskParameters(BaseInferenceType):
"""Additional inference parameters
Additional inference parameters for Fill Mask
"""
targets: Optional[List[str]] = None
"""When passed, the model will limit the scores to the passed targets instead of looking up
in the whole vocabulary. If the provided targets are not in the model vocab, they will be
tokenized and the first resulting token will be used (with a warning, and that might be
slower).
"""
top_k: Optional[int] = None
"""When passed, overrides the number of predictions to return."""
@dataclass
class FillMaskInput(BaseInferenceType):
"""Inputs for Fill Mask inference"""
inputs: str
"""The text with masked tokens"""
parameters: Optional[FillMaskParameters] = None
"""Additional inference parameters"""
@dataclass
class FillMaskOutputElement(BaseInferenceType):
"""Outputs of inference for the Fill Mask task"""
score: float
"""The corresponding probability"""
sequence: str
"""The corresponding input with the mask token prediction."""
token: int
"""The predicted token id (to replace the masked one)."""
token_str: Any
fill_mask_output_token_str: Optional[str] = None
"""The predicted token (to replace the masked one)."""