Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 35 additions & 34 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,19 @@ const (
)

var DefaultConfig = Config{
DefaultBlockSize: DefaultBlockSize,
AutoTune: true,
BlockSize: DefaultBlockSize,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}

type Config struct {
// If set to true, the plugin will automatically adjust the configuration based on various
// metrics from the model servers.
AutoTune bool `json:"autoTune"`
// The input prompt is broken into sizes of BlockSize to calculate block hashes . Requests
// with length shorter than the block size will be ignored.
DefaultBlockSize int `json:"blockSize"`
BlockSize int `json:"blockSize"`
// MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will
// be ignored.
MaxPrefixBlocksToMatch int `json:"maxPrefixBlocksToMatch"`
Expand Down Expand Up @@ -148,11 +152,7 @@ var (

// PrefixCachePluginFactory defines the factory function for Prefix plugin.
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := Config{
DefaultBlockSize: DefaultBlockSize,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
parameters := DefaultConfig

if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
Expand All @@ -167,40 +167,32 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle

// New initializes a new prefix Plugin and returns its pointer.
func New(ctx context.Context, config Config) *Plugin {
capacity := config.LRUCapacityPerServer
if capacity <= 0 {
capacity = DefaultLRUCapacityPerServer
if config.LRUCapacityPerServer <= 0 {
config.LRUCapacityPerServer = DefaultLRUCapacityPerServer
log.FromContext(ctx).V(logutil.DEFAULT).Info(
"LRUCapacityPerServer is not positive, using default value",
"defaultCapacity", DefaultLRUCapacityPerServer,
)
}

blockSize := config.DefaultBlockSize
if blockSize <= 0 {
blockSize = DefaultBlockSize
log.FromContext(ctx).V(logutil.DEFAULT).Info("DefaultBlockSize is not positive, using default value",
if config.BlockSize <= 0 {
config.BlockSize = DefaultBlockSize
log.FromContext(ctx).V(logutil.DEFAULT).Info("BlockSize is not positive, using default value",
"default", DefaultBlockSize)
}

maxPrefixBlocks := config.MaxPrefixBlocksToMatch
if maxPrefixBlocks <= 0 {
maxPrefixBlocks = DefaultMaxPrefixBlocks
if config.MaxPrefixBlocksToMatch <= 0 {
config.MaxPrefixBlocksToMatch = DefaultMaxPrefixBlocks
log.FromContext(ctx).V(logutil.DEFAULT).Info("MaxPrefixBlocksToMatch is not positive, using default value",
"default", DefaultMaxPrefixBlocks)
}

validConfig := Config{
DefaultBlockSize: blockSize,
MaxPrefixBlocksToMatch: maxPrefixBlocks,
LRUCapacityPerServer: capacity,
}

log.FromContext(ctx).V(logutil.DEFAULT).Info("PrefixCachePlugin initialized", "config", config)
return &Plugin{
typedName: plugins.TypedName{Type: PrefixCachePluginType, Name: PrefixCachePluginType},
config: validConfig,
config: config,
pluginState: plugins.NewPluginState(ctx),
indexer: newIndexer(ctx, capacity),
indexer: newIndexer(ctx, config.LRUCapacityPerServer),
}
}

Expand All @@ -218,7 +210,7 @@ func (p *Plugin) WithName(name string) *Plugin {
// Score returns the scoring result for the given list of pods based on context.
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
// pre score step, hashing prompt and find longest prefix match.
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch)
state := &SchedulingContextState{
PrefixHashes: hashes,
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
Expand Down Expand Up @@ -248,8 +240,12 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
// PreRequest records in the plugin cache the result of the scheduling selection.
func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
gpuBlocks := primaryProfileResult.TargetPods[0].GetMetrics().CacheNumGPUBlocks
targetPod := primaryProfileResult.TargetPods[0] // get the first pod of the primary profile

gpuBlocks := p.config.LRUCapacityPerServer
if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 {
gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks
}

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
Expand All @@ -265,16 +261,16 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
p.wg.Add(1)
go func() {
p.indexer.Add(state.PrefixHashes, Server{
ServerID(targetPod.NamespacedName),
ServerID(targetPod.GetPod().NamespacedName),
gpuBlocks,
})
p.wg.Done()
}()

total := len(state.PrefixHashes)
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
matchLen := state.PrefixCacheServers[ServerID(targetPod.GetPod().NamespacedName)]

blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize)
blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config)
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
}

Expand Down Expand Up @@ -388,9 +384,14 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
return json.Marshal(request.Body.ChatCompletions.Messages)
}

func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
func getBlockSize(pods []types.Pod, config Config) int {
if !config.AutoTune {
return config.BlockSize
}

// Fallback to BlockSize if no metrics are available.
if len(pods) == 0 {
return defaultBlockSize
return config.BlockSize
}

// Since all PODs originate from the same inference pool, they are considered to have identical configurations.
Expand All @@ -401,5 +402,5 @@ func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
return cacheBlockSize * averageCharactersPerToken
}
}
return defaultBlockSize
return config.BlockSize
}
115 changes: 105 additions & 10 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (

func TestPrefixPluginCompletion(t *testing.T) {
config := Config{
DefaultBlockSize: 4,
BlockSize: 4,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
Expand Down Expand Up @@ -201,7 +201,7 @@ func TestPrefixPluginCompletion(t *testing.T) {

func TestPrefixPluginChatCompletions(t *testing.T) {
config := Config{
DefaultBlockSize: 4,
BlockSize: 4,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
Expand Down Expand Up @@ -235,7 +235,7 @@ func TestPrefixPluginChatCompletions(t *testing.T) {

func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
config := Config{
DefaultBlockSize: 8, // Use larger block size for more predictable JSON marshaling
BlockSize: 8, // Use larger block size for more predictable JSON marshaling
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
Expand Down Expand Up @@ -349,7 +349,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
blockSize := 4
maxPrefixBlocks := 50000
config := Config{
DefaultBlockSize: blockSize,
BlockSize: blockSize,
MaxPrefixBlocksToMatch: maxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
Expand Down Expand Up @@ -409,7 +409,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
{
name: "all zero",
config: Config{
DefaultBlockSize: 0,
BlockSize: 0,
MaxPrefixBlocksToMatch: 0,
LRUCapacityPerServer: 0,
},
Expand All @@ -420,7 +420,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
{
name: "negative values",
config: Config{
DefaultBlockSize: -5,
BlockSize: -5,
MaxPrefixBlocksToMatch: -10,
LRUCapacityPerServer: -100,
},
Expand All @@ -431,7 +431,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
{
name: "mixed valid and invalid",
config: Config{
DefaultBlockSize: 32, // valid
BlockSize: 32, // valid
MaxPrefixBlocksToMatch: -1, // invalid
LRUCapacityPerServer: 50000, // valid
},
Expand All @@ -442,7 +442,7 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
{
name: "all valid",
config: Config{
DefaultBlockSize: 64,
BlockSize: 64,
MaxPrefixBlocksToMatch: 200,
LRUCapacityPerServer: 30000,
},
Expand All @@ -459,13 +459,108 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {

assert.NotEmpty(t, plugin)
assert.NotEmpty(t, plugin.indexer)
assert.Equal(t, tt.expectBlock, plugin.config.DefaultBlockSize)
assert.Equal(t, tt.expectBlock, plugin.config.BlockSize)
assert.Equal(t, tt.expectMaxMatch, plugin.config.MaxPrefixBlocksToMatch)
assert.Equal(t, tt.expectCapacity, plugin.config.LRUCapacityPerServer)
})
}
}

func TestPrefixPluginAutoTune(t *testing.T) {
// Setup common test data
podName := "pod-autotune"
pod := &types.PodMetrics{
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: podName}},
MetricsState: &backendmetrics.MetricsState{
CacheBlockSize: 16, // 16 tokens * 4 chars/token = 64 chars per block
CacheNumGPUBlocks: 1000, // 1000 blocks capacity
},
}
pods := []types.Pod{pod}

req := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model",
Body: &types.LLMRequestBody{
Completions: &types.CompletionsRequest{
// Length 128 chars.
// If AutoTune=true (block size 64): 2 blocks
// If AutoTune=false (block size 32): 4 blocks
Prompt: strings.Repeat("a", 128),
},
},
}

t.Run("AutoTune Enabled", func(t *testing.T) {
config := Config{
AutoTune: true,
BlockSize: 32, // Should be ignored in favor of pod metrics (64)
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
// Should be ignored in favor of pod metrics (1000)
LRUCapacityPerServer: 1,
}
plugin := New(context.Background(), config)

// 1. Verify Score uses pod metrics for block size
scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
_ = scores

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String()))
assert.NoError(t, err)
// Block size from pod is 16 tokens * 4 = 64 chars.
// Prompt is 128 chars.
// Expected blocks: 128/64 = 2 hashes (model hash is used as seed but not returned as a block)
assert.Equal(t, 2, len(state.PrefixHashes), "Should use pod block size (64 chars) -> 2 body blocks")

// 2. Verify PreRequest uses pod metrics for capacity
schedulingResult := &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod}},
},
}
plugin.PreRequest(context.Background(), req, schedulingResult)
plugin.wg.Wait()

// Check indexer state
assert.Contains(t, plugin.indexer.Pods(), ServerID(pod.GetPod().NamespacedName))
})

t.Run("AutoTune Disabled", func(t *testing.T) {
config := Config{
AutoTune: false,
BlockSize: 32, // Should be used (32 chars)
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: 1, // Should be used, and the first hash should be evicted due to the small
}
plugin := New(context.Background(), config)

// 1. Verify Score uses config BlockSize
req.RequestId = uuid.NewString() // New request ID
scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
_ = scores

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String()))
assert.NoError(t, err)
// Block size from config is 32 chars.
// Prompt is 128 chars.
// 128 / 32 = 4 chunks.
assert.Equal(t, 4, len(state.PrefixHashes), "Should use config block size (32 chars) -> 4 body blocks")

// 2. Verify PreRequest uses config LRUCapacityPerServer
schedulingResult := &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
"default": {TargetPods: []types.Pod{pod}},
},
}
plugin.PreRequest(context.Background(), req, schedulingResult)
plugin.wg.Wait()

assert.Contains(t, plugin.indexer.Pods(), ServerID(pod.GetPod().NamespacedName))
})
}

// randomPrompt generates a pseudo-random string of length n using lowercase letters.
func randomPrompt(n int) string {
runes := []rune("abcdefghijklmnopqrstuvwxyz")
Expand All @@ -481,7 +576,7 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
blockSize := 8
maxPrefixBlocks := 50000
config := Config{
DefaultBlockSize: blockSize,
BlockSize: blockSize,
MaxPrefixBlocksToMatch: maxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
Expand Down