diff --git a/cpp/tensorrt_llm/nitro/controllers/tensorrtllm.cc b/cpp/tensorrt_llm/nitro/controllers/tensorrtllm.cc index 999d7b18a82..c3891440dd3 100644 --- a/cpp/tensorrt_llm/nitro/controllers/tensorrtllm.cc +++ b/cpp/tensorrt_llm/nitro/controllers/tensorrtllm.cc @@ -26,11 +26,52 @@ void removeId(std::vector& vec, int id) struct inferenceState { int prevPos{0}; + std::string prevText; bool isFinished; std::queue textsToStream; std::mutex queueMutex; // Mutex to protect access to textsToStream + + size_t stopWordMatchLen = 0; + std::vector sequence{"<", "|", "im", "_", "end", "|", ">"}; + + void reset() + { + stopWordMatchLen = 0; + prevText = ""; + } + + bool isComplete() const + { + return stopWordMatchLen >= sequence.size(); + } }; +bool handleMatch(const std::string& rawText, std::shared_ptr inferState) +{ + if (inferState->isComplete()) + { + return true; + } + + if (rawText == inferState->sequence[inferState->stopWordMatchLen]) + { + inferState->stopWordMatchLen++; // Move to next state + inferState->prevText = rawText; + return true; + } + else if (inferState->stopWordMatchLen > 0 && rawText == inferState->sequence[0]) + { + inferState->stopWordMatchLen = 1; // Restart from first match if sequence breaks but matches start + inferState->prevText = rawText; + return true; + } + else + { + inferState->reset(); + return false; // Reset to start if sequence breaks + } +} + // Only support single token stopping point now std::string create_return_json(const std::string& id, const std::string& model, const std::string& content, Json::Value finish_reason = Json::Value()) @@ -67,6 +108,13 @@ GenerationInput::TensorPtr tensorrtllm::getTensorSingleStopWordList(int stopToke return gptSession->getBufferManager().copyFrom(stopWordsTokens, ITensor::makeShape({1, 2, 2}), MemoryType::kGPU); } +GenerationInput::TensorPtr tensorrtllm::getTensorChatMLStopWordList() +{ + std::vector stopWordsTokens = {28789, 28766, 321, 28730, 416, 28766, 28767, 32000, 6, 8, -1, -1, -1, -1, + -1, -1}; // Extend with -1 for increased length + return gptSession->getBufferManager().copyFrom(stopWordsTokens, ITensor::makeShape({1, 2, 8}), MemoryType::kGPU); +} + GenerationInput tensorrtllm::createGenerationInput(std::vector inputIdsHost) { int inputLen = inputIdsHost.size(); @@ -78,7 +126,7 @@ GenerationInput tensorrtllm::createGenerationInput(std::vector inputIds GenerationInput generationInput{0, 0, inputIds, inputLengths, modelConfig->usePackedInput()}; - generationInput.stopWordsList = getTensorSingleStopWordList(32000); + generationInput.stopWordsList = getTensorChatMLStopWordList(); return generationInput; } @@ -117,35 +165,35 @@ void inferenceThread(std::shared_ptr inferState, std::vectorgetShape().d[2]; // Get the length of output IDs based on the tensor shape + // Copy output IDs from GPU to host for printing + std::vector outputIdsHost(outputLength); + self->gptSession->getBufferManager().copy(*outputIds, outputIdsHost.data(), MemoryType::kCPU); + // Find the last non-zero value in the output IDs starting from the end of the input sequence + std::vector outputIdsHostDecode(outputIdsHost.begin() + inputLen, outputIdsHost.end()); + removeId(outputIdsHostDecode, 0); + std::string text = self->nitro_tokenizer->decode(outputIdsHostDecode); + + if (inferState->prevPos > 0 && inferState->prevPos < text.size()) + { + // Valid prevPos, proceed with slicing the string from prevPos to the end + std::string stringTok(text.begin() + inferState->prevPos, text.end()); + std::lock_guard guard(inferState->queueMutex); // Protect access with a lock + inferState->textsToStream.push(stringTok); + } + else if (inferState->prevPos >= text.size()) { - // Assuming the shape of outputIds tensor is (1, 1, 160), where 160 is the number of tokens - int outputLength = outputIds->getShape().d[2]; // Get the length of output IDs based on the tensor shape - // Copy output IDs from GPU to host for printing - std::vector outputIdsHost(outputLength); - self->gptSession->getBufferManager().copy(*outputIds, outputIdsHost.data(), MemoryType::kCPU); - // Find the last non-zero value in the output IDs starting from the end of the input sequence - std::vector outputIdsHostDecode(outputIdsHost.begin() + inputLen, outputIdsHost.end()); - removeId(outputIdsHostDecode, 0); - removeId(outputIdsHostDecode, 32000); - std::string text = self->nitro_tokenizer->decode(outputIdsHostDecode); - - if (inferState->prevPos > 0 && inferState->prevPos < text.size()) - { - // Valid prevPos, proceed with slicing the string from prevPos to the end - std::string stringTok(text.begin() + inferState->prevPos, text.end()); - std::lock_guard guard(inferState->queueMutex); // Protect access with a lock - inferState->textsToStream.push(stringTok); - } - else if (inferState->prevPos >= text.size()) - { - inferState->prevPos = text.size(); - } inferState->prevPos = text.size(); + } + inferState->prevPos = text.size(); + if (finished) + { + + std::lock_guard guard(inferState->queueMutex); // Protect access with a lock + inferState->textsToStream.push("[DONE]"); return; } - std::lock_guard guard(inferState->queueMutex); // Protect access with a lock - inferState->textsToStream.push("[DONE]"); }; // The rest of the logic inside the `chat_completion` remains unchanged... // After finishing the setup, call the inference logic @@ -243,6 +291,12 @@ void tensorrtllm::chat_completion( { std::string rawText = inferState->textsToStream.front(); + inferState->textsToStream.pop(); + if (handleMatch(rawText, inferState)) + { + continue; + }; + if (rawText == "[DONE]") { LOG_INFO << "End of result"; @@ -257,7 +311,6 @@ void tensorrtllm::chat_completion( } const std::string textToStream = "data: " + create_return_json(nitro_utils::generate_random_string(20), "_", rawText) + "\n\n"; - inferState->textsToStream.pop(); lock.unlock(); // Unlock as soon as possible // Ensure we do not exceed the buffer size. Truncate if necessary. @@ -265,6 +318,7 @@ void tensorrtllm::chat_completion( // Copy the text to the provided buffer std::memcpy(pBuffer, textToStream.data(), bytesToWrite); + inferState->prevText = rawText; return bytesToWrite; // Return the number of bytes written to the buffer } else diff --git a/cpp/tensorrt_llm/nitro/controllers/tensorrtllm.h b/cpp/tensorrt_llm/nitro/controllers/tensorrtllm.h index 0ecae873d27..40454829f6b 100644 --- a/cpp/tensorrt_llm/nitro/controllers/tensorrtllm.h +++ b/cpp/tensorrt_llm/nitro/controllers/tensorrtllm.h @@ -100,6 +100,7 @@ class tensorrtllm : public drogon::HttpController GenerationInput createGenerationInput(std::vector inputIds); GenerationOutput createGenerationOutput(); std::unique_ptr nitro_tokenizer; + GenerationInput::TensorPtr getTensorChatMLStopWordList(); private: GptSession::Config sessionConfig{1, 1, 1};