Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .ccs-fork-upstream.env
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
UPSTREAM_TAG=v6.9.42
UPSTREAM_COMMIT=359ec30d0c5674659d9d73080de378f9a7417c4a
UPSTREAM_TAG=v6.9.43
UPSTREAM_COMMIT=e3e60f914ba82a6caa7a17a717f65a3b2f02285f
4 changes: 4 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ max-retry-interval: 30
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
disable-cooling: false

# When true, disable the built-in image_generation tool globally.
# The server will stop injecting image_generation and will also remove it from request payload tools arrays.
disable-image-generation: false

# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh).
# When > 0, overrides the default worker count (16).
# auth-auto-refresh-workers: 16
Expand Down
4 changes: 4 additions & 0 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
}

if oldCfg != nil && oldCfg.DisableImageGeneration != cfg.DisableImageGeneration {
log.Infof("disable-image-generation updated: %t -> %t", oldCfg.DisableImageGeneration, cfg.DisableImageGeneration)
}

applySignatureCacheConfig(oldCfg, cfg)

if s.handlers != nil && s.handlers.AuthManager != nil {
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.ErrorLogsMaxFiles = 10
cfg.UsageStatisticsEnabled = false
cfg.DisableCooling = false
cfg.DisableImageGeneration = false
cfg.Pprof.Enable = false
cfg.Pprof.Addr = DefaultPprofAddr
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
Expand Down
6 changes: 6 additions & 0 deletions internal/config/sdk_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ type SDKConfig struct {
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`

// DisableImageGeneration disables the built-in image_generation tool when true.
// When enabled, the server will avoid injecting image_generation into request payloads,
// will remove any existing image_generation tool entries from tools arrays, and will
// return 404 for /v1/images/generations and /v1/images/edits.
DisableImageGeneration bool `yaml:"disable-image-generation" json:"disable-image-generation"`

// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
// Default is false for safety; when false, /v1internal:* requests are rejected.
EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"`
Expand Down
12 changes: 9 additions & 3 deletions internal/runtime/executor/codex_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}

url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
Expand Down Expand Up @@ -329,7 +331,9 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "stream")
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}

url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
Expand Down Expand Up @@ -424,7 +428,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "model", baseModel)
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}

url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
Expand Down
280 changes: 162 additions & 118 deletions internal/runtime/executor/helps/payload_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,133 +20,137 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
if cfg == nil || len(payload) == 0 {
return payload
}
rules := cfg.Payload
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
return payload
}
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model == "" && requestedModel == "" {
return payload
}
candidates := payloadModelCandidates(model, requestedModel)
out := payload
source := original
if len(source) == 0 {
source = payload
}
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue

rules := cfg.Payload
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
if hasPayloadRules {
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model != "" || requestedModel != "" {
candidates := payloadModelCandidates(model, requestedModel)
source := original
if len(source) == 0 {
source = payload
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
}
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}
}
out = updated
}
}
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}

if cfg.DisableImageGeneration {
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
}
return out
}
Expand Down Expand Up @@ -226,6 +230,46 @@ func buildPayloadPath(root, path string) string {
return r + "." + p
}

func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
if len(payload) == 0 {
return payload
}
toolType = strings.TrimSpace(toolType)
if toolType == "" {
return payload
}
toolsPath := buildPayloadPath(root, "tools")
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
}

func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
tools := gjson.GetBytes(payload, toolsPath)
if !tools.Exists() || !tools.IsArray() {
return payload
}
removed := false
filtered := []byte(`[]`)
for _, tool := range tools.Array() {
if tool.Get("type").String() == toolType {
removed = true
continue
}
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
if errSet != nil {
continue
}
filtered = updated
}
if !removed {
return payload
}
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
if errSet != nil {
return payload
}
return updated
}

func payloadRawValue(value any) ([]byte, bool) {
if value == nil {
return nil, false
Expand Down
Loading