Skip to content

Commit c9448b8

Browse files
committed
support ChatGLM_6b chat stream
1 parent d7934d0 commit c9448b8

File tree

5 files changed

+72
-26
lines changed

5 files changed

+72
-26
lines changed

ChatGLM_6b.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import requests
22
import json
3+
import os
34

45
def getAnswerFromChatGLM6b(context):
5-
url = 'http://172.16.62.136:8000/'
6+
if os.name == 'nt':
7+
url = 'http://172.16.62.66:8000/stream'
8+
else:
9+
url = 'http://172.16.62.136:8000/stream'
610
data = '{"prompt": "' + context + '", "history": []}'
711
headers = {'content-type': 'application/json;charset=utf-8'}
812
r = requests.post(url, data=data.encode(), headers=headers)

chat/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "chat",
3-
"version": "1.0.0",
3+
"version": "1.0.1",
44
"private": true,
55
"dependencies": {
66
"@chatui/core": "^2.4.2",

chat/src/App.js

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ const defaultQuickReplies = [
3030
{
3131
icon: 'message',
3232
name: 'golang',
33-
},
33+
},
3434
];
3535

3636

@@ -52,10 +52,10 @@ function App() {
5252
const { messages, appendMsg, setTyping } = useMessages(initialMessages);
5353
const [percentage, setPercentage] = useState(0);
5454

55-
function handleSend(type, val) {
56-
if (percentage>0) {
57-
alert("正在生成,请稍候!") ;
58-
return ;
55+
function handleSend(type, val, item_name) {
56+
if (percentage > 0) {
57+
alert("正在生成,请稍候!");
58+
return;
5959
}
6060
if (type === 'text' && val.trim()) {
6161
appendMsg({
@@ -64,10 +64,14 @@ function App() {
6464
position: 'left',
6565
user: { avatar: '//gitclone.com/download1/user.png' },
6666
});
67-
6867
setTyping(true);
6968
setPercentage(10);
70-
onGenCode(val, val, 0);
69+
if (item_name === undefined) {
70+
if (isChinese(val)) {
71+
item_name = "GPT";
72+
}
73+
}
74+
onGenCode(val, val, 0, item_name);
7175
}
7276
}
7377

@@ -103,23 +107,38 @@ function App() {
103107
} else {
104108
content = "写一个python版的数组排序";
105109
}
106-
handleSend('text', content);
110+
handleSend('text', content, item.name);
111+
}
112+
113+
function Sleep(ms) {
114+
return new Promise(resolve => setTimeout(resolve, ms))
107115
}
108116

109-
function onGenCode(context_en, context_ch, count) {
110-
if (count >= 5) {
117+
async function onGenCode(context_en, context_ch, count, item_name) {
118+
var context_gpt = context_en;
119+
var stop = false;
120+
var x = 5;
121+
if (item_name === "GPT") {
122+
x = 1000;
123+
await Sleep(500);
124+
}
125+
if (count >= x) {
111126
setPercentage(0);
112127
return;
113128
}
114129
let xhr = new XMLHttpRequest();
115130
xhr.open('post', 'https://gitclone.com/aiit/codegen_stream');
131+
//xhr.open('post', 'http://localhost:5000/codegen_stream');
116132
xhr.setRequestHeader('Content-Type', 'application/json');
117133
xhr.onload = function () {
118134
var json = JSON.parse(xhr.response);
119135
if (count === 0) {
120136
context_en = context_en + "\n" + json.result_en;
121137
context_ch = context_ch + "\n" + json.result_ch;
122-
var stop = json.stop ;
138+
stop = json.stop;
139+
if (item_name === "GPT") {
140+
context_ch = json.result_ch;
141+
}
123142
appendMsg({
124143
type: 'text',
125144
content: { text: context_ch },
@@ -132,20 +151,34 @@ function App() {
132151
}
133152
context_en = context_en + json.result_en;
134153
context_ch = context_ch + json.result_ch;
154+
stop = json.stop;
135155
if (context_ch === context_en) {
136-
updateMsg(context_en);
156+
if (item_name === "GPT") {
157+
updateMsg(json.result_en);
158+
}
159+
else {
160+
updateMsg(context_en);
161+
}
137162
} else {
138-
updateMsg(context_ch + "\n" + context_en);
163+
if (item_name === "GPT") {
164+
updateMsg(json.result_en);
165+
} else {
166+
updateMsg(context_ch + "\n" + context_en);
167+
}
139168
}
140169

141170
}
142171
count++;
143172
setPercentage(count * 20);
144-
if(stop){
173+
if (stop) {
145174
setPercentage(0);
146175
return;
147-
}
148-
onGenCode(context_en, context_ch, count);
176+
}
177+
if (item_name === "GPT") {
178+
onGenCode(context_gpt, context_gpt, count, item_name);
179+
} else {
180+
onGenCode(context_en, context_ch, count, item_name);
181+
}
149182
}
150183
xhr.send(JSON.stringify({
151184
"context": context_en,
@@ -187,6 +220,15 @@ function App() {
187220
}
188221
}
189222

223+
function isChinese(s) {
224+
let reg = new RegExp("[\\u4E00-\\u9FFF]+", "g")
225+
if (reg.test(s)) {
226+
return true;
227+
} else {
228+
return false;
229+
}
230+
}
231+
190232
useEffect(() => {
191233
var oUl = document.getElementById('root');
192234
var aBox = getByClass(oUl, 'Input Input--outline Composer-input');

codegen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ async def codegen(request):
4343
print(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime()),"context : " + context)
4444
context = context.replace("//","").replace("#","").strip()
4545
stop = False
46-
if flag_chs :#or content.startwith('gpt-j') :
47-
# result = getAnswerFromChatGPTJ(context,maxlength).replace(context,"")
46+
if flag_chs :
4847
result = getAnswerFromChatGLM6b(context)
48+
stop = result.endswith("[stop]")
49+
result = result.replace("[stop]", "")
4950
else:
5051
result,stop = sampling(context,maxlength)
5152
end = time.perf_counter()

codegen_stream.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,12 @@ async def codegen_stream(request):
2626
flag_chs = f(context)
2727
stop = False
2828
if flag_chs:
29-
# results = sampling_gptj(context, maxlength)
30-
# results = json.loads(results)
31-
# result_en = results["result_en"]
32-
# result_ch = results["result_ch"]
3329
result_en = getAnswerFromChatGLM6b(context)
34-
result_ch = result_en
35-
stop = True
30+
stop = result_en.endswith("[stop]")
31+
result_ch = result_en.replace("[stop]", "")
32+
if result_ch == "" :
33+
result_ch = "思考中"
34+
result_en = result_ch
3635
else:
3736
result_en,stop = sampling(context, maxlength)
3837
result_ch = result_en

0 commit comments

Comments
 (0)