Skip to content

Commit 760c25b

Browse files
committed
✨ implement the back-end call codegen and return the result
1 parent 65c7c35 commit 760c25b

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

codegen_paddle/codegen.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import time
2+
import paddle
3+
from paddlenlp.utils.log import logger
4+
from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM
5+
6+
from .config import DefaultConfig
7+
8+
generate_config = DefaultConfig()
9+
paddle.set_device(generate_config.device)
10+
paddle.set_default_dtype(generate_config.default_dtype)
11+
12+
tokenizer = CodeGenTokenizer.from_pretrained(generate_config.model_name_or_path)
13+
model = CodeGenForCausalLM.from_pretrained(
14+
generate_config.model_name_or_path,
15+
load_state_as_np=generate_config.load_state_as_np)
16+
17+
18+
def gen_code(prompt: str) -> str:
19+
start_time = time.time()
20+
logger.info("Start generating code")
21+
tokenized = tokenizer(prompt,
22+
truncation=True,
23+
return_tensors='pd')
24+
output, _ = model.generate(
25+
tokenized["input_ids"],
26+
max_length=16,
27+
min_length=generate_config.min_length,
28+
decode_strategy=generate_config.decode_strategy,
29+
top_k=generate_config.top_k,
30+
repetition_penalty=generate_config.repetition_penalty,
31+
temperature=generate_config.temperature,
32+
use_faster=generate_config.use_faster,
33+
use_fp16_decoding=generate_config.use_fp16_decoding)
34+
logger.info("Finish generating code")
35+
end_time = time.time()
36+
logger.info(f"Time cost: {end_time - start_time}")
37+
output = tokenizer.decode(output[0], skip_special_tokens=True)
38+
logger.info(f"Generated code: {output}")
39+
return output

codegen_paddle/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
class DefaultConfig:
2+
model_name_or_path = "Salesforce/codegen-350M-mono"
3+
device = "cpu"
4+
temperature = 0.5
5+
top_k = 10
6+
top_p = 1.0
7+
repetition_penalty = 1.0
8+
min_length = 0
9+
max_length = 16
10+
decode_strategy = "greedy_search"
11+
load_state_as_np = True
12+
use_faster = False
13+
use_fp16_decoding = False
14+
default_dtype = "float16" if use_faster and use_fp16_decoding else "float32"

codegen_paddle/handlers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import tornado
88
from tornado.web import StaticFileHandler
99

10+
from .codegen import gen_code
11+
1012

1113
class RouteHandler(APIHandler):
1214
# The following decorator should be present on all verb methods (head, get, post,
@@ -19,7 +21,9 @@ def get(self):
1921
@tornado.web.authenticated
2022
def post(self):
2123
input_data = self.get_json_body()
22-
data = {"received": "{}".format(input_data["code"])}
24+
prompt = input_data["code"]
25+
res = gen_code(prompt)
26+
data = {"received": "{}".format(res)}
2327
self.finish(json.dumps(data))
2428

2529

@@ -45,4 +49,3 @@ def setup_handlers(web_app, url_path):
4549
)
4650
handlers = [("{}/(.*)".format(doc_url), StaticFileHandler, {"path": doc_dir})]
4751
web_app.add_handlers(".*$", handlers)
48-

0 commit comments

Comments
 (0)