Skip to content

Commit

Permalink
[coati] fix inference output (#3285)
Browse files Browse the repository at this point in the history
* [coati] fix inference requirements

* [coati] add output postprocess

* [coati] update inference readme

* [coati] fix inference requirements
  • Loading branch information
ver217 committed Mar 28, 2023
1 parent bb6196e commit 4905b21
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
6 changes: 6 additions & 0 deletions applications/Chat/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar
| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |

## General setup

```shell
pip install -r requirements.txt
```

## 8-bit setup

8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source.
Expand Down
4 changes: 3 additions & 1 deletion applications/Chat/inference/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fastapi
locustio
locust
numpy
pydantic
safetensors
Expand All @@ -8,3 +8,5 @@ sse_starlette
torch
uvicorn
git+https://github.com/huggingface/transformers
accelerate
bitsandbytes
4 changes: 2 additions & 2 deletions applications/Chat/inference/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn

CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 2048
MAX_LEN = 512
running_lock = Lock()


Expand Down Expand Up @@ -116,7 +116,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt_len = inputs['input_ids'].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
return out_string.lstrip()
return prompt_processor.postprocess_output(out_string)


if __name__ == '__main__':
Expand Down
8 changes: 8 additions & 0 deletions applications/Chat/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional

Expand Down Expand Up @@ -118,6 +119,9 @@ def _format_dialogue(instruction: str, response: str = ''):
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'


STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))


class ChatPromptProcessor:

def __init__(self, tokenizer, context: str, max_len: int = 2048):
Expand Down Expand Up @@ -164,6 +168,10 @@ def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt

def postprocess_output(self, output: str) -> str:
output = STOP_PAT.sub('', output)
return output.strip()


class LockedIterator:

Expand Down

0 comments on commit 4905b21

Please sign in to comment.