/
text_to_audio.py
105 lines (90 loc) · 4.71 KB
/
text_to_audio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union
from .base import BaseInferenceType
EarlyStoppingEnum = Literal["never"]
@dataclass
class TextToAudioGenerationParameters(BaseInferenceType):
"""Parametrization of the text generation process
Ad-hoc parametrization of the text generation process
"""
do_sample: Optional[bool] = None
"""Whether to use sampling instead of greedy decoding when generating new tokens."""
early_stopping: Optional[Union[bool, "EarlyStoppingEnum"]] = None
"""Controls the stopping condition for beam-based methods."""
epsilon_cutoff: Optional[float] = None
"""If set to float strictly between 0 and 1, only tokens with a conditional probability
greater than epsilon_cutoff will be sampled. In the paper, suggested values range from
3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language
Model Desmoothing](https://hf.co/papers/2210.15191) for more details.
"""
eta_cutoff: Optional[float] = None
"""Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to
float strictly between 0 and 1, a token is only considered if it is greater than either
eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter
term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In
the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191)
for more details.
"""
max_length: Optional[int] = None
"""The maximum length (in tokens) of the generated text, including the input."""
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate. Takes precedence over maxLength."""
min_length: Optional[int] = None
"""The minimum length (in tokens) of the generated text, including the input."""
min_new_tokens: Optional[int] = None
"""The minimum number of tokens to generate. Takes precedence over maxLength."""
num_beam_groups: Optional[int] = None
"""Number of groups to divide num_beams into in order to ensure diversity among different
groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details.
"""
num_beams: Optional[int] = None
"""Number of beams to use for beam search."""
penalty_alpha: Optional[float] = None
"""The value balances the model confidence and the degeneration penalty in contrastive
search decoding.
"""
temperature: Optional[float] = None
"""The value used to modulate the next token probabilities."""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
top_p: Optional[float] = None
"""If set to float < 1, only the smallest set of most probable tokens with probabilities
that add up to top_p or higher are kept for generation.
"""
typical_p: Optional[float] = None
"""Local typicality measures how similar the conditional probability of predicting a target
token next is to the expected conditional probability of predicting a random token next,
given the partial text already generated. If set to float < 1, the smallest set of the
most locally typical tokens with probabilities that add up to typical_p or higher are
kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details.
"""
use_cache: Optional[bool] = None
"""Whether the model should use the past last key/values attentions to speed up decoding"""
@dataclass
class TextToAudioParameters(BaseInferenceType):
"""Additional inference parameters
Additional inference parameters for Text To Audio
"""
generate: Optional[TextToAudioGenerationParameters] = None
"""Parametrization of the text generation process"""
@dataclass
class TextToAudioInput(BaseInferenceType):
"""Inputs for Text To Audio inference"""
inputs: str
"""The input text data"""
parameters: Optional[TextToAudioParameters] = None
"""Additional inference parameters"""
@dataclass
class TextToAudioOutput(BaseInferenceType):
"""Outputs of inference for the Text To Audio task"""
audio: Any
"""The generated audio waveform."""
sampling_rate: Any
text_to_audio_output_sampling_rate: Optional[float] = None
"""The sampling rate of the generated audio waveform."""