Skip to content

Commit

Permalink
Quit agent endpoint with config (#14223)
Browse files Browse the repository at this point in the history
* Add agent/v1/quit endpoint
  * Closes #11089
* Agent quit API behind config setting
* Normalise test config whitespace
* Document config option

Co-authored-by: Rémi Lapeyre <remi.lapeyre@lenstra.fr>
Co-authored-by: Ben Ash <32777270+benashz@users.noreply.github.com>
  • Loading branch information
3 people committed Feb 25, 2022
1 parent ed46b97 commit e9df7a6
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 35 deletions.
3 changes: 3 additions & 0 deletions changelog/14223.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
agent: The `agent/v1/quit` endpoint can now be used to stop the Vault Agent remotely
```
22 changes: 22 additions & 0 deletions command/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,10 @@ func (c *AgentCommand) Run(args []string) int {

// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
quitEnabled := lnConfig.AgentAPI != nil && lnConfig.AgentAPI.EnableQuit

mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
mux.Handle(consts.AgentPathQuit, c.handleQuit(quitEnabled))
mux.Handle(consts.AgentPathMetrics, c.handleMetrics())
mux.Handle("/", muxHandler)

Expand Down Expand Up @@ -1047,3 +1050,22 @@ func (c *AgentCommand) handleMetrics() http.Handler {
}
})
}

func (c *AgentCommand) handleQuit(enabled bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !enabled {
w.WriteHeader(http.StatusNotFound)
return
}

switch r.Method {
case http.MethodPost:
default:
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

c.logger.Debug("received quit request")
close(c.ShutdownCh)
})
}
179 changes: 161 additions & 18 deletions command/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -510,21 +511,31 @@ cache {
}
listener "tcp" {
address = "127.0.0.1:8101"
address = "%s"
tls_disable = true
}
listener "tcp" {
address = "127.0.0.1:8102"
address = "%s"
tls_disable = true
require_request_header = false
}
listener "tcp" {
address = "127.0.0.1:8103"
address = "%s"
tls_disable = true
require_request_header = true
}
`
config = fmt.Sprintf(config, roleIDPath, secretIDPath)
listenAddr1 := generateListenerAddress(t)
listenAddr2 := generateListenerAddress(t)
listenAddr3 := generateListenerAddress(t)
config = fmt.Sprintf(
config,
roleIDPath,
secretIDPath,
listenAddr1,
listenAddr2,
listenAddr3,
)
configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)

