forked from rbgo404/Mixral-8x7B
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
34 lines (30 loc) · 1.17 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import contextlib
from get_model import model_initialize,encode_tokens,generate
from huggingface_hub import snapshot_download
import os
class InferlessPythonModel:
def initialize(self):
repo_id = "Inferless/Mixtral-8x7B-v0.1-int8-GPTQ"
model_store = f"/home/{repo_id}"
os.makedirs(f"/home/{repo_id}", exist_ok=True)
snapshot_download(repo_id,local_dir=model_store)
self.tokenizer, self.model = model_initialize(f"{model_store}/model_int8.pth")
self.callback = lambda x : x
def infer(self, inputs):
prompt= inputs['prompt']
encoded = encode_tokens(self.tokenizer,prompt, bos=True, device="cuda")
prof = contextlib.nullcontext()
with prof:
y, metrics = generate(
self.model,
encoded,
max_new_tokens=256,
draft_model=None,
speculate_k=5,
interactive=False,
callback=self.callback,
temperature=0.8,
top_k=200,)
return {'generated_result': self.tokenizer.decode(y.tolist())}
def finalize(self):
pass