Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,26 @@ bool IsGraphCaptureEnabled(Config::SessionOptions& session_options) {
return false;
}

bool IsMultiProfileEnabled(const Config::SessionOptions& session_options) {
for (const auto& provider : session_options.providers) {
const auto provider_options = std::find_if(session_options.provider_options.begin(),
session_options.provider_options.end(),
[&provider](const Config::ProviderOptions& po) {
return po.name == provider;
});
if (provider_options != session_options.provider_options.end()) {
if (provider_options->name == "NvTensorRtRtx") {
for (const auto& value : provider_options->options) {
if (value.first == "nv_multi_profile_enable" && value.second == "1") {
return true;
}
}
}
}
}
return false;
}

struct Root_Element : JSON::Element {
explicit Root_Element(Config& config) : config_{config} {}

Expand Down
1 change: 1 addition & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,6 @@ void ClearProviders(Config& config);
void SetProviderOption(Config& config, std::string_view provider_name, std::string_view option_name, std::string_view option_value);
void OverlayConfig(Config& config, std::string_view json);
bool IsGraphCaptureEnabled(Config::SessionOptions& session_options);
bool IsMultiProfileEnabled(const Config::SessionOptions& session_options);

} // namespace Generators
1 change: 1 addition & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ GeneratorParams::GeneratorParams(const Config& config)
GeneratorParams::GeneratorParams(const Model& model)
: config{*model.config_.get()},
use_graph_capture{IsGraphCaptureEnabled(model.config_->model.decoder.session_options)},
use_multi_profile{IsMultiProfileEnabled(model.config_->model.decoder.session_options)},
p_device{model.p_device_inputs_} {
if (use_graph_capture) {
max_batch_size = 1; // set it to 1 by default
Expand Down
1 change: 1 addition & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChec

int max_batch_size{0};
bool use_graph_capture{};
bool use_multi_profile{};
int BatchBeamSize() const { return search.num_beams * search.batch_size; }

DeviceInterface* p_device{}; // Scoring device (usually CPU, but can be CUDA)
Expand Down
122 changes: 117 additions & 5 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,16 @@ void State::Run(OrtSession& session, bool graph_capture_this_run) {

if (first_run_) {
extra_outputs_.Add(session.GetOutputNames());
if (params_->use_multi_profile) {
// Run the context phase profile for the first run
run_options_->AddConfigEntry("nv_profile_index", "0");
}
first_run_ = false;
} else {
extra_outputs_.Update();
if (params_->use_multi_profile) {
run_options_->AddConfigEntry("nv_profile_index", "1");
}
}

if (g_log.enabled && g_log.model_input_values) {
Expand Down Expand Up @@ -273,11 +280,111 @@ int32_t Tokenizer::TokenToTokenId(const char* token) const {
return token_id;
}

/**
* @brief Creates multi-profile shapes for TensorRT execution provider optimization.
*
* This function generates separate profiles for each of the context and generation phases to optimize performance
* Each profile includes shapes for input tensors (input_ids, attention_mask, position_ids)
* and key-value cache tensors with appropriate dimensions based on the model configuration.
*
*/
void ConfigureMultiProfile(const Config& config, OrtSessionOptions& session_options) {
// Get model parameters from decoder config
const int num_layers = config.model.decoder.num_hidden_layers;
const int num_kv_heads = config.model.decoder.num_key_value_heads;
const int head_dim = config.model.decoder.head_size;

// Get max context length from config
const int max_context_len = config.model.context_length;
const int opt_context_len = config.model.context_length / 2;
const int min_seq_len = 1;

// Extract KV cache name patterns from decoder config
std::string_view past_key_pattern = config.model.decoder.inputs.past_key_names;
std::string_view past_value_pattern = config.model.decoder.inputs.past_value_names;

// Helper function to add input shapes (input_ids, attention_mask, position_ids)
const auto add_input_shapes = [](std::ostringstream& shapes, int seq_len, bool append = false) {
if (append) shapes << ",";
shapes << Config::Defaults::InputIdsName << ":1x" << seq_len << ","
<< Config::Defaults::AttentionMaskName << ":1x" << seq_len << ","
<< Config::Defaults::PositionIdsName << ":1x" << seq_len;
};

// Helper function to add generation phase input shapes
const auto add_generation_input_shapes = [](std::ostringstream& shapes, int context_len) {
shapes << "," << Config::Defaults::AttentionMaskName << ":1x" << context_len << ","
<< Config::Defaults::InputIdsName << ":1x1,"
<< Config::Defaults::PositionIdsName << ":1x1";
};

// Helper function to add empty KV cache shapes for all layers
const auto add_empty_key_value_cache_shapes = [](std::ostringstream& shapes,
std::string_view key_pattern,
std::string_view value_pattern,
int num_layers,
int num_kv_heads,
int head_dim) {
for (int i = 0; i < num_layers; i++) {
// Use the existing function to format the key/value names
const std::string key_name = ComposeKeyValueName(std::string(key_pattern), i);
const std::string value_name = ComposeKeyValueName(std::string(value_pattern), i);

shapes << "," << key_name << ":1x" << num_kv_heads << "x0x" << head_dim;
shapes << "," << value_name << ":1x" << num_kv_heads << "x0x" << head_dim;
}
};

// Helper function to add KV cache with sequence length
const auto add_key_value_cache_shapes = [](std::ostringstream& shapes,
std::string_view key_pattern,
std::string_view value_pattern,
int seq_len,
int num_layers,
int num_kv_heads,
int head_dim) {
for (int i = 0; i < num_layers; i++) {
// Use the existing function to format the key/value names
const std::string key_name = ComposeKeyValueName(std::string(key_pattern), i);
const std::string value_name = ComposeKeyValueName(std::string(value_pattern), i);

shapes << "," << key_name << ":1x" << num_kv_heads << "x" << seq_len << "x" << head_dim;
shapes << "," << value_name << ":1x" << num_kv_heads << "x" << seq_len << "x" << head_dim;
}
};

std::ostringstream min_shapes, opt_shapes, max_shapes;

// MIN SHAPES (context phase and first token generation)
add_input_shapes(min_shapes, min_seq_len);
add_empty_key_value_cache_shapes(min_shapes, past_key_pattern, past_value_pattern, num_layers, num_kv_heads, head_dim);
add_generation_input_shapes(min_shapes, min_seq_len);
add_key_value_cache_shapes(min_shapes, past_key_pattern, past_value_pattern, min_seq_len, num_layers, num_kv_heads, head_dim);

// OPT SHAPES (prefill with medium context and generation after medium context)
add_input_shapes(opt_shapes, opt_context_len);
add_empty_key_value_cache_shapes(opt_shapes, past_key_pattern, past_value_pattern, num_layers, num_kv_heads, head_dim);
add_generation_input_shapes(opt_shapes, opt_context_len);
add_key_value_cache_shapes(opt_shapes, past_key_pattern, past_value_pattern, opt_context_len - 1, num_layers, num_kv_heads, head_dim);

// MAX SHAPES (prefill with maximum context and generation after maximum context)
add_input_shapes(max_shapes, max_context_len);
add_empty_key_value_cache_shapes(max_shapes, past_key_pattern, past_value_pattern, num_layers, num_kv_heads, head_dim);
add_generation_input_shapes(max_shapes, max_context_len);
add_key_value_cache_shapes(max_shapes, past_key_pattern, past_value_pattern, max_context_len - 1, num_layers, num_kv_heads, head_dim);

// Add the constructed profiles to session options
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_profile_min_shapes", min_shapes.str().c_str());
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_profile_opt_shapes", opt_shapes.str().c_str());
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_profile_max_shapes", max_shapes.str().c_str());
}

DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
const std::vector<std::string>& providers,
const std::vector<Config::ProviderOptions>& provider_options_list,
bool is_primary_session_options,
bool disable_graph_capture) {
bool disable_graph_capture,
const Config& config) {
DeviceInterface* p_device{};

auto providers_list = providers;
Expand Down Expand Up @@ -380,6 +487,11 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
} else if (provider_options.name == "NvTensorRtRtx") {
// After setting the NvTensorRtRtx provider in Onnxruntime, GenAI will then treat it as the cuda device.
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_cuda_graph_enable", "1");

if (IsMultiProfileEnabled(config.model.decoder.session_options)) {
ConfigureMultiProfile(config, session_options);
}

p_device = GetDeviceInterface(DeviceType::CUDA);
}

Expand Down Expand Up @@ -409,7 +521,7 @@ static const uint8_t g_trivial_model[] = {
// the allocator used is not destroyed until last. This keeps the allocator around until exit, after all other memory
// has been destroyed. Without this, we will crash in the Onnxruntime BFCArena code when deleting tensors due to the
// arena already being destroyed.
void EnsureDeviceOrtInit(DeviceInterface& device) {
void EnsureDeviceOrtInit(DeviceInterface& device, const Config& config) {
// CPU Allocator is a special case, it's not in the owned 'allocator_device_' table below so we handle it separately
// OpenVINO delegates to the CPU device allocator
auto type = device.GetType();
Expand Down Expand Up @@ -439,7 +551,7 @@ void EnsureDeviceOrtInit(DeviceInterface& device) {
provider_options_list.back().options.emplace_back("enable_htp_shared_memory_allocator", "1");
}
const std::vector<std::string> providers{device_type_names[static_cast<int>(type)]};
SetProviderSessionOptions(*session_options, providers, provider_options_list, true, false);
SetProviderSessionOptions(*session_options, providers, provider_options_list, true, false, config);
session_options->SetLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR); // Errors only here, as warnings are not useful to the user

allocator.session_ = OrtSession::Create(GetOrtEnv(), g_trivial_model, sizeof(g_trivial_model), session_options.get());
Expand Down Expand Up @@ -525,7 +637,7 @@ std::vector<const char*> SessionInfo::GetOutputSymbolicShape(const std::string&

Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
CreateSessionOptions();
EnsureDeviceOrtInit(*p_device_);
EnsureDeviceOrtInit(*p_device_, *config_);

// Only CUDA and DML does every input on the device
if (p_device_->GetType() == DeviceType::CUDA || p_device_->GetType() == DeviceType::DML)
Expand Down Expand Up @@ -638,7 +750,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_

auto session_device = SetProviderSessionOptions(session_options, config_session_options.providers,
config_session_options.provider_options, is_primary_session_options,
disable_graph_capture);
disable_graph_capture, *config_);

if (!p_device_) {
p_device_ = session_device;
Expand Down
Loading