Expand Down Expand Up @@ -563,19 +574,19 @@ listener "tcp" {

// Test against a listener configuration that omits
// 'require_request_header', with the header missing from the request.
agentClient := newApiClient("http://127.0.0.1:8101", false)
agentClient := newApiClient("http://"+listenAddr1, false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(t, agentClient, req, 200)

// Test against a listener configuration that sets 'require_request_header'
// to 'false', with the header missing from the request.
agentClient = newApiClient("http://127.0.0.1:8102", false)
agentClient = newApiClient("http://"+listenAddr2, false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(t, agentClient, req, 200)

// Test against a listener configuration that sets 'require_request_header'
// to 'true', with the header missing from the request.
agentClient = newApiClient("http://127.0.0.1:8103", false)
agentClient = newApiClient("http://"+listenAddr3, false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
resp, err := agentClient.RawRequest(req)
if err == nil {
Expand All @@ -587,7 +598,7 @@ listener "tcp" {

// Test against a listener configuration that sets 'require_request_header'
// to 'true', with an invalid header present in the request.
agentClient = newApiClient("http://127.0.0.1:8103", false)
agentClient = newApiClient("http://"+listenAddr3, false)
h := agentClient.Headers()
h[consts.RequestHeaderName] = []string{"bogus"}
agentClient.SetHeaders(h)
Expand All @@ -602,7 +613,7 @@ listener "tcp" {

// Test against a listener configuration that sets 'require_request_header'
// to 'true', with the proper header present in the request.
agentClient = newApiClient("http://127.0.0.1:8103", true)
agentClient = newApiClient("http://"+listenAddr3, true)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(t, agentClient, req, 200)
}
Expand All @@ -613,16 +624,17 @@ listener "tcp" {
func TestAgent_RequireAutoAuthWithForce(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
// Create a config file
config := `
config := fmt.Sprintf(`
cache {
use_auto_auth_token = "force"
}
listener "tcp" {
address = "127.0.0.1:8101"
address = "%s"
tls_disable = true
}
`
`, generateListenerAddress(t))

configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)

Expand Down Expand Up @@ -1623,7 +1635,7 @@ func TestAgent_Cache_Retry(t *testing.T) {
cache {
}
`
listenAddr := "127.0.0.1:18123"
listenAddr := generateListenerAddress(t)
listenConfig := fmt.Sprintf(`
listener "tcp" {
address = "%s"
Expand Down Expand Up @@ -1861,7 +1873,7 @@ func TestAgent_TemplateConfig_ExitOnRetryFailure(t *testing.T) {
t.Fatal(err)
}

listenAddr := "127.0.0.1:18123"
listenAddr := generateListenerAddress(t)
listenConfig := fmt.Sprintf(`
listener "tcp" {
address = "%s"
Expand Down Expand Up @@ -2021,14 +2033,15 @@ func TestAgent_Metrics(t *testing.T) {
serverClient := cluster.Cores[0].Client

// Create a config file
config := `
listenAddr := generateListenerAddress(t)
config := fmt.Sprintf(`
cache {}
listener "tcp" {
address = "127.0.0.1:8101"
address = "%s"
tls_disable = true
}
`
`, listenAddr)
configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)

Expand Down Expand Up @@ -2062,7 +2075,7 @@ listener "tcp" {
}()

conf := api.DefaultConfig()
conf.Address = "http://127.0.0.1:8101"
conf.Address = "http://" + listenAddr
agentClient, err := api.NewClient(conf)
if err != nil {
t.Fatalf("err: %s", err)
Expand All @@ -2082,3 +2095,133 @@ listener "tcp" {
"Points",
})
}

func TestAgent_Quit(t *testing.T) {
//----------------------------------------------------
// Start the server and agent
//----------------------------------------------------
logger := logging.NewVaultLogger(hclog.Error)
cluster := vault.NewTestCluster(t,
&vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
LogicalBackends: map[string]logical.Factory{
"kv": logicalKv.Factory,
},
},
&vault.TestClusterOptions{
NumCores: 1,
})
cluster.Start()
defer cluster.Cleanup()

vault.TestWaitActive(t, cluster.Cores[0].Core)
serverClient := cluster.Cores[0].Client

// Unset the environment variable so that agent picks up the right test
// cluster address
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
err := os.Unsetenv(api.EnvVaultAddress)
if err != nil {
t.Fatal(err)
}

listenAddr := generateListenerAddress(t)
listenAddr2 := generateListenerAddress(t)
config := fmt.Sprintf(`
vault {
address = "%s"
tls_skip_verify = true
}
listener "tcp" {
address = "%s"
tls_disable = true
}
listener "tcp" {
address = "%s"
tls_disable = true
agent_api {
enable_quit = true
}
}
cache {}
`, serverClient.Address(), listenAddr, listenAddr2)

configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)

// Start the agent
_, cmd := testAgentCommand(t, logger)
cmd.startedCh = make(chan struct{})

wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
cmd.Run([]string{"-config", configPath})
wg.Done()
}()

select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
client.SetToken(serverClient.Token())
client.SetMaxRetries(0)
err = client.SetAddress("http://" + listenAddr)
if err != nil {
t.Fatal(err)
}

// First try on listener 1 where the API should be disabled.
resp, err := client.RawRequest(client.NewRequest(http.MethodPost, "/agent/v1/quit"))
if err == nil {
t.Fatalf("expected error")
}
if resp != nil && resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected %d but got: %d", http.StatusNotFound, resp.StatusCode)
}

// Now try on listener 2 where the quit API should be enabled.
err = client.SetAddress("http://" + listenAddr2)
if err != nil {
t.Fatal(err)
}

_, err = client.RawRequest(client.NewRequest(http.MethodPost, "/agent/v1/quit"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

select {
case <-cmd.ShutdownCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}

wg.Wait()
}

// Get a randomly assigned port and then free it again before returning it.
// There is still a race when trying to use it, but should work better
// than a static port.
func generateListenerAddress(t *testing.T) string {
t.Helper()

ln1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
listenAddr := ln1.Addr().String()
ln1.Close()
return listenAddr
}
38 changes: 22 additions & 16 deletions command/server/config_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,22 +780,25 @@ func testConfig_Sanitized(t *testing.T) {
func testParseListeners(t *testing.T) {
obj, _ := hcl.Parse(strings.TrimSpace(`
listener "tcp" {
address = "127.0.0.1:443"
cluster_address = "127.0.0.1:8201"
tls_disable = false
tls_cert_file = "./certs/server.crt"
tls_key_file = "./certs/server.key"
tls_client_ca_file = "./certs/rootca.crt"
tls_min_version = "tls12"
tls_max_version = "tls13"
tls_require_and_verify_client_cert = true
tls_disable_client_certs = true
telemetry {
unauthenticated_metrics_access = true
}
profiling {
unauthenticated_pprof_access = true
}
address = "127.0.0.1:443"
cluster_address = "127.0.0.1:8201"
tls_disable = false
tls_cert_file = "./certs/server.crt"
tls_key_file = "./certs/server.key"
tls_client_ca_file = "./certs/rootca.crt"
tls_min_version = "tls12"
tls_max_version = "tls13"
tls_require_and_verify_client_cert = true
tls_disable_client_certs = true
telemetry {
unauthenticated_metrics_access = true
}
profiling {
unauthenticated_pprof_access = true
}
agent_api {
enable_quit = true
}
}`))

config := Config{
Expand Down Expand Up @@ -833,6 +836,9 @@ listener "tcp" {
Profiling: configutil.ListenerProfiling{
UnauthenticatedPProfAccess: true,
},
AgentAPI: &configutil.AgentAPI{
EnableQuit: true,
},
CustomResponseHeaders: DefaultCustomHeaders,
},
},
Expand Down
7 changes: 7 additions & 0 deletions internalshared/configutil/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ type Listener struct {
SocketUser string `hcl:"socket_user"`
SocketGroup string `hcl:"socket_group"`

AgentAPI *AgentAPI `hcl:"agent_api"`

Telemetry ListenerTelemetry `hcl:"telemetry"`
Profiling ListenerProfiling `hcl:"profiling"`
InFlightRequestLogging ListenerInFlightRequestLogging `hcl:"inflight_requests_logging"`
Expand All @@ -111,6 +113,11 @@ type Listener struct {
CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers"`
}

// AgentAPI allows users to select which parts of the Agent API they want enabled.
type AgentAPI struct {
EnableQuit bool `hcl:"enable_quit"`
}

func (l *Listener) GoString() string {
return fmt.Sprintf("*%#v", *l)
}
Expand Down
Loading

0 comments on commit e9df7a6

Please sign in to comment.