Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions internal/command/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,21 @@ func NewWebCmd() *cobra.Command {
}

func runWebServer(port int, host string, openBrowser bool) error {
// Check if we need setup (no providers configured).
needsSetup := config.NeedsSetup()

cfg, err := config.LoadConfig()
if err != nil {
return fmt.Errorf("config error: %w", err)
var cfg *config.Config
if !needsSetup {
var err error
cfg, err = config.LoadConfig()
if err != nil {
return fmt.Errorf("config error: %w", err)
}
} else {
// Create a minimal config for setup mode.
cfg = &config.Config{
MaxIterations: 1000,
}
}

ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
Expand All @@ -68,12 +79,15 @@ func runWebServer(port int, host string, openBrowser bool) error {
skillLoader.ScanProjectSkills(pwd)

systemPrompt := prompts.GetSystemPrompt(platform, pwd, "local", envInfo, skillLoader.Descriptions())
providerName, modelName := cfg.GetProviderModel()

providers := cfg.GetProviders()
providerCfg := providers[providerName]
if providerCfg == nil {
return fmt.Errorf("provider %q not found in config", providerName)
var providerName, modelName string
if !needsSetup {
providerName, modelName = cfg.GetProviderModel()
providers := cfg.GetProviders()
providerCfg := providers[providerName]
if providerCfg == nil {
return fmt.Errorf("provider %q not found in config", providerName)
}
}

registry := internalmodel.NewModelRegistry()
Expand Down Expand Up @@ -171,7 +185,12 @@ func runWebServer(port int, host string, openBrowser bool) error {

createAgent := func(prov, mod string) (*adk.ChatModelAgent, error) {
// Resolve provider config.
provCfg := providers[prov]
// Reload config to pick up any new providers added via setup.
currentCfg, err := config.LoadConfig()
if err != nil {
return nil, fmt.Errorf("config error: %w", err)
}
provCfg := currentCfg.GetProviders()[prov]
if provCfg == nil {
return nil, fmt.Errorf("provider %q not configured", prov)
}
Expand Down Expand Up @@ -240,9 +259,13 @@ func runWebServer(port int, host string, openBrowser bool) error {
return agent.NewAgent(ctx, cm, tools, systemPrompt, approvalState.RequestApproval, middlewares, handlers)
}

ag, err := createAgent(providerName, modelName)
if err != nil {
return fmt.Errorf("error creating agent: %w", err)
var ag *adk.ChatModelAgent
var agentErr error
if !needsSetup {
ag, agentErr = createAgent(providerName, modelName)
if agentErr != nil {
return fmt.Errorf("error creating agent: %w", agentErr)
}
}

switchProject := func(newPwd string) (*adk.ChatModelAgent, *session.Recorder, error) {
Expand Down Expand Up @@ -301,6 +324,7 @@ func runWebServer(port int, host string, openBrowser bool) error {
WechatClient: wechatClient,
WebHandler: webHandler,
EventHandler: finalHandler,
NeedsSetup: needsSetup,
})

// Set handler for approval routing.
Expand Down
161 changes: 161 additions & 0 deletions internal/config/model_state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package config

import (
"encoding/json"
"os"
"path/filepath"
"sync"
)

const modelStateFile = "model_state.json"

// ModelState tracks recent, favorite, and visibility settings for models.
type ModelState struct {
Recent []ModelRef `json:"recent,omitempty"`
Favorite []ModelRef `json:"favorite,omitempty"`
// EnabledModels lists models explicitly enabled by the user (shown in model selector).
// If nil/empty, default-enabled models from the registry are used.
EnabledModels []ModelRef `json:"enabled_models,omitempty"`
// DisabledModels lists models explicitly disabled by the user (hidden from model selector).
DisabledModels []ModelRef `json:"disabled_models,omitempty"`
}

// ModelRef uniquely identifies a model in "provider/model" format.
type ModelRef struct {
Provider string `json:"provider"`
Model string `json:"model"`
}

var (
modelStateMu sync.Mutex
)

// modelStatePath returns the path to the model state file.
func modelStatePath() (string, error) {
return filepath.Join(ConfigDir(), modelStateFile), nil
}

// LoadModelState loads the model state from disk.
func LoadModelState() (*ModelState, error) {
modelStateMu.Lock()
defer modelStateMu.Unlock()

p, err := modelStatePath()
if err != nil {
return &ModelState{}, nil
}

data, err := os.ReadFile(p)
if err != nil {
return &ModelState{}, nil
}

var state ModelState
if err := json.Unmarshal(data, &state); err != nil {
return &ModelState{}, nil
}
return &state, nil
}

// SaveModelState writes the model state to disk.
func SaveModelState(state *ModelState) error {
modelStateMu.Lock()
defer modelStateMu.Unlock()

p, err := modelStatePath()
if err != nil {
return err
}

dir := filepath.Dir(p)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}

data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return os.WriteFile(p, data, 0644)
}

// AddRecent adds a model to the recent list (deduped, max 10).
func (s *ModelState) AddRecent(ref ModelRef) {
// Remove if already present
filtered := make([]ModelRef, 0, len(s.Recent))
for _, r := range s.Recent {
if r.Provider != ref.Provider || r.Model != ref.Model {
filtered = append(filtered, r)
}
}
// Prepend
s.Recent = append([]ModelRef{ref}, filtered...)
// Cap at 10
if len(s.Recent) > 10 {
s.Recent = s.Recent[:10]
}
}

// ToggleFavorite adds or removes a model from favorites. Returns true if now favorite.
func (s *ModelState) ToggleFavorite(ref ModelRef) bool {
for i, r := range s.Favorite {
if r.Provider == ref.Provider && r.Model == ref.Model {
s.Favorite = append(s.Favorite[:i], s.Favorite[i+1:]...)
return false
}
}
s.Favorite = append(s.Favorite, ref)
return true
}

// IsFavorite returns whether the given model is in the favorites list.
func (s *ModelState) IsFavorite(ref ModelRef) bool {
for _, r := range s.Favorite {
if r.Provider == ref.Provider && r.Model == ref.Model {
return true
}
}
return false
}

// IsModelEnabled returns whether the given model should be shown in the model selector.
// Logic: if the model is in EnabledModels, it's enabled.
// If the model is in DisabledModels, it's disabled.
// Otherwise, fallback to the defaultEnabled parameter (from registry).
func (s *ModelState) IsModelEnabled(ref ModelRef, defaultEnabled bool) bool {
for _, r := range s.DisabledModels {
if r.Provider == ref.Provider && r.Model == ref.Model {
return false
}
}
for _, r := range s.EnabledModels {
if r.Provider == ref.Provider && r.Model == ref.Model {
return true
}
}
return defaultEnabled
}

// SetModelEnabled explicitly enables or disables a model in the model selector.
func (s *ModelState) SetModelEnabled(ref ModelRef, enabled bool) {
// Remove from both lists first
s.EnabledModels = removeModelRef(s.EnabledModels, ref)
s.DisabledModels = removeModelRef(s.DisabledModels, ref)

if enabled {
s.EnabledModels = append(s.EnabledModels, ref)
} else {
s.DisabledModels = append(s.DisabledModels, ref)
}
}

// removeModelRef removes a model ref from a slice.
func removeModelRef(refs []ModelRef, ref ModelRef) []ModelRef {
result := make([]ModelRef, 0, len(refs))
for _, r := range refs {
if r.Provider != ref.Provider || r.Model != ref.Model {
result = append(result, r)
}
}
return result
}
83 changes: 81 additions & 2 deletions internal/model/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type RegistryModel struct {
Cost *ModelCost `json:"cost,omitempty"`
Limit *ModelLimit `json:"limit,omitempty"`
Status string `json:"status,omitempty"`
Recommended bool `json:"recommended,omitempty"`
DefaultEnabled bool `json:"default_enabled,omitempty"`
}

// ModelModalities describes input/output modalities.
Expand Down Expand Up @@ -151,8 +153,8 @@ func (r *ModelRegistry) ListProviderModels(providerID string, toolCallOnly bool)
}
models = append(models, m)
}
// Sort by ID for consistent ordering
sortModelsByID(models)
// Sort: recommended first, then by ID
sortModels(models)
return models
}

Expand Down Expand Up @@ -183,3 +185,80 @@ func sortModelsByID(models []*RegistryModel) {
}
}
}

// sortModels sorts models: recommended first, then by ID.
func sortModels(models []*RegistryModel) {
for i := 0; i < len(models); i++ {
for j := i + 1; j < len(models); j++ {
iRec := models[i].Recommended
jRec := models[j].Recommended
if (!iRec && jRec) || (iRec == jRec && models[i].ID > models[j].ID) {
models[i], models[j] = models[j], models[i]
}
}
}
}

// recommendedModels defines recommended and default-enabled models per provider.
// Key: provider ID, Value: map of model ID → true (recommended + default enabled).
var recommendedModels = map[string]map[string]bool{
"zhipuai": {
"glm-5.1": true,
"glm-5": true,
},
"zhipuai-coding-plan": {
"glm-5.1": true,
"glm-5": true,
},
"deepseek": {
"deepseek-v4-pro": true,
},
"alibaba-cn": {
"qwen3.6-plus": true,
"MiniMax/MiniMax-M2.7": true,
"deepseek-v3-2-exp": true,
"kimi-k2.6": true,
},
"alibaba-coding-plan-cn": {
"qwen3.6-plus": true,
},
"moonshotai": {
"kimi-k2.6": true,
},
"minimax": {
"MiniMax-M2.7": true,
},
"minimax-coding-plan": {
"MiniMax-M2.7": true,
},
"openai": {
"gpt-4.1": true,
"o4-mini": true,
},
"anthropic": {
"claude-sonnet-4-20250514": true,
},
"google": {
"gemini-2.5-pro": true,
},
}

func init() {
applyRecommendedModels()
}

// applyRecommendedModels sets Recommended and DefaultEnabled on models in the generated registry.
func applyRecommendedModels() {
for provID, models := range recommendedModels {
prov, ok := generatedProviders[provID]
if !ok {
continue
}
for modelID := range models {
if m, ok := prov.Models[modelID]; ok {
m.Recommended = true
m.DefaultEnabled = true
}
}
}
}
44 changes: 44 additions & 0 deletions internal/model/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package model

import (
"context"
"fmt"
"net/http"
"time"
)

// ValidateProvider tests connectivity to a provider by making a lightweight
// GET /models request. Returns nil on success, or a descriptive error.
func ValidateProvider(ctx context.Context, apiKey, baseURL string) error {
if baseURL == "" {
return fmt.Errorf("base URL is empty")
}

client := &http.Client{Timeout: 10 * time.Second}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}

resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("connection failed: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusUnauthorized {
return fmt.Errorf("invalid API key (401 Unauthorized)")
}
if resp.StatusCode == http.StatusForbidden {
return fmt.Errorf("access denied (403 Forbidden) — check API key permissions")
}
if resp.StatusCode >= 400 {
return fmt.Errorf("server returned %d %s", resp.StatusCode, resp.Status)
}

return nil
}
Loading