Skip to content

Compute buffer and KV-cache aware layer distribution for multi-GPU inference #14484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
mparams.requested_n_ctx = params.n_ctx;

if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ extern "C" {
// override key-value pairs of the model meta data
const struct llama_model_kv_override * kv_overrides;

// expected context size for memory allocation planning (0 = auto)
uint32_t requested_n_ctx;

// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
Expand Down
308 changes: 308 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
#include <cassert>
#include <cmath>
#include <cfloat>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <functional>
#include <map>
#include <numeric>
#include <regex>
#include <sstream>
#include <stdexcept>
Expand Down Expand Up @@ -1580,6 +1582,311 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
splits[i] /= split_sum;
}

// KV-cache aware layer distribution for heterogeneous GPUs
if (all_zero && n_devices() > 1 && split_mode == LLAMA_SPLIT_MODE_LAYER) {
// Determine context size for memory planning
uint32_t n_ctx_for_kv = 0;
if (params.requested_n_ctx > 0) {
// Use the explicitly requested context size from model params
n_ctx_for_kv = params.requested_n_ctx;
LLAMA_LOG_INFO("%s: Using requested_n_ctx=%u for KV cache calculation\n",
__func__, n_ctx_for_kv);
} else {
// Use a conservative default for memory planning
n_ctx_for_kv = std::min(32768u, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: Using default n_ctx=%u for KV cache calculation (training context: %u)\n",
__func__, n_ctx_for_kv, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: (set requested_n_ctx in model params to match your actual context size)\n", __func__);
}

// Only apply KV-aware distribution if we have a valid context size
if (n_ctx_for_kv > 0 && n_gpu_layers > 0) {
LLAMA_LOG_INFO("%s: Implementing KV-cache aware layer distribution\n", __func__);

// Calculate memory requirements per layer
const int64_t n_head_kv = hparams.n_head_kv();
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_kv = n_embd_head * n_head_kv;

// KV cache element size (typically f16 = 2 bytes, but can be quantized)
const size_t kv_size_element = 2; // sizeof(ggml_fp16_t)

// Total KV cache size for all layers (K and V)
// KV cache = 2 (K+V) * n_ctx * n_layers * n_embd_kv * element_size
const size_t kv_cache_size_total = 2ULL * n_ctx_for_kv * n_layer * n_embd_kv * kv_size_element;

// Estimate model weight size per layer
const size_t model_size_total = ml.n_bytes;
const size_t weight_size_per_layer = model_size_total / n_layer;

// Calculate actual compute buffer size based on attention matrix requirements
// Attention matrix: n_kv × n_ubatch × n_head × sizeof(float)
// This is the dominant memory consumer during inference
const int64_t n_head = hparams.n_head();
const size_t n_ubatch = 512; // Default physical batch size (from context params)
const size_t compute_buffer_size = n_ctx_for_kv * n_ubatch * n_head * sizeof(float);
const size_t min_overhead = 512ULL * 1024 * 1024; // 512MB base overhead

LLAMA_LOG_INFO("%s: Compute buffer size: %.2f MB (context=%u, ubatch=%zu, heads=%lld)\n",
__func__,
compute_buffer_size / 1024.0 / 1024.0,
n_ctx_for_kv, n_ubatch, (long long)n_head);

// For memory calculation, we need to account for KV cache being shared across layers on each device
// We'll calculate this dynamically during layer assignment

LLAMA_LOG_INFO("%s: Per-layer memory: weights=%.2f MB\n",
__func__,
weight_size_per_layer / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: Total KV cache size: %.2f MB\n",
__func__,
kv_cache_size_total / 1024.0 / 1024.0);

// Get memory info and calculate layer assignments
std::vector<int> layers_per_gpu(n_devices(), 0);
std::vector<size_t> gpu_free_memory(n_devices());

// Get free memory for each device and check if they can handle compute buffers
std::vector<bool> device_excluded(n_devices(), false);
for (size_t i = 0; i < n_devices(); ++i) {
ggml_backend_dev_t dev = devices[i];
size_t total, free;
ggml_backend_dev_memory(dev, &free, &total);
gpu_free_memory[i] = free;

// Check if device can handle minimum requirements (1 layer + compute buffer + KV cache)
size_t min_kv_cache = kv_cache_size_total / n_devices(); // Conservative estimate
size_t min_required = weight_size_per_layer + min_kv_cache + compute_buffer_size + min_overhead;

if (free < min_required) {
device_excluded[i] = true;
LLAMA_LOG_WARN("%s: Device %zu [%s]: %.2f MB free - excluding (needs %.2f MB minimum)\n",
__func__, i, ggml_backend_dev_name(dev),
free / 1024.0 / 1024.0, min_required / 1024.0 / 1024.0);
}
}

// Estimate total memory requirements and warn if insufficient
size_t total_gpu_memory = 0;
for (size_t i = 0; i < n_devices(); ++i) {
total_gpu_memory += gpu_free_memory[i];
}

// Rough estimate: KV cache + model weights + compute buffers (conservative estimate)
size_t estimated_compute_buffers = kv_cache_size_total; // Compute buffers often similar to KV cache size
size_t estimated_total_needed = kv_cache_size_total + model_size_total + estimated_compute_buffers;

if (estimated_total_needed > total_gpu_memory) {
LLAMA_LOG_WARN("%s: Memory estimate: %.2f GB needed vs %.2f GB available\n",
__func__,
estimated_total_needed / 1024.0 / 1024.0 / 1024.0,
total_gpu_memory / 1024.0 / 1024.0 / 1024.0);
LLAMA_LOG_WARN("%s: Context size may be too large for available memory\n", __func__);
}

// Sort devices by available memory (largest first), excluding unusable devices
std::vector<size_t> gpu_indices;
for (size_t i = 0; i < n_devices(); ++i) {
if (!device_excluded[i]) {
gpu_indices.push_back(i);
}
}
std::sort(gpu_indices.begin(), gpu_indices.end(),
[&gpu_free_memory](size_t a, size_t b) {
return gpu_free_memory[a] > gpu_free_memory[b];
});

if (gpu_indices.empty()) {
LLAMA_LOG_ERROR("%s: No GPUs have sufficient memory for compute buffers\n", __func__);
// Fall back to original allocation
return true;
}

// Assign layers greedily to GPUs with most memory first
int act_gpu_layers = n_gpu_layers; // Local copy that can be modified
int remaining_layers = act_gpu_layers;

// First pass: assign layers based on weights only (KV cache and compute buffers handled separately)
size_t weight_per_layer = weight_size_per_layer;

for (size_t idx : gpu_indices) {
// Reserve memory for compute buffer and base overhead
size_t reserved = compute_buffer_size + min_overhead;
if (gpu_free_memory[idx] <= reserved) {
LLAMA_LOG_WARN("%s: Device %zu [%s]: %zu MB free, can't fit compute buffer (%.2f MB)\n",
__func__, idx, ggml_backend_dev_name(devices[idx]),
gpu_free_memory[idx] / 1024 / 1024,
reserved / 1024.0 / 1024.0);
continue;
}

size_t available_for_model = gpu_free_memory[idx] - reserved;
int layers_that_fit = available_for_model / weight_per_layer;

if (layers_that_fit > 0 && remaining_layers > 0) {
int layers_to_assign = std::min(layers_that_fit, remaining_layers);
layers_per_gpu[idx] = layers_to_assign;
remaining_layers -= layers_to_assign;

LLAMA_LOG_INFO("%s: Device %zu [%s]: %zu MB free, assigned %d layers (%.2f MB weights, %.2f MB compute buffer)\n",
__func__, idx, ggml_backend_dev_name(devices[idx]),
gpu_free_memory[idx] / 1024 / 1024,
layers_per_gpu[idx],
(layers_to_assign * weight_per_layer) / 1024.0 / 1024.0,
compute_buffer_size / 1024.0 / 1024.0);
} else {
LLAMA_LOG_WARN("%s: Device %zu [%s]: %zu MB free, assigned 0 layers (need %.2f MB per layer + %.2f MB compute buffer)\n",
__func__, idx, ggml_backend_dev_name(devices[idx]),
gpu_free_memory[idx] / 1024 / 1024,
weight_per_layer / 1024.0 / 1024.0,
compute_buffer_size / 1024.0 / 1024.0);
}
}

// Second pass: iteratively check if KV cache can fit proportionally
bool kv_fit_check_needed = (remaining_layers == 0);
int iterations = 0;
const int max_iterations = 10;

while (kv_fit_check_needed && iterations < max_iterations) {
kv_fit_check_needed = false;
iterations++;

// Calculate current total assigned layers
int total_assigned = 0;
for (size_t idx = 0; idx < n_devices(); ++idx) {
total_assigned += layers_per_gpu[idx];
}

if (total_assigned == 0) break;

// Check KV cache distribution for each device
for (size_t idx = 0; idx < n_devices(); ++idx) {
if (layers_per_gpu[idx] > 0) {
double layer_ratio = (double)layers_per_gpu[idx] / total_assigned;
size_t kv_cache_for_device = (size_t)(kv_cache_size_total * layer_ratio);
size_t weights = layers_per_gpu[idx] * weight_per_layer;
size_t total_memory_needed = weights + kv_cache_for_device + compute_buffer_size + min_overhead;

if (total_memory_needed > gpu_free_memory[idx]) {
// Device can't fit current allocation, reduce layers
size_t available_memory = gpu_free_memory[idx];
if (available_memory > min_overhead + kv_cache_for_device + compute_buffer_size) {
size_t available_for_weights = available_memory - min_overhead - kv_cache_for_device - compute_buffer_size;
int new_layer_count = available_for_weights / weight_per_layer;
new_layer_count = std::max(0, new_layer_count);

if (new_layer_count < layers_per_gpu[idx]) {
LLAMA_LOG_WARN("%s: Device %zu: Reducing layers from %d to %d due to KV cache requirements (%.2f MB KV cache)\n",
__func__, idx, layers_per_gpu[idx], new_layer_count,
kv_cache_for_device / 1024.0 / 1024.0);
remaining_layers += layers_per_gpu[idx] - new_layer_count;
layers_per_gpu[idx] = new_layer_count;
kv_fit_check_needed = true;
}
} else {
// Device can't even fit the minimum requirements
LLAMA_LOG_WARN("%s: Device %zu: Removing all %d layers - insufficient memory for KV cache\n",
__func__, idx, layers_per_gpu[idx]);
remaining_layers += layers_per_gpu[idx];
layers_per_gpu[idx] = 0;
kv_fit_check_needed = true;
}
}
}
}
}

// Third pass: redistribute any remaining layers to devices with available capacity
if (remaining_layers > 0) {
LLAMA_LOG_INFO("%s: Attempting to redistribute %d remaining layers\n", __func__, remaining_layers);

// Calculate current memory usage for each device that has layers assigned
for (size_t idx : gpu_indices) {
if (layers_per_gpu[idx] > 0 && remaining_layers > 0) {
// Calculate current memory usage
int current_assigned = 0;
for (size_t i = 0; i < n_devices(); ++i) {
current_assigned += layers_per_gpu[i];
}

double layer_ratio = (double)layers_per_gpu[idx] / current_assigned;
size_t current_kv_cache = (size_t)(kv_cache_size_total * layer_ratio);
size_t current_weights = layers_per_gpu[idx] * weight_per_layer;
size_t current_usage = current_weights + current_kv_cache + compute_buffer_size + min_overhead;

if (gpu_free_memory[idx] > current_usage) {
// Calculate how many additional layers could fit
// We need to account for proportional increase in KV cache
int additional_layers = 0;
for (int test_layers = 1; test_layers <= remaining_layers; test_layers++) {
int new_total_layers = layers_per_gpu[idx] + test_layers;
int new_total_assigned = current_assigned + test_layers;
double new_layer_ratio = (double)new_total_layers / new_total_assigned;
size_t new_kv_cache = (size_t)(kv_cache_size_total * new_layer_ratio);
size_t new_weights = new_total_layers * weight_per_layer;
size_t new_total_usage = new_weights + new_kv_cache + compute_buffer_size + min_overhead;

if (new_total_usage <= gpu_free_memory[idx]) {
additional_layers = test_layers;
} else {
break;
}
}

if (additional_layers > 0) {
int layers_to_add = std::min(additional_layers, remaining_layers);
layers_per_gpu[idx] += layers_to_add;
remaining_layers -= layers_to_add;

LLAMA_LOG_INFO("%s: Device %zu [%s]: redistributed %d additional layers (total now %d)\n",
__func__, idx, ggml_backend_dev_name(devices[idx]),
layers_to_add, layers_per_gpu[idx]);
}
}
}
}
}

// Warn if we couldn't place all layers
if (remaining_layers > 0) {
LLAMA_LOG_ERROR("%s: WARNING: Could not assign %d layers to GPUs. Consider:\n",
__func__, remaining_layers);
LLAMA_LOG_ERROR("%s: - Reducing context size (current: %u)\n",
__func__, n_ctx_for_kv);
LLAMA_LOG_ERROR("%s: - Using fewer layers (-ngl)\n", __func__);
LLAMA_LOG_ERROR("%s: - Adding more GPU memory\n", __func__);

// Put remaining layers on CPU (will be updated below)
}

// Convert layer counts to split ratios
splits.clear();
splits.resize(n_devices());
float cumsum = 0.0f;

// Calculate total layers actually assigned
int total_assigned_layers = 0;
for (size_t i = 0; i < n_devices(); ++i) {
total_assigned_layers += layers_per_gpu[i];
}

// Update act_gpu_layers to match what we actually assigned
act_gpu_layers = total_assigned_layers;

for (size_t i = 0; i < n_devices(); ++i) {
cumsum += (float)layers_per_gpu[i] / act_gpu_layers;
splits[i] = cumsum;
}

LLAMA_LOG_INFO("%s: Final split ratios: ", __func__);
for (size_t i = 0; i < n_devices(); ++i) {
LLAMA_LOG_CONT("%.3f ", splits[i]);
}
LLAMA_LOG_CONT("\n");
}
}

ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
Expand Down Expand Up @@ -14837,6 +15144,7 @@ llama_model_params llama_model_default_params() {
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.requested_n_ctx =*/ 0,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_mlock =*/ false,
Expand Down