Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/GetStarted/界面训练推理.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ web-ui没有传入参数,所有可控部分都在界面中。但是有几个

> WEBUI_SHARE=1 控制gradio是否是share状态
> SWIFT_UI_LANG=en/zh 控制web-ui界面语言
> WEBUI_SERVER server_name参数, web-ui host ip,0.0.0.0代表所有ip均可访问,127.0.0.1代表只允许本机访问
> WEBUI_PORT web-ui的端口号
6 changes: 5 additions & 1 deletion swift/ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,9 @@ def run_ui():
LLMTrain.build_ui(LLMTrain)
LLMInfer.build_ui(LLMInfer)

port = os.environ.get('WEBUI_PORT', None)
app.queue().launch(
height=800, share=bool(int(os.environ.get('WEBUI_SHARE', '0'))))
server_name=os.environ.get('WEBUI_SERVER', None),
server_port=port if port is None else int(port),
height=800,
share=bool(int(os.environ.get('WEBUI_SHARE', '0'))))
28 changes: 21 additions & 7 deletions swift/ui/llm_infer/llm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class LLMInfer(BaseUI):
'en': 'Start to load model, please wait'
}
},
'loaded_alert': {
'value': {
'zh': '模型加载完成',
'en': 'Model loaded'
}
},
'chatbot': {
'value': {
'zh': '对话框',
Expand Down Expand Up @@ -117,19 +123,27 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
clear_history.click(
fn=cls.clear_session, inputs=[], outputs=[prompt, chatbot])
cls.element('load_checkpoint').click(
cls.reset_memory, [], [model_and_template],
show_progress=False).then(
cls.reset_memory, [], [model_and_template])\
.then(cls.reset_loading_button, [], [cls.element('load_checkpoint')]).then(
cls.prepare_checkpoint, [
value for value in cls.elements().values()
if not isinstance(value, (Tab, Accordion))
], [model_and_template],
show_progress=True).then(cls.change_interactive, [],
[prompt])
cls.element('load_checkpoint').click(
], [model_and_template]).then(cls.change_interactive, [],
[prompt]).then( # noqa
cls.clear_session,
inputs=[],
outputs=[prompt, chatbot],
queue=True)
queue=True).then(cls.reset_load_button, [], [cls.element('load_checkpoint')])

@classmethod
def reset_load_button(cls):
gr.Info(cls.locale('loaded_alert', cls.lang)['value'])
return gr.update(
value=cls.locale('load_checkpoint', cls.lang)['value'])

@classmethod
def reset_loading_button(cls):
return gr.update(value=cls.locale('load_alert', cls.lang)['value'])

@classmethod
def reset_memory(cls):
Expand Down
64 changes: 28 additions & 36 deletions swift/ui/llm_train/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
base_tab.element('show_log').click(
Runtime.update_log, [], [cls.element('log')]).then(
Runtime.wait, [base_tab.element('logging_dir')],
[cls.element('log')],
show_progress=True,
queue=True)
[cls.element('log')])

base_tab.element('start_tb').click(
Runtime.start_tb,
Expand All @@ -147,40 +145,34 @@ def wait(cls, logging_dir):
latest_data = ''
lines = collections.deque(
maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
while True:
try:
with open(log_file) as input:
input.seek(offset)
fail_cnt = 0
while True:
try:
with open(log_file, 'r') as input:
input.seek(offset)
fail_cnt = 0
while True:
try:
latest_data += input.read()
offset = input.tell()
if not latest_data:
time.sleep(0.5)
fail_cnt += 1
if fail_cnt > 5:
break

if '\n' not in latest_data:
continue
latest_lines = latest_data.split('\n')
if latest_data[-1] != '\n':
latest_data = latest_lines[-1]
latest_lines = latest_lines[:-1]
else:
latest_data = ''
lines.extend(latest_lines)
yield '\n'.join(lines)
except IOError:
pass

process_name = 'swift'
process_find = False
for proc in psutil.process_iter():
if proc.name() == process_name:
process_find = proc.pid
if not process_find:
break
except UnicodeDecodeError:
continue
offset = input.tell()
if not latest_data:
time.sleep(0.5)
fail_cnt += 1
if fail_cnt > 20:
break

if '\n' not in latest_data:
continue
latest_lines = latest_data.split('\n')
if latest_data[-1] != '\n':
latest_data = latest_lines[-1]
latest_lines = latest_lines[:-1]
else:
latest_data = ''
lines.extend(latest_lines)
yield '\n'.join(lines)
except IOError:
pass

@classmethod
def show_log(cls, logging_dir):
Expand Down