Skip to content

Commit 2c379ed

Browse files
committed
add stream support
1 parent d36f7aa commit 2c379ed

File tree

7 files changed

+160
-54
lines changed

7 files changed

+160
-54
lines changed

chat/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "chat",
3-
"version": "0.1.0",
3+
"version": "1.0.0",
44
"private": true,
55
"dependencies": {
66
"@chatui/core": "^2.4.2",

chat/src/App.js

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import './App.css';
2-
import Chat, { Bubble, useMessages } from '@chatui/core';
2+
import Chat, { Bubble, useMessages, Progress } from '@chatui/core';
33
import '@chatui/core/dist/index.css';
4-
import React, { useEffect } from 'react'
4+
import '@chatui/core/es/styles/index.less';
5+
import React, { useEffect, useState } from 'react';
6+
import './chatui-theme.css';
57

68
const defaultQuickReplies = [
79
{
@@ -48,19 +50,20 @@ const initialMessages = [
4850

4951
function App() {
5052
const { messages, appendMsg, setTyping } = useMessages(initialMessages);
53+
const [percentage, setPercentage] = useState(0);
5154

5255
function handleSend(type, val) {
5356
if (type === 'text' && val.trim()) {
5457
appendMsg({
5558
type: 'text',
5659
content: { text: val },
57-
position: 'right',
60+
position: 'left',
5861
user: { avatar: '//gitclone.com/download1/user.png' },
5962
});
6063

6164
setTyping(true);
62-
63-
onGenCode(val);
65+
setPercentage(10);
66+
onGenCode(val, val, 0);
6467
}
6568
}
6669

@@ -90,7 +93,7 @@ function App() {
9093
} else if (item.name === "Java") {
9194
content = "int add(int x,int y){";
9295
} else if (item.name === "javascript") {
93-
content = "function Add(x,y,z){";
96+
content = "function Add(x,y){";
9497
} else if (item.name === "golang") {
9598
content = "func IsBlacklist(bl []string,url string) bool{";
9699
} else {
@@ -99,26 +102,57 @@ function App() {
99102
handleSend('text', content);
100103
}
101104

102-
function onGenCode(context) {
103-
var sl = context.trim().split("\n");
104-
context = sl[sl.length - 1];
105-
if (context.trim() === "") {
106-
alert("输入不能为空!")
105+
function onGenCode(context_en, context_ch, count) {
106+
if (count >= 5) {
107+
setPercentage(0);
107108
return;
108109
}
109110
let xhr = new XMLHttpRequest();
110-
xhr.open('post', 'https://gitclone.com/aiit/codegen');
111+
xhr.open('post', 'https://gitclone.com/aiit/codegen_stream');
111112
xhr.setRequestHeader('Content-Type', 'application/json');
112113
xhr.onload = function () {
113114
var json = JSON.parse(xhr.response);
114-
context = context + "\n" + json.result;
115-
appendMsg({
116-
type: 'text',
117-
content: { text: context },
118-
user: { avatar: '//gitclone.com/download1/gitclone.png' },
119-
});
115+
if (count === 0) {
116+
context_en = context_en + "\n" + json.result_en;
117+
context_ch = context_ch + "\n" + json.result_ch;
118+
appendMsg({
119+
type: 'text',
120+
content: { text: context_ch },
121+
user: { avatar: '//gitclone.com/download1/gitclone.png' },
122+
});
123+
} else {
124+
if (("" === json.result_en.trim()) || json.result_en.trim().startsWith("A:") || json.result_en.trim().endsWith("A:")) {
125+
setPercentage(0);
126+
return;
127+
}
128+
context_en = context_en + json.result_en;
129+
context_ch = context_ch + json.result_ch;
130+
if (context_ch === context_en) {
131+
updateMsg(context_en);
132+
} else {
133+
updateMsg(context_ch + "\n" + context_en);
134+
}
135+
136+
}
137+
count++;
138+
setPercentage(count * 20);
139+
onGenCode(context_en, context_ch, count);
140+
}
141+
xhr.send(JSON.stringify({
142+
"context": context_en,
143+
"maxlength": 16,
144+
"modelname": "codegen"
145+
}));
146+
147+
function updateMsg(context_ch) {
148+
var oUl = document.getElementById('root');
149+
var aBox = getByClass(oUl, 'Bubble text');
150+
if (aBox.length > 0) {
151+
aBox[aBox.length - 1].innerHTML = "<p>" + context_ch + "</p>";
152+
var msgList = getByClass(oUl, "PullToRefresh")[0];
153+
msgList.scrollTo(0, msgList.scrollHeight);
154+
}
120155
}
121-
xhr.send('{"context":"' + context + '","maxlength":32}');
122156
}
123157

124158
function findInArr(arr, n) {
@@ -169,14 +203,15 @@ function App() {
169203
title: 'More',
170204
},
171205
],
172-
title: '基于Salesforce codegen和GPTJ、GPT-neo的AI代码生成',
206+
title: '基于Salesforce codegen和GPTJ的AI代码生成',
173207
}}
174208
messages={messages}
175209
renderMessageContent={renderMessageContent}
176210
quickReplies={defaultQuickReplies}
177211
onQuickReplyClick={handleQuickReplyClick}
178212
onSend={handleSend}
179213
/>
214+
<Progress value={percentage} />
180215
</div>
181216
);
182217
}

chat/src/chatui-theme.css

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
:root {
2+
font-size: 16px;
3+
line-height:14px ;
4+
}
5+
.ChatApp,
6+
.MessageContainer,
7+
.Navbar,
8+
.Message .Bubble,
9+
.QuickReplies,
10+
.ChatFooter {
11+
background-repeat: no-repeat;
12+
background-size: cover;
13+
}

codegen.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from functools import lru_cache
88
from aiohttp import web
99
from jaxformer.hf.sample import load_model,sampling
10-
from gpt_neo import gpt_load_model,gpt_generate
10+
from gpt_j import gpt_load_model,gpt_generate
11+
from codegen_stream import codegen_stream
1112

1213
ROOT = os.path.dirname(__file__)
1314

@@ -18,24 +19,9 @@ async def index(request):
1819
return web.Response(content_type="text/html", text=content)
1920

2021
@lru_cache(maxsize=1024, typed=False)
21-
def getAnswerFromChatGPT(context):
22-
url = 'http://chatgptserver.com:5000/chat'
23-
data = '{"message":"' + context + '", "user": "gitclone"}'
24-
headers = {'content-type': 'application/json;charset=utf-8'}
25-
r = requests.post(url,data= data.encode(), headers=headers)
26-
res = r.json()
27-
return res['response']
28-
29-
@lru_cache(maxsize=1024, typed=False)
30-
def getAnswerFromChatGPTJ(context):
31-
#url = 'http://52.82.67.116:8081/generate/'
32-
#data = '{' + '"text": "' + context + '",' + '"generate_tokens_limit": 40,'+ '"top_p": 0.7,'+'"top_k": 0,' + '"temperature":1.0' +'}' ;
33-
#headers = {'content-type': 'application/json;charset=utf-8'}
34-
#r = requests.post(url,data= data.encode(), headers=headers)
35-
#res = r.json()
36-
#return res['completion']
22+
def getAnswerFromChatGPTJ(context,maxlength):
3723
gpt_load_model()
38-
return gpt_generate(context,128)
24+
return gpt_generate(context,maxlength)
3925

4026
async def codegen(request):
4127
params = await request.json()
@@ -56,7 +42,7 @@ async def codegen(request):
5642
print(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime()),"context : " + context)
5743
context = context.replace("//","").replace("#","").strip()
5844
if flag_chs :#or content.startwith('gpt-j') :
59-
result = getAnswerFromChatGPTJ(context).replace(context,"")
45+
result = getAnswerFromChatGPTJ(context,maxlength).replace(context,"")
6046
else:
6147
result = sampling(context,maxlength)
6248
end = time.perf_counter()
@@ -73,6 +59,7 @@ async def codegen(request):
7359
app.router.add_get("/", index)
7460
app.router.add_get("/codegen", index)
7561
app.router.add_post("/codegen", codegen)
62+
app.router.add_post("/codegen_stream", codegen_stream)
7663

7764
for route in list(app.router.routes()):
7865
cors.add(route, {

codegen_stream.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import requests
2+
import time
3+
from aiohttp import web
4+
import json
5+
from jaxformer.hf.sample import load_model, sampling
6+
from gpt_j import gpt_load_model, gpt_generate_stream
7+
8+
9+
def sampling_gptj(context, maxlength):
10+
gpt_load_model()
11+
return gpt_generate_stream(context, maxlength)
12+
13+
14+
async def codegen_stream(request):
15+
params = await request.json()
16+
context = params["context"]
17+
maxlength = params["maxlength"]
18+
modelname = params["modelname"]
19+
start = time.perf_counter()
20+
print(time.strftime("%Y-%m-%d %H:%M:%S",
21+
time.localtime()), "context : " + context)
22+
context = context.strip()
23+
f = lambda x='ddd': sum(
24+
[1 if u'\u4e00' <= i <= u'\u9fff' else 0 for i in x]) > 0
25+
flag_chs = f(context)
26+
if flag_chs:
27+
results = sampling_gptj(context, maxlength)
28+
results = json.loads(results)
29+
result_en = results["result_en"]
30+
result_ch = results["result_ch"]
31+
else:
32+
result_en = sampling(context, maxlength)
33+
result_ch = result_en
34+
end = time.perf_counter()
35+
print(time.strftime("%Y-%m-%d %H:%M:%S",
36+
time.localtime()), "result : " + result_ch)
37+
return web.Response(
38+
content_type="application/json",
39+
text=json.dumps(
40+
{"result_en": result_en, "result_ch": result_ch, "time": end-start}
41+
),
42+
)

gpt_neo.py renamed to gpt_j.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
import gradio as gr
44
import torch
55
import requests
6+
import json
67

7-
#generator = None
8+
# generator = None
89
translator_zh2en = None
910
translator_en2zh = None
1011

1112

1213
def gpt_load_model():
13-
#global generator
14+
# global generator
1415
global translator_zh2en
1516
global translator_en2zh
16-
#if generator is None:
17+
# if generator is None:
1718
# #torch.cuda.set_device('cuda:1')
1819
# generator = pipeline(
1920
# 'text-generation', model='EleutherAI/gpt-neo-1.3B')
@@ -24,17 +25,21 @@ def gpt_load_model():
2425
translator_en2zh = pipeline(
2526
"translation", model="Helsinki-NLP/opus-mt-en-zh")
2627

27-
def getAnswerFromChatGPTJ6B(context):
28-
url = 'http://127.0.0.1:8081/generate/'
29-
data = '{' + '"text": "' + context + '",' + '"generate_tokens_limit": 64,'+ '"top_p": 0.7,'+'"top_k": 0,' + '"temperature":0.9' +'}' ;
28+
29+
def getAnswerFromChatGPTJ6B(context, maxlength):
30+
url = 'http://172.16.62.66:8081/generate/'
31+
data = '{' + '"text": "' + context + '",' + '"generate_tokens_limit": ' + \
32+
str(maxlength) + ',' + '"top_p": 0.7,' + \
33+
'"top_k": 0,' + '"temperature":0.9' + '}'
3034
headers = {'content-type': 'application/json;charset=utf-8'}
31-
r = requests.post(url,data= data.encode(), headers=headers)
35+
r = requests.post(url, data=data.encode(), headers=headers)
3236
res = r.json()
3337
return res['completion']
3438

39+
3540
@lru_cache(maxsize=1024, typed=False)
3641
def gpt_generate(inputs, maxlength):
37-
#global generator
42+
# global generator
3843
global translator_zh2en
3944
global translator_en2zh
4045
f = lambda x='ddd': sum(
@@ -44,25 +49,48 @@ def gpt_generate(inputs, maxlength):
4449
if flag_chs:
4550
inputs = translator_zh2en(inputs)[0]['translation_text']
4651
print("zh2en: ", inputs)
47-
#results = generator(inputs, max_length=int(maxlength),
48-
# do_sample=True, temperature=0.9)
49-
results = getAnswerFromChatGPTJ6B(inputs)
52+
results = getAnswerFromChatGPTJ6B(inputs, maxlength)
5053
print("output: ", results)
5154
if flag_chs:
52-
#results = translator_en2zh(results[0]['generated_text'])
5355
results_en = results
5456
results = translator_en2zh(results)
5557
print("en2zh:", results)
5658
return results_en + '\n' + results[0]['translation_text']
5759
else:
5860
return results
5961

62+
63+
def gpt_generate_stream(inputs, maxlength):
64+
# global generator
65+
global translator_zh2en
66+
global translator_en2zh
67+
f = lambda x='ddd': sum(
68+
[1 if u'\u4e00' <= i <= u'\u9fff' else 0 for i in x]) > 0
69+
print("inputs: ", inputs)
70+
flag_chs = f(inputs)
71+
if flag_chs:
72+
inputs = translator_zh2en(inputs)[0]['translation_text']
73+
print("zh2en: ", inputs)
74+
results = getAnswerFromChatGPTJ6B(inputs,maxlength)
75+
print("output: ", results)
76+
if flag_chs:
77+
results_en = results
78+
results = translator_en2zh(results)
79+
print("en2zh:", results)
80+
return json.dumps(
81+
{"result_en": results_en, "result_ch": results[0]['translation_text']})
82+
else:
83+
return json.dumps(
84+
{"result_en": results, "result_ch": results})
85+
86+
6087
def chat(message, history):
6188
history = history or []
62-
response = gpt_generate(message,128)
89+
response = gpt_generate(message, 128)
6390
history.append((message, response))
6491
return history, history
6592

93+
6694
def create_ui():
6795
chatbot = gr.Chatbot().style(color_map=("green", "gray"))
6896
interface = gr.Interface(
@@ -73,6 +101,7 @@ def create_ui():
73101
)
74102
interface.launch(server_name='0.0.0.0')
75103

104+
76105
if __name__ == "__main__":
77106
torch.cuda.set_device(1)
78107
print("torch gpu: ", torch.cuda.is_available())

jaxformer/hf/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def cast(model, fp16=True):
6767

6868
def create_model(ckpt, fp16=True):
6969
if fp16:
70-
return CodeGenForCausalLM.from_pretrained(ckpt, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=True)
70+
return CodeGenForCausalLM.from_pretrained(ckpt, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=False)
7171
else:
7272
return CodeGenForCausalLM.from_pretrained(ckpt)
7373

0 commit comments

Comments
 (0)