Skip to content

Commit 7163356

Browse files
committedJul 13, 2024·
Merge commit '3605cb8defd37b4c26ed50287996be1fd3871f86' into release/2.2
* commit '3605cb8defd37b4c26ed50287996be1fd3871f86': fix dataset_sample & deploy stop_words (#1385) update discord and fix documentation link update discord and fix documentation link error Update discord url which will never expire in README.md
2 parents 2c8753a + 3605cb8 commit 7163356

File tree

6 files changed

+21
-16
lines changed

6 files changed

+21
-16
lines changed
 

Diff for: ‎README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ SWIFT has rich documentations for users, please feel free to check our documenta
5050
You can contact us and communicate with us by adding our group:
5151

5252

53-
[Discord Group](https://discord.gg/qQXTzNUp) | 微信群
53+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
5454
:-------------------------:|:-------------------------:
5555
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
5656

@@ -226,7 +226,7 @@ docker pull registry.us-west-1.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.
226226

227227
## 🚀 Getting Started
228228

229-
This section introduces basic usage, see the [Documentation](#-documentation) section for more ways to use.
229+
This section introduces basic usage, see the [Documentation](https://swift.readthedocs.io/en/latest/) section for more ways to use.
230230

231231
### Web-UI
232232

Diff for: ‎README_CN.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:
5050

5151
请扫描下面的二维码来加入我们的交流群:
5252

53-
[Discord Group](https://discord.gg/qQXTzNUp) | 微信群
53+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
5454
:-------------------------:|:-------------------------:
5555
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
5656

@@ -228,7 +228,7 @@ docker pull registry.us-west-1.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.
228228

229229
## 🚀 快速开始
230230

231-
本章节介绍基本使用,更丰富的使用方式请查看[文档部分](#-文档)
231+
本章节介绍基本使用,更丰富的使用方式请查看[文档部分](https://swift.readthedocs.io/zh-cn/latest/)
232232

233233
### Web-UI
234234

Diff for: ‎asset/discord_qr.jpg

37.2 KB
Loading

Diff for: ‎swift/llm/deploy.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,15 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
184184
request_id = request_info['request_id']
185185

186186
kwargs = {'max_new_tokens': request.max_tokens}
187-
for key in ['n', 'stop', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
187+
for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
188188
kwargs[key] = getattr(request, key)
189189
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
190190
new_value = getattr(request, key)
191191
if new_value is None:
192192
kwargs[key] = getattr(llm_engine.generation_config, key)
193193
else:
194194
kwargs[key] = new_value
195+
kwargs['stop'] = (llm_engine.generation_config.stop or []) + (getattr(request, 'stop') or [])
195196

196197
generation_config = VllmGenerationConfig(**kwargs)
197198
if generation_config.use_beam_search and request.stream:
@@ -343,7 +344,7 @@ def __repr__(self) -> str:
343344

344345
@torch.inference_mode()
345346
async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionRequest], raw_request: Request):
346-
global model, template
347+
global model, template, _args
347348
result = await _prepare_request(request)
348349
if isinstance(result, JSONResponse):
349350
return result
@@ -359,8 +360,13 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
359360
new_value = getattr(request, key)
360361
if new_value is None:
361362
kwargs[key] = getattr(model.generation_config, key)
363+
if key == 'temperature':
364+
do_sample = getattr(model.generation_config, 'do_sample')
365+
if not do_sample:
366+
kwargs[key] = 0
362367
else:
363368
kwargs[key] = new_value
369+
364370
if kwargs['temperature'] == 0:
365371
kwargs['do_sample'] = False
366372
kwargs['temperature'] = 1
@@ -374,7 +380,8 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
374380
set_generation_config(model, generation_config) # inplace
375381
model.generation_config = _old_generation_config
376382
request_info['generation_config'] = generation_config
377-
request_info.update({'seed': request.seed, 'stop': request.stop, 'stream': request.stream})
383+
stop = (_args.stop_words or []) + (getattr(request, 'stop') or [])
384+
request_info.update({'seed': request.seed, 'stop': stop, 'stream': request.stream})
378385
logger.info(request_info)
379386

380387
created_time = int(time.time())
@@ -397,7 +404,7 @@ async def _generate_full():
397404
model,
398405
template,
399406
**example,
400-
stop_words=request.stop,
407+
stop_words=stop,
401408
generation_config=generation_config,
402409
generation_info=generation_info,
403410
**adapter_kwargs)
@@ -441,7 +448,7 @@ def _generate_stream():
441448
model,
442449
template,
443450
**example,
444-
stop_words=request.stop,
451+
stop_words=stop,
445452
generation_config=generation_config,
446453
generation_info=generation_info,
447454
**adapter_kwargs)

Diff for: ‎swift/llm/utils/argument.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ class InferArguments(ArgumentsBase):
11171117
top_p: float = 0.7
11181118
repetition_penalty: float = 1.
11191119
num_beams: int = 1
1120-
stop_words: List[str] = None
1120+
stop_words: List[str] = field(default_factory=list)
11211121

11221122
# rope-scaling
11231123
rope_scaling: Literal['linear', 'dynamic'] = None

Diff for: ‎swift/llm/utils/dataset.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,10 @@ def sample_dataset(dataset: HfDataset, dataset_sample: int, random_state: Option
341341
return dataset
342342
if random_state is None:
343343
random_state = RandomState()
344-
# Sample the part that exceeds the length of the dataset.
345-
idx = random_state.permutation(len(dataset))[:dataset_sample]
346-
dataset_sample -= len(idx)
347-
if dataset_sample > 0:
348-
idx2 = random_state.choice(len(dataset), dataset_sample)
349-
idx = np.concatenate([idx, idx2], axis=0)
344+
345+
idx_repeat = np.tile(range(len(dataset)), dataset_sample // len(dataset))
346+
idx_random = random_state.permutation(len(dataset))[:dataset_sample % len(dataset)]
347+
idx = np.concatenate([idx_repeat, idx_random])
350348
dataset = dataset.select(idx)
351349
return dataset
352350

0 commit comments

Comments
 (0)
Please sign in to comment.