File tree Expand file tree Collapse file tree 5 files changed +27
-7
lines changed
Expand file tree Collapse file tree 5 files changed +27
-7
lines changed Original file line number Diff line number Diff line change 1+ import requests
2+ import json
3+
4+ def getAnswerFromChatGLM6b (context ):
5+ url = 'http://172.16.62.136:8000/'
6+ data = '{"prompt": "' + context + '", "history": []}'
7+ headers = {'content-type' : 'application/json;charset=utf-8' }
8+ r = requests .post (url , data = data .encode (), headers = headers )
9+ res = r .json ()
10+ if r .status_code == 200 :
11+ return res ['response' ]
12+ else :
13+ return ""
Original file line number Diff line number Diff line change @@ -212,7 +212,7 @@ function App() {
212212 title : 'More' ,
213213 } ,
214214 ] ,
215- title : '基于Salesforce codegen和GPT-J -6B的AI代码生成' ,
215+ title : '基于Salesforce codegen和清华THUDM/ChatGLM -6B的AI代码生成' ,
216216 } }
217217 messages = { messages }
218218 renderMessageContent = { renderMessageContent }
Original file line number Diff line number Diff line change 99from jaxformer .hf .sample import load_model ,sampling
1010from gpt_j import gpt_load_model ,gpt_generate
1111from codegen_stream import codegen_stream
12+ from ChatGLM_6b import getAnswerFromChatGLM6b
1213
1314ROOT = os .path .dirname (__file__ )
1415
@@ -43,7 +44,8 @@ async def codegen(request):
4344 context = context .replace ("//" ,"" ).replace ("#" ,"" ).strip ()
4445 stop = False
4546 if flag_chs :#or content.startwith('gpt-j') :
46- result = getAnswerFromChatGPTJ (context ,maxlength ).replace (context ,"" )
47+ # result = getAnswerFromChatGPTJ(context,maxlength).replace(context,"")
48+ result = getAnswerFromChatGLM6b (context )
4749 else :
4850 result ,stop = sampling (context ,maxlength )
4951 end = time .perf_counter ()
Original file line number Diff line number Diff line change 44import json
55from jaxformer .hf .sample import load_model , sampling
66from gpt_j import gpt_load_model , gpt_generate_stream
7+ from ChatGLM_6b import getAnswerFromChatGLM6b
78
89
910def sampling_gptj (context , maxlength ):
@@ -25,10 +26,13 @@ async def codegen_stream(request):
2526 flag_chs = f (context )
2627 stop = False
2728 if flag_chs :
28- results = sampling_gptj (context , maxlength )
29- results = json .loads (results )
30- result_en = results ["result_en" ]
31- result_ch = results ["result_ch" ]
29+ # results = sampling_gptj(context, maxlength)
30+ # results = json.loads(results)
31+ # result_en = results["result_en"]
32+ # result_ch = results["result_ch"]
33+ result_en = getAnswerFromChatGLM6b (context )
34+ result_ch = result_en
35+ stop = True
3236 else :
3337 result_en ,stop = sampling (context , maxlength )
3438 result_ch = result_en
Original file line number Diff line number Diff line change @@ -207,7 +207,8 @@ def load_model():
207207 device = torch .device ('cuda:0' )
208208 use_fp16 = True
209209 model_name = "codegen-6B-mono"
210- # model_name = "codegen-350M-mono" # test on windows
210+ if os .name == 'nt' :
211+ model_name = "codegen-350M-mono" # test on windows
211212 ckpt = f'./checkpoints/{ model_name } '
212213 # (3) load
213214 with print_time ('loading parameters' ):
You can’t perform that action at this time.
0 commit comments