Skip to content

Commit 6d5d615

Browse files
committed
change gpt-j to THUDM/ChatGLM-6B
1 parent 4a1633f commit 6d5d615

File tree

5 files changed

+27
-7
lines changed

5 files changed

+27
-7
lines changed

ChatGLM_6b.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import requests
2+
import json
3+
4+
def getAnswerFromChatGLM6b(context):
5+
url = 'http://172.16.62.136:8000/'
6+
data = '{"prompt": "' + context + '", "history": []}'
7+
headers = {'content-type': 'application/json;charset=utf-8'}
8+
r = requests.post(url, data=data.encode(), headers=headers)
9+
res = r.json()
10+
if r.status_code == 200 :
11+
return res['response']
12+
else:
13+
return ""

chat/src/App.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ function App() {
212212
title: 'More',
213213
},
214214
],
215-
title: '基于Salesforce codegen和GPT-J-6B的AI代码生成',
215+
title: '基于Salesforce codegen和清华THUDM/ChatGLM-6B的AI代码生成',
216216
}}
217217
messages={messages}
218218
renderMessageContent={renderMessageContent}

codegen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from jaxformer.hf.sample import load_model,sampling
1010
from gpt_j import gpt_load_model,gpt_generate
1111
from codegen_stream import codegen_stream
12+
from ChatGLM_6b import getAnswerFromChatGLM6b
1213

1314
ROOT = os.path.dirname(__file__)
1415

@@ -43,7 +44,8 @@ async def codegen(request):
4344
context = context.replace("//","").replace("#","").strip()
4445
stop = False
4546
if flag_chs :#or content.startwith('gpt-j') :
46-
result = getAnswerFromChatGPTJ(context,maxlength).replace(context,"")
47+
# result = getAnswerFromChatGPTJ(context,maxlength).replace(context,"")
48+
result = getAnswerFromChatGLM6b(context)
4749
else:
4850
result,stop = sampling(context,maxlength)
4951
end = time.perf_counter()

codegen_stream.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
from jaxformer.hf.sample import load_model, sampling
66
from gpt_j import gpt_load_model, gpt_generate_stream
7+
from ChatGLM_6b import getAnswerFromChatGLM6b
78

89

910
def sampling_gptj(context, maxlength):
@@ -25,10 +26,13 @@ async def codegen_stream(request):
2526
flag_chs = f(context)
2627
stop = False
2728
if flag_chs:
28-
results = sampling_gptj(context, maxlength)
29-
results = json.loads(results)
30-
result_en = results["result_en"]
31-
result_ch = results["result_ch"]
29+
# results = sampling_gptj(context, maxlength)
30+
# results = json.loads(results)
31+
# result_en = results["result_en"]
32+
# result_ch = results["result_ch"]
33+
result_en = getAnswerFromChatGLM6b(context)
34+
result_ch = result_en
35+
stop = True
3236
else:
3337
result_en,stop = sampling(context, maxlength)
3438
result_ch = result_en

jaxformer/hf/sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def load_model():
207207
device = torch.device('cuda:0')
208208
use_fp16 = True
209209
model_name = "codegen-6B-mono"
210-
# model_name = "codegen-350M-mono" # test on windows
210+
if os.name == 'nt':
211+
model_name = "codegen-350M-mono" # test on windows
211212
ckpt = f'./checkpoints/{model_name}'
212213
# (3) load
213214
with print_time('loading parameters'):

0 commit comments

Comments
 (0)