Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions internal/config/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ type GitConfig struct {
AutoPullInterval time.Duration `yaml:"auto_pull_interval,omitempty"`
}

// OAuthConfig holds OAuth server configuration.
type OAuthConfig struct {
AccessTokenTTL time.Duration `yaml:"access_token_ttl,omitempty"`
RefreshTokenTTL time.Duration `yaml:"refresh_token_ttl,omitempty"`
}

// MCPConfig holds MCP server-related configuration for AI agent integration.
type MCPConfig struct {
Bind string `yaml:"bind,omitempty"`
Expand All @@ -68,6 +74,7 @@ type MCPConfig struct {
TLSCertFile string `yaml:"tls_cert_file,omitempty"`
TLSKeyFile string `yaml:"tls_key_file,omitempty"`
AllowInsecureBind bool `yaml:"allow_insecure_bind,omitempty"`
OAuth *OAuthConfig `yaml:"oauth,omitempty"`
}

// UpdateConfig holds update check-related configuration.
Expand Down Expand Up @@ -137,6 +144,10 @@ func defaultMCPConfig() MCPConfig {
ApprovalTimeout: 30 * time.Second,
RateLimit: 60,
MetricsAuthRequired: true,
OAuth: &OAuthConfig{
AccessTokenTTL: 24 * time.Hour,
RefreshTokenTTL: 720 * time.Hour,
},
}
}

Expand Down Expand Up @@ -197,6 +208,13 @@ type fileGitConfig struct {
CommitTemplate *string `yaml:"commit_template,omitempty"`
}

// fileOAuthConfig is the file-based OAuth configuration with pointer fields
// for optional YAML unmarshaling.
type fileOAuthConfig struct {
AccessTokenTTL *time.Duration `yaml:"access_token_ttl,omitempty"`
RefreshTokenTTL *time.Duration `yaml:"refresh_token_ttl,omitempty"`
}

// fileMCPConfig is the file-based MCP configuration with pointer fields
// for optional YAML unmarshaling.
type fileMCPConfig struct {
Expand All @@ -217,6 +235,7 @@ type fileMCPConfig struct {
TLSCertFile *string `yaml:"tls_cert_file,omitempty"`
TLSKeyFile *string `yaml:"tls_key_file,omitempty"`
AllowInsecureBind *bool `yaml:"allow_insecure_bind,omitempty"`
OAuth *fileOAuthConfig `yaml:"oauth,omitempty"`
}

// fileUpdateConfig is the file-based update configuration with pointer fields
Expand Down Expand Up @@ -363,6 +382,17 @@ func MergeFileMCPConfig(fileCfg *fileMCPConfig, defaults MCPConfig) MCPConfig {
if fileCfg.AllowInsecureBind != nil {
result.AllowInsecureBind = *fileCfg.AllowInsecureBind
}
if fileCfg.OAuth != nil {
if result.OAuth == nil {
result.OAuth = &OAuthConfig{}
}
if fileCfg.OAuth.AccessTokenTTL != nil {
result.OAuth.AccessTokenTTL = *fileCfg.OAuth.AccessTokenTTL
}
if fileCfg.OAuth.RefreshTokenTTL != nil {
result.OAuth.RefreshTokenTTL = *fileCfg.OAuth.RefreshTokenTTL
}
}
return result
}

Expand Down
74 changes: 74 additions & 0 deletions internal/config/schema_oauth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package config

import (
"testing"
"time"
)

func TestDefaultOAuthConfig(t *testing.T) {
cfg := defaultMCPConfig()
if cfg.OAuth == nil {
t.Fatal("OAuth config is nil in defaults")
}
if cfg.OAuth.AccessTokenTTL != 24*time.Hour {
t.Errorf("default AccessTokenTTL = %v, want 24h", cfg.OAuth.AccessTokenTTL)
}
if cfg.OAuth.RefreshTokenTTL != 720*time.Hour {
t.Errorf("default RefreshTokenTTL = %v, want 720h (30d)", cfg.OAuth.RefreshTokenTTL)
}
}

func TestMergeFileOAuthConfig(t *testing.T) {
accessTTL := 10 * time.Second
refreshTTL := 30 * time.Second

fileCfg := &fileMCPConfig{
OAuth: &fileOAuthConfig{
AccessTokenTTL: &accessTTL,
RefreshTokenTTL: &refreshTTL,
},
}

result := MergeFileMCPConfig(fileCfg, defaultMCPConfig())

if result.OAuth == nil {
t.Fatal("OAuth config is nil after merge")
}
if result.OAuth.AccessTokenTTL != accessTTL {
t.Errorf("AccessTokenTTL = %v, want %v", result.OAuth.AccessTokenTTL, accessTTL)
}
if result.OAuth.RefreshTokenTTL != refreshTTL {
t.Errorf("RefreshTokenTTL = %v, want %v", result.OAuth.RefreshTokenTTL, refreshTTL)
}
}

func TestMergeFileOAuthConfig_PartialOverride(t *testing.T) {
refreshTTL := 100 * time.Hour

fileCfg := &fileMCPConfig{
OAuth: &fileOAuthConfig{
RefreshTokenTTL: &refreshTTL,
},
}

result := MergeFileMCPConfig(fileCfg, defaultMCPConfig())

if result.OAuth.AccessTokenTTL != 24*time.Hour {
t.Errorf("AccessTokenTTL = %v, want default 24h", result.OAuth.AccessTokenTTL)
}
if result.OAuth.RefreshTokenTTL != 100*time.Hour {
t.Errorf("RefreshTokenTTL = %v, want 100h", result.OAuth.RefreshTokenTTL)
}
}

func TestMergeFileOAuthConfig_NilOAuth(t *testing.T) {
fileCfg := &fileMCPConfig{}
result := MergeFileMCPConfig(fileCfg, defaultMCPConfig())

if result.OAuth == nil {
t.Fatal("OAuth config should not be nil after merge with nil file cfg")
}
if result.OAuth.AccessTokenTTL != 24*time.Hour {
t.Errorf("AccessTokenTTL = %v, want default 24h", result.OAuth.AccessTokenTTL)
}
}
12 changes: 11 additions & 1 deletion internal/mcp/serverbootstrap/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,17 @@ func RunHTTPServerOnListener(ctx context.Context, listener net.Listener, v *vaul
mux.HandleFunc("GET /mcp/oauth/authorize", oauthAuthorizeHandler.ServeHTTP)

// Token endpoint uses the scoped token registry instead of the legacy bearer token.
oauthTokenHandler := mcp.OriginValidationMiddleware(addr, handleOAuthToken(oauthStore, registry))
accessTokenTTL := 24 * time.Hour
refreshTokenTTL := 720 * time.Hour
if v != nil && v.Config != nil && v.Config.MCP != nil && v.Config.MCP.OAuth != nil {
if v.Config.MCP.OAuth.AccessTokenTTL > 0 {
accessTokenTTL = v.Config.MCP.OAuth.AccessTokenTTL
}
if v.Config.MCP.OAuth.RefreshTokenTTL > 0 {
refreshTokenTTL = v.Config.MCP.OAuth.RefreshTokenTTL
}
}
oauthTokenHandler := mcp.OriginValidationMiddleware(addr, handleOAuthToken(oauthStore, registry, accessTokenTTL, refreshTokenTTL))
mux.HandleFunc("POST /mcp/oauth/token", oauthTokenHandler.ServeHTTP)

const maxRequestBodySize = 1 * 1024 * 1024
Expand Down
101 changes: 65 additions & 36 deletions internal/mcp/serverbootstrap/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func handleOAuthRegister(clientStore *oauthClientStore) http.HandlerFunc {
"client_id_issued_at": time.Now().Unix(),
"client_secret_expires_at": 0,
"token_endpoint_auth_method": "none",
"grant_types": []string{"authorization_code"},
"grant_types": []string{"authorization_code", "refresh_token"},
"response_types": []string{"code"},
"redirect_uris": req.RedirectURIs,
})
Expand Down Expand Up @@ -222,55 +222,84 @@ func handleOAuthAuthorize(store *oauthCodeStore, clientStore *oauthClientStore)
}

// handleOAuthToken implements the authorization code grant (RFC 6749 §4.1.3)
// with PKCE verification (RFC 7636). On success it mints a fresh scoped MCP
// token via the TokenRegistry instead of returning the global legacy bearer
// token. The scoped token has a 24-hour TTL.
func handleOAuthToken(store *oauthCodeStore, registry *mcp.TokenRegistry) http.HandlerFunc {
// with PKCE verification (RFC 7636) and refresh token support (RFC 6749 §6).
// On success it mints a fresh scoped MCP token via the TokenRegistry instead
// of returning the global legacy bearer token.
func handleOAuthToken(store *oauthCodeStore, registry *mcp.TokenRegistry, accessTokenTTL, refreshTokenTTL time.Duration) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_request"})
return
}
if r.FormValue("grant_type") != "authorization_code" {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "unsupported_grant_type"})
return
}

if registry == nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "server_error"})
return
}

pending, ok := store.take(r.FormValue("code"))
if !ok || time.Now().After(pending.expiresAt) {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"})
return
}
if !verifyS256(r.FormValue("code_verifier"), pending.codeChallenge) {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"})
return
}
grantType := r.FormValue("grant_type")

switch grantType {
case "authorization_code":
pending, ok := store.take(r.FormValue("code"))
if !ok || time.Now().After(pending.expiresAt) {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"})
return
}
if !verifyS256(r.FormValue("code_verifier"), pending.codeChallenge) {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"})
return
}

label := fmt.Sprintf("oauth-%s", pending.clientID[:8])
tok, rawToken, rawRefresh, err := registry.CreateWithRefresh(
label, []string{"*"}, "oauth", accessTokenTTL, refreshTokenTTL,
)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "server_error"})
return
}

expiresIn := 0
if tok.ExpiresAt != nil {
expiresIn = int(time.Until(*tok.ExpiresAt).Seconds())
}

writeJSON(w, http.StatusOK, map[string]any{
"access_token": rawToken,
"token_type": "Bearer",
"expires_in": expiresIn,
"refresh_token": rawRefresh,
})

// Mint a fresh scoped token instead of returning the legacy bearer token.
// This ensures every OAuth-issued token is independently revocable and
// auditable — the global legacy token is never exposed to OAuth clients.
label := fmt.Sprintf("oauth-%s", pending.clientID[:8])
tok, rawToken, err := registry.Create(label, []string{"*"}, "oauth", 24*time.Hour)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "server_error"})
return
}
case "refresh_token":
rawRefresh := r.FormValue("refresh_token")
if rawRefresh == "" {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_request", "error_description": "refresh_token is required"})
return
}

newTok, rawAccess, rawRefresh, err := registry.RotateViaRefreshToken(rawRefresh)
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant", "error_description": "invalid or expired refresh token"})
return
}

expiresIn := 0
if newTok.ExpiresAt != nil {
expiresIn = int(time.Until(*newTok.ExpiresAt).Seconds())
}

writeJSON(w, http.StatusOK, map[string]any{
"access_token": rawAccess,
"token_type": "Bearer",
"expires_in": expiresIn,
"refresh_token": rawRefresh,
})

expiresIn := 0
if tok.ExpiresAt != nil {
expiresIn = int(time.Until(*tok.ExpiresAt).Seconds())
default:
writeJSON(w, http.StatusBadRequest, map[string]string{"error": "unsupported_grant_type"})
}

writeJSON(w, http.StatusOK, map[string]any{
"access_token": rawToken,
"token_type": "Bearer",
"expires_in": expiresIn,
})
}
}

Expand Down
Loading
Loading