-
Notifications
You must be signed in to change notification settings - Fork 456
/
chat_completion.py
106 lines (81 loc) · 3.53 KB
/
chat_completion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import List, Literal, Optional, Union
from .base import BaseInferenceType
ChatCompletionMessageRole = Literal["assistant", "system", "user"]
@dataclass
class ChatCompletionInputMessage(BaseInferenceType):
content: str
"""The content of the message."""
role: "ChatCompletionMessageRole"
@dataclass
class ChatCompletionInput(BaseInferenceType):
"""Inputs for ChatCompletion inference"""
messages: List[ChatCompletionInputMessage]
frequency_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
frequency in the text so far, decreasing the model's likelihood to repeat the same line
verbatim.
"""
max_tokens: Optional[int] = None
"""The maximum number of tokens that can be generated in the chat completion."""
seed: Optional[int] = None
"""The random sampling seed."""
stop: Optional[Union[List[str], str]] = None
"""Stop generating tokens if a stop token is generated."""
stream: Optional[bool] = None
"""If set, partial message deltas will be sent."""
temperature: Optional[float] = None
"""The value used to modulate the logits distribution."""
top_p: Optional[float] = None
"""If set to < 1, only the smallest set of most probable tokens with probabilities that add
up to `top_p` or higher are kept for generation.
"""
ChatCompletionFinishReason = Literal["length", "eos_token", "stop_sequence"]
@dataclass
class ChatCompletionOutputChoiceMessage(BaseInferenceType):
content: str
"""The content of the chat completion message."""
role: "ChatCompletionMessageRole"
@dataclass
class ChatCompletionOutputChoice(BaseInferenceType):
finish_reason: "ChatCompletionFinishReason"
"""The reason why the generation was stopped."""
index: int
"""The index of the choice in the list of choices."""
message: ChatCompletionOutputChoiceMessage
@dataclass
class ChatCompletionOutput(BaseInferenceType):
"""Outputs for Chat Completion inference"""
choices: List[ChatCompletionOutputChoice]
"""A list of chat completion choices."""
created: int
"""The Unix timestamp (in seconds) of when the chat completion was created."""
@dataclass
class ChatCompletionStreamOutputDelta(BaseInferenceType):
"""A chat completion delta generated by streamed model responses."""
content: Optional[str] = None
"""The contents of the chunk message."""
role: Optional[str] = None
"""The role of the author of this message."""
@dataclass
class ChatCompletionStreamOutputChoice(BaseInferenceType):
delta: ChatCompletionStreamOutputDelta
"""A chat completion delta generated by streamed model responses."""
index: int
"""The index of the choice in the list of choices."""
finish_reason: Optional["ChatCompletionFinishReason"] = None
"""The reason why the generation was stopped."""
@dataclass
class ChatCompletionStreamOutput(BaseInferenceType):
"""Chat Completion Stream Output"""
choices: List[ChatCompletionStreamOutputChoice]
"""A list of chat completion choices."""
created: int
"""The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has
the same timestamp.
"""