From bd9f990e959694ddd0c65958dc66ddb9da045c19 Mon Sep 17 00:00:00 2001 From: Christopher Petito Date: Mon, 24 Nov 2025 15:35:36 +0100 Subject: [PATCH] Fix thinking budget and rag strategy serialization/deserialization Fixes pushing/pulling agent teams using these values Signed-off-by: Christopher Petito --- pkg/config/latest/types.go | 145 +++++++++++++++++++++++++ pkg/config/latest/types_test.go | 185 ++++++++++++++++++++++++++++++++ pkg/config/roundtrip_test.go | 140 ++++++++++++++++++++++++ 3 files changed, 470 insertions(+) create mode 100644 pkg/config/roundtrip_test.go diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index ecc691cd6..b8cf48b03 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -1,8 +1,11 @@ package latest import ( + "encoding/json" "fmt" + "github.com/goccy/go-yaml" + "github.com/docker/cagent/pkg/config/types" ) @@ -194,6 +197,49 @@ func (t *ThinkingBudget) UnmarshalYAML(unmarshal func(any) error) error { return nil } +// MarshalYAML implements custom marshaling to output simple string or int format +func (t ThinkingBudget) MarshalYAML() ([]byte, error) { + // If Effort string is set (non-empty), marshal as string + if t.Effort != "" { + return yaml.Marshal(t.Effort) + } + + // Otherwise marshal as integer (includes 0, -1, and positive values) + return yaml.Marshal(t.Tokens) +} + +// MarshalJSON implements custom marshaling to output simple string or int format +// This ensures JSON and YAML have the same flattened format for consistency +func (t ThinkingBudget) MarshalJSON() ([]byte, error) { + // If Effort string is set (non-empty), marshal as string + if t.Effort != "" { + return []byte(fmt.Sprintf("%q", t.Effort)), nil + } + + // Otherwise marshal as integer (includes 0, -1, and positive values) + return []byte(fmt.Sprintf("%d", t.Tokens)), nil +} + +// UnmarshalJSON implements custom unmarshaling to accept simple string or int format +// This ensures JSON and YAML have the same flattened format for consistency +func (t *ThinkingBudget) UnmarshalJSON(data []byte) error { + // Try integer tokens first + var n int + if err := json.Unmarshal(data, &n); err == nil { + *t = ThinkingBudget{Tokens: n} + return nil + } + + // Try string level + var s string + if err := json.Unmarshal(data, &s); err == nil { + *t = ThinkingBudget{Effort: s} + return nil + } + + return nil +} + // StructuredOutput defines a JSON schema for structured output type StructuredOutput struct { // Name is the name of the response format @@ -283,6 +329,105 @@ func (s *RAGStrategyConfig) UnmarshalYAML(unmarshal func(any) error) error { return nil } +// MarshalYAML implements custom marshaling to flatten Params into parent level +func (s RAGStrategyConfig) MarshalYAML() ([]byte, error) { + result := s.buildFlattenedMap() + return yaml.Marshal(result) +} + +// MarshalJSON implements custom marshaling to flatten Params into parent level +// This ensures JSON and YAML have the same flattened format for consistency +func (s RAGStrategyConfig) MarshalJSON() ([]byte, error) { + result := s.buildFlattenedMap() + return json.Marshal(result) +} + +// UnmarshalJSON implements custom unmarshaling to capture all extra fields into Params +// This ensures JSON and YAML have the same flattened format for consistency +func (s *RAGStrategyConfig) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to capture everything + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Extract known fields + if t, ok := raw["type"].(string); ok { + s.Type = t + delete(raw, "type") + } + + if docs, ok := raw["docs"].([]any); ok { + s.Docs = make([]string, len(docs)) + for i, d := range docs { + if str, ok := d.(string); ok { + s.Docs[i] = str + } + } + delete(raw, "docs") + } + + if dbRaw, ok := raw["database"]; ok { + if dbStr, ok := dbRaw.(string); ok { + var db RAGDatabaseConfig + db.value = dbStr + s.Database = db + } + delete(raw, "database") + } + + if chunkRaw, ok := raw["chunking"]; ok { + // Re-marshal and unmarshal chunking config + chunkBytes, _ := json.Marshal(chunkRaw) + var chunk RAGChunkingConfig + if err := json.Unmarshal(chunkBytes, &chunk); err == nil { + s.Chunking = chunk + } + delete(raw, "chunking") + } + + if limit, ok := raw["limit"].(float64); ok { + s.Limit = int(limit) + delete(raw, "limit") + } + + // Everything else goes into Params for strategy-specific configuration + s.Params = raw + + return nil +} + +// buildFlattenedMap creates a flattened map representation for marshaling +// Used by both MarshalYAML and MarshalJSON to ensure consistent format +func (s RAGStrategyConfig) buildFlattenedMap() map[string]any { + result := make(map[string]any) + + if s.Type != "" { + result["type"] = s.Type + } + if len(s.Docs) > 0 { + result["docs"] = s.Docs + } + if !s.Database.IsEmpty() { + dbStr, _ := s.Database.AsString() + result["database"] = dbStr + } + // Only include chunking if any fields are set + if s.Chunking.Size > 0 || s.Chunking.Overlap > 0 || s.Chunking.RespectWordBoundaries { + result["chunking"] = s.Chunking + } + if s.Limit > 0 { + result["limit"] = s.Limit + } + + // Flatten Params into the same level + for k, v := range s.Params { + result[k] = v + } + + return result +} + // unmarshalDatabaseConfig handles DatabaseConfig unmarshaling from raw YAML data. // For RAG strategies, the database configuration is intentionally simple: // a single string value under the `database` key that points to the SQLite diff --git a/pkg/config/latest/types_test.go b/pkg/config/latest/types_test.go index c4b0c6b83..4791cca8a 100644 --- a/pkg/config/latest/types_test.go +++ b/pkg/config/latest/types_test.go @@ -32,3 +32,188 @@ func TestCommandsUnmarshal_List(t *testing.T) { require.Equal(t, "check disk", c["df"]) require.Equal(t, "list files", c["ls"]) } + +func TestThinkingBudget_MarshalUnmarshal_String(t *testing.T) { + t.Parallel() + + // Test string effort level + input := []byte(`thinking_budget: minimal`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Equal(t, "minimal", config.ThinkingBudget.Effort) + require.Equal(t, 0, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: minimal\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_Integer(t *testing.T) { + t.Parallel() + + // Test integer token budget + input := []byte(`thinking_budget: 8192`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, 8192, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: 8192\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_NegativeInteger(t *testing.T) { + t.Parallel() + + // Test negative integer token budget (e.g., -1 for Gemini dynamic thinking) + input := []byte(`thinking_budget: -1`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, -1, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: -1\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_Zero(t *testing.T) { + t.Parallel() + + // Test zero token budget (e.g., 0 for Gemini no thinking) + input := []byte(`thinking_budget: 0`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, 0, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: 0\n", string(output)) +} + +func TestRAGStrategyConfig_MarshalUnmarshal_FlattenedParams(t *testing.T) { + t.Parallel() + + // Test that params are flattened during unmarshal and remain flattened after marshal + input := []byte(`type: chunked-embeddings +model: embeddinggemma +database: ./rag/test.db +threshold: 0.5 +vector_dimensions: 768 +`) + + var strategy RAGStrategyConfig + + // Unmarshal + err := yaml.Unmarshal(input, &strategy) + require.NoError(t, err) + require.Equal(t, "chunked-embeddings", strategy.Type) + require.Equal(t, "./rag/test.db", mustGetDBString(t, strategy.Database)) + require.NotNil(t, strategy.Params) + require.Equal(t, "embeddinggemma", strategy.Params["model"]) + require.InEpsilon(t, 0.5, strategy.Params["threshold"], 0.001) + // YAML may unmarshal numbers as different numeric types (int, uint64, float64) + require.InEpsilon(t, float64(768), toFloat64(strategy.Params["vector_dimensions"]), 0.001) + + // Marshal back + output, err := yaml.Marshal(strategy) + require.NoError(t, err) + + // Verify it's still flattened (no "params:" key) + outputStr := string(output) + require.Contains(t, outputStr, "type: chunked-embeddings") + require.Contains(t, outputStr, "model: embeddinggemma") + require.Contains(t, outputStr, "threshold: 0.5") + require.Contains(t, outputStr, "vector_dimensions: 768") + require.NotContains(t, outputStr, "params:") + + // Unmarshal again to verify round-trip + var strategy2 RAGStrategyConfig + err = yaml.Unmarshal(output, &strategy2) + require.NoError(t, err) + require.Equal(t, strategy.Type, strategy2.Type) + require.Equal(t, strategy.Params["model"], strategy2.Params["model"]) + require.Equal(t, strategy.Params["threshold"], strategy2.Params["threshold"]) + // YAML may unmarshal numbers as different numeric types (int, uint64, float64) + // Just verify the numeric value is correct + require.InEpsilon(t, float64(768), toFloat64(strategy2.Params["vector_dimensions"]), 0.001) +} + +func TestRAGStrategyConfig_MarshalUnmarshal_WithDatabase(t *testing.T) { + t.Parallel() + + input := []byte(`type: chunked-embeddings +database: ./test.db +model: test-model +`) + + var strategy RAGStrategyConfig + err := yaml.Unmarshal(input, &strategy) + require.NoError(t, err) + + // Marshal back + output, err := yaml.Marshal(strategy) + require.NoError(t, err) + + // Should contain database as a simple string, not nested with sub-fields + outputStr := string(output) + require.Contains(t, outputStr, "database: ./test.db") + require.NotContains(t, outputStr, " value:") // Should not be nested with internal fields + require.Contains(t, outputStr, "model: test-model") + require.NotContains(t, outputStr, "params:") // Should be flattened +} + +func mustGetDBString(t *testing.T, db RAGDatabaseConfig) string { + t.Helper() + str, err := db.AsString() + require.NoError(t, err) + return str +} + +// toFloat64 converts various numeric types to float64 for comparison +func toFloat64(v any) float64 { + switch val := v.(type) { + case int: + return float64(val) + case int64: + return float64(val) + case uint64: + return float64(val) + case float64: + return val + case float32: + return float64(val) + default: + return 0 + } +} diff --git a/pkg/config/roundtrip_test.go b/pkg/config/roundtrip_test.go new file mode 100644 index 000000000..d48368a88 --- /dev/null +++ b/pkg/config/roundtrip_test.go @@ -0,0 +1,140 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/config/latest" +) + +// TestExampleYAMLRoundtrip tests that all example YAML files can be serialized/deserialized +// without losing information. This simulates the push/pull flow where configs are: +// 1. Read from YAML file +// 2. Parsed into Config struct +// 3. Marshaled back to YAML (for OCI packaging) +// 4. Unmarshaled again (when pulled) +func TestExampleYAMLRoundtrip(t *testing.T) { + t.Parallel() + + examplesDir := filepath.Join("..", "..", "examples") + if _, err := os.Stat(examplesDir); os.IsNotExist(err) { + t.Skip("examples directory not found") + } + + // Collect all YAML files from examples directory (including subdirectories) + var yamlFiles []string + err := filepath.Walk(examplesDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && (filepath.Ext(path) == ".yaml" || filepath.Ext(path) == ".yml") { + yamlFiles = append(yamlFiles, path) + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, yamlFiles, "no YAML files found in examples directory") + + ctx := t.Context() + + for _, yamlFile := range yamlFiles { + t.Run(filepath.Base(yamlFile), func(t *testing.T) { + t.Parallel() + + // Step 1: Read original YAML file + originalBytes, err := os.ReadFile(yamlFile) + require.NoError(t, err, "failed to read %s", yamlFile) + + // Step 2: Parse into Config struct (simulates push reading the file) + cfg, err := LoadConfigBytes(ctx, originalBytes) + require.NoError(t, err, "failed to load config from %s", yamlFile) + require.NotNil(t, cfg, "config should not be nil for %s", yamlFile) + + // Step 3: Marshal back to YAML with same options as push command + // (see pkg/oci/package.go PackageFileAsOCIToStore) + marshaledBytes, err := yaml.MarshalWithOptions(cfg, yaml.Indent(2)) + require.NoError(t, err, "failed to marshal config for %s", yamlFile) + + // Step 4: Unmarshal again (simulates pull reading from OCI) + var cfg2 *latest.Config + cfg2, err = LoadConfigBytes(ctx, marshaledBytes) + require.NoError(t, err, "failed to load marshaled config for %s", yamlFile) + require.NotNil(t, cfg2, "round-tripped config should not be nil for %s", yamlFile) + + // Step 5: Compare the two parsed configs - they should be identical + // Marshal both to JSON for easy comparison (avoids YAML formatting differences) + assertConfigsEqual(t, cfg, cfg2, yamlFile) + + // Step 6: Ensure the marshaled YAML can be marshaled again identically (stability test) + marshaledBytes2, err := yaml.MarshalWithOptions(cfg2, yaml.Indent(2)) + require.NoError(t, err, "failed to re-marshal config for %s", yamlFile) + + // Parse both marshaled versions and compare + var cfg3 *latest.Config + cfg3, err = LoadConfigBytes(ctx, marshaledBytes2) + require.NoError(t, err, "failed to load re-marshaled config for %s", yamlFile) + + assertConfigsEqual(t, cfg, cfg3, yamlFile) + }) + } +} + +// assertConfigsEqual compares two configs for semantic equality using go-cmp +func assertConfigsEqual(t *testing.T, cfg1, cfg2 *latest.Config, filename string) { + t.Helper() + + // Define comparison options to handle normalization and special cases + opts := []cmp.Option{ + // Sort maps for consistent comparison (map iteration order is random) + cmpopts.SortMaps(func(a, b string) bool { return a < b }), + + // Handle ParallelToolCalls normalization: nil and &true are considered equal + // Config validation sets ParallelToolCalls to &true if nil, so after roundtrip it may differ + cmp.Comparer(func(a, b *bool) bool { + // Both nil is equal + if a == nil && b == nil { + return true + } + // One nil, one true is equal (normalized case) + if (a == nil && b != nil && *b == true) || (b == nil && a != nil && *a == true) { + return true + } + // Both non-nil, compare values + if a != nil && b != nil { + return *a == *b + } + return false + }), + + // Handle RAGDatabaseConfig which has unexported fields + // Compare using the public AsString() method + cmp.Comparer(func(a, b latest.RAGDatabaseConfig) bool { + aStr, aErr := a.AsString() + bStr, bErr := b.AsString() + // If both error, consider them equal (both invalid) + if aErr != nil && bErr != nil { + return true + } + // If one errors, not equal + if aErr != nil || bErr != nil { + return false + } + // Compare the string values + return aStr == bStr + }), + + // Treat nil and empty slices as equal (common normalization during YAML marshal/unmarshal) + cmpopts.EquateEmpty(), + } + + // Use cmp.Diff to get detailed differences + if diff := cmp.Diff(cfg1, cfg2, opts...); diff != "" { + t.Errorf("Config mismatch for %s (-want +got):\n%s", filename, diff) + } +}