/
text_generation.py
139 lines (109 loc) · 4.18 KB
/
text_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import Any, List, Literal, Optional
from .base import BaseInferenceType
TypeEnum = Literal["json", "regex"]
@dataclass
class TextGenerationInputGrammarType(BaseInferenceType):
type: "TypeEnum"
value: Any
"""A string that represents a [JSON Schema](https://json-schema.org/).
JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions.
"""
@dataclass
class TextGenerationInputGenerateParameters(BaseInferenceType):
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = None
do_sample: Optional[bool] = None
frequency_penalty: Optional[float] = None
grammar: Optional[TextGenerationInputGrammarType] = None
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = None
seed: Optional[int] = None
stop: Optional[List[str]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_n_tokens: Optional[int] = None
top_p: Optional[float] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
@dataclass
class TextGenerationInput(BaseInferenceType):
"""Text Generation Input.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
inputs: str
parameters: Optional[TextGenerationInputGenerateParameters] = None
stream: Optional[bool] = None
TextGenerationOutputFinishReason = Literal["length", "eos_token", "stop_sequence"]
@dataclass
class TextGenerationOutputPrefillToken(BaseInferenceType):
id: int
logprob: float
text: str
@dataclass
class TextGenerationOutputToken(BaseInferenceType):
id: int
logprob: float
special: bool
text: str
@dataclass
class TextGenerationOutputBestOfSequence(BaseInferenceType):
finish_reason: "TextGenerationOutputFinishReason"
generated_text: str
generated_tokens: int
prefill: List[TextGenerationOutputPrefillToken]
tokens: List[TextGenerationOutputToken]
seed: Optional[int] = None
top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None
@dataclass
class TextGenerationOutputDetails(BaseInferenceType):
finish_reason: "TextGenerationOutputFinishReason"
generated_tokens: int
prefill: List[TextGenerationOutputPrefillToken]
tokens: List[TextGenerationOutputToken]
best_of_sequences: Optional[List[TextGenerationOutputBestOfSequence]] = None
seed: Optional[int] = None
top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None
@dataclass
class TextGenerationOutput(BaseInferenceType):
"""Text Generation Output.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
generated_text: str
details: Optional[TextGenerationOutputDetails] = None
@dataclass
class TextGenerationStreamOutputStreamDetails(BaseInferenceType):
finish_reason: "TextGenerationOutputFinishReason"
generated_tokens: int
seed: Optional[int] = None
@dataclass
class TextGenerationStreamOutputToken(BaseInferenceType):
id: int
logprob: float
special: bool
text: str
@dataclass
class TextGenerationStreamOutput(BaseInferenceType):
"""Text Generation Stream Output.
Auto-generated from TGI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts.
"""
index: int
token: TextGenerationStreamOutputToken
details: Optional[TextGenerationStreamOutputStreamDetails] = None
generated_text: Optional[str] = None
top_tokens: Optional[List[TextGenerationStreamOutputToken]] = None