-
Notifications
You must be signed in to change notification settings - Fork 311
/
web_server_single.py
107 lines (86 loc) · 2.88 KB
/
web_server_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import json
from pathlib import Path
from typing import List
from pydantic import BaseModel
from fastapi import FastAPI
import uvicorn
import torch.distributed as dist
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def get_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_dir", type=str, default="/llama_data/7B")
parser.add_argument(
"--tokenizer_path", type=str, default="/llama_data/tokenizer.model"
)
parser.add_argument("--max_seq_len", type=int, default=512)
parser.add_argument("--max_batch_size", type=int, default=1)
return parser.parse_args()
app = FastAPI()
def load(
ckpt_dir: str,
tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
return generator
def init_generator(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int = 512,
max_batch_size: int = 1,
):
local_rank, world_size = 0, 1
generator = load(
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)
return generator
if __name__ == "__main__":
args = get_args()
generator = init_generator(
args.ckpt_dir,
args.tokenizer_path,
args.max_seq_len,
args.max_batch_size,
)
class Config(BaseModel):
prompts: List[str]
max_gen_len: int
temperature: float = 0.8
top_p: float = 0.95
@app.post("/llama/")
def generate(config: Config):
if len(config.prompts) > args.max_batch_size:
return {"error": "too much prompts."}
for prompt in config.prompts:
if len(prompt) + config.max_gen_len > args.max_seq_len:
return {"error": "max_gen_len too large."}
results = generator.generate(
config.prompts,
max_gen_len=config.max_gen_len,
temperature=config.temperature,
top_p=config.top_p,
)
return {"responses": results}
uvicorn.run(app, host="0.0.0.0", port=8080)