-
Notifications
You must be signed in to change notification settings - Fork 402
/
generate.py
85 lines (68 loc) · 2.22 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import subprocess, os
from serge.models.chat import Chat, ChatParameters
import asyncio
import logging
logger = logging.getLogger(__name__)
async def generate(
prompt: str,
params: ChatParameters
):
CHUNK_SIZE = 64
await params.fetch_all_links()
args = (
"llama",
"--model",
"/usr/src/app/weights/" + params.model + ".bin",
"--prompt",
prompt,
"--n_predict",
str(params.max_length),
"--temp",
str(params.temperature),
"--top_k",
str(params.top_k),
"--top_p",
str(params.top_p),
"--repeat_last_n",
str(params.repeat_last_n),
"--repeat_penalty",
str(params.repeat_penalty),
"--ctx_size",
str(params.context_window),
"--threads",
str(params.n_threads),
"--n_parts",
"1",
)
logger.debug(f"Calling LLaMa with arguments", args)
procLlama = await asyncio.create_subprocess_exec(
*args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
while True:
chunk = await procLlama.stdout.read(CHUNK_SIZE)
if not chunk:
return_code = await procLlama.wait()
if return_code != 0:
error_output = await procLlama.stderr.read()
logger.error(error_output.decode("utf-8"))
raise ValueError(f"RETURN CODE {return_code}\n\n"+error_output.decode("utf-8"))
else:
return
try:
chunk = chunk.decode("utf-8")
except UnicodeDecodeError:
return
yield chunk
async def get_full_prompt_from_chat(chat: Chat, simple_prompt: str):
await chat.fetch_all_links()
await chat.parameters.fetch_link(ChatParameters.init_prompt)
prompt = chat.parameters.init_prompt + "\n\n"
if chat.questions != None:
for question in chat.questions:
if question.error != None: # skip errored out prompts
continue
prompt += "### Instruction:\n" + question.question + "\n"
prompt += "### Response:\n" + question.answer + "\n"
prompt += "### Instruction:\n" + simple_prompt + "\n"
prompt += "### Response:\n"
return prompt