Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vault Agent Cache Auto-Auth SSRF Protection #7627

Merged
merged 29 commits into from
Oct 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
853a109
implement SSRF protection header
mjarmy Oct 3, 2019
84e3086
add test for SSRF protection header
mjarmy Oct 4, 2019
b499b32
cleanup
mjarmy Oct 4, 2019
6f07e18
refactor
mjarmy Oct 8, 2019
7044ac9
merge from master
mjarmy Oct 8, 2019
9adadc5
implement SSRF header on a per-listener basis
mjarmy Oct 9, 2019
d10eef8
cleanup
mjarmy Oct 9, 2019
0e5f902
cleanup
mjarmy Oct 9, 2019
4a03773
creat unit test for agent SSRF
mjarmy Oct 10, 2019
4ca14d8
improve unit test for agent SSRF
mjarmy Oct 10, 2019
0491fec
add VaultRequest SSRF header to CLI
mjarmy Oct 10, 2019
91e22a5
merge from master
mjarmy Oct 10, 2019
2125a20
fix unit test
mjarmy Oct 10, 2019
467698c
cleanup
mjarmy Oct 10, 2019
199ff09
merge from master
mjarmy Oct 11, 2019
a915b64
improve test suite
mjarmy Oct 11, 2019
34ff6dd
simplify check for Vault-Request header
mjarmy Oct 11, 2019
6d9c4a1
add constant for Vault-Request header
mjarmy Oct 11, 2019
22e199e
improve test suite
mjarmy Oct 11, 2019
14ee72d
change 'config' to 'agentConfig'
mjarmy Oct 11, 2019
5ea8878
Revert "change 'config' to 'agentConfig'"
mjarmy Oct 11, 2019
3e6acbc
do not remove header from request
mjarmy Oct 11, 2019
15ca067
change header name to X-Vault-Request
mjarmy Oct 11, 2019
44b75af
merge from master
mjarmy Oct 11, 2019
12c8c2b
simplify http.Handler logic
mjarmy Oct 11, 2019
335c5ec
cleanup
mjarmy Oct 11, 2019
ca26eba
simplify http.Handler logic
mjarmy Oct 11, 2019
14559e5
use stdlib errors package
mjarmy Oct 11, 2019
28e9dbc
merge from master
mjarmy Oct 11, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,14 @@ func NewClient(c *Config) (*Client, error) {
}

client := &Client{
addr: u,
config: c,
addr: u,
config: c,
headers: make(http.Header),
}

// Add the VaultRequest SSRF protection header
client.headers[consts.RequestHeaderName] = []string{"true"}

if token := os.Getenv(EnvVaultToken); token != "" {
client.token = token
}
Expand Down
48 changes: 41 additions & 7 deletions command/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package command

import (
"context"
"errors"
"flag"
"fmt"
"io"
Expand All @@ -28,13 +29,14 @@ import (
"github.com/hashicorp/vault/command/agent/auth/jwt"
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
"github.com/hashicorp/vault/command/agent/cache"
"github.com/hashicorp/vault/command/agent/config"
agentConfig "github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/command/agent/sink/inmem"
gatedwriter "github.com/hashicorp/vault/helper/gated-writer"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/version"
"github.com/kr/pretty"
"github.com/mitchellh/cli"
Expand Down Expand Up @@ -192,7 +194,7 @@ func (c *AgentCommand) Run(args []string) int {
}

// Load the configuration
config, err := config.LoadConfig(c.flagConfigs[0])
config, err := agentConfig.LoadConfig(c.flagConfigs[0])
if err != nil {
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
return 1
Expand Down Expand Up @@ -418,11 +420,8 @@ func (c *AgentCommand) Run(args []string) int {
})
}

// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))

mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink))
// Create the request handler
cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink)

var listeners []net.Listener
for i, lnConfig := range config.Listeners {
Expand All @@ -434,6 +433,25 @@ func (c *AgentCommand) Run(args []string) int {

listeners = append(listeners, ln)

// Parse 'require_request_header' listener config option, and wrap
// the request handler if necessary
muxHandler := cacheHandler
if v, ok := lnConfig.Config[agentConfig.RequireRequestHeader]; ok {
switch v {
case true:
muxHandler = verifyRequestHeader(muxHandler)
case false /* noop */ :
default:
c.UI.Error(fmt.Sprintf("Invalid value for 'require_request_header': %v", v))
return 1
}
}

// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
mux.Handle("/", muxHandler)

scheme := "https://"
if tlsConf == nil {
scheme = "http://"
Expand Down Expand Up @@ -536,6 +554,22 @@ func (c *AgentCommand) Run(args []string) int {
return 0
}

// verifyRequestHeader wraps an http.Handler inside a Handler that checks for
// the request header that is used for SSRF protection.
func verifyRequestHeader(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" {
logical.RespondError(w,
http.StatusPreconditionFailed,
errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName)))
return
}

handler.ServeHTTP(w, r)
})
}

func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
var isFlagSet bool
f.Visit(func(f *flag.Flag) {
Expand Down
3 changes: 3 additions & 0 deletions command/agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type Listener struct {
Config map[string]interface{}
}

// RequireRequestHeader is a listener configuration option
const RequireRequestHeader = "require_request_header"

type AutoAuth struct {
Method *Method `hcl:"-"`
Sinks []*Sink `hcl:"sinks"`
Expand Down
250 changes: 250 additions & 0 deletions command/agent_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package command

import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"reflect"
"sync"
"testing"
"time"

hclog "github.com/hashicorp/go-hclog"
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
"github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
"github.com/hashicorp/vault/command/agent"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
Expand Down Expand Up @@ -370,3 +377,246 @@ auto_auth {
t.Fatal("sink 1/2 values don't match")
}
}

func TestAgent_RequireRequestHeader(t *testing.T) {

// request issues HTTP requests.
request := func(client *api.Client, req *api.Request, expectedStatusCode int) map[string]interface{} {
resp, err := client.RawRequest(req)
if err != nil {
t.Fatalf("err: %s", err)
}
if resp.StatusCode != expectedStatusCode {
t.Fatalf("expected status code %d, not %d", expectedStatusCode, resp.StatusCode)
}

bytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("err: %s", err)
}
if len(bytes) == 0 {
return nil
}

var body map[string]interface{}
err = json.Unmarshal(bytes, &body)
if err != nil {
t.Fatalf("err: %s", err)
}
return body
}

// makeTempFile creates a temp file and populates it.
makeTempFile := func(name, contents string) string {
f, err := ioutil.TempFile("", name)
if err != nil {
t.Fatal(err)
}
path := f.Name()
f.WriteString(contents)
f.Close()
return path
}

// newApiClient creates an *api.Client.
newApiClient := func(addr string, includeVaultRequestHeader bool) *api.Client {
conf := api.DefaultConfig()
conf.Address = addr
cli, err := api.NewClient(conf)
if err != nil {
t.Fatalf("err: %s", err)
}

h := cli.Headers()
val, ok := h[consts.RequestHeaderName]
if !ok || !reflect.DeepEqual(val, []string{"true"}) {
t.Fatalf("invalid %s header", consts.RequestHeaderName)
}
if !includeVaultRequestHeader {
delete(h, consts.RequestHeaderName)
cli.SetHeaders(h)
}

return cli
}

//----------------------------------------------------
// Start the server and agent
//----------------------------------------------------

// Start a vault server
logger := logging.NewVaultLogger(hclog.Trace)
cluster := vault.NewTestCluster(t,
&vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
},
&vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
serverClient := cluster.Cores[0].Client

// Enable the approle auth method
req := serverClient.NewRequest("POST", "/v1/sys/auth/approle")
req.BodyBytes = []byte(`{
"type": "approle"
}`)
request(serverClient, req, 204)

// Create a named role
req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role")
req.BodyBytes = []byte(`{
"secret_id_num_uses": "10",
"secret_id_ttl": "1m",
"token_max_ttl": "1m",
"token_num_uses": "10",
"token_ttl": "1m"
}`)
request(serverClient, req, 204)

// Fetch the RoleID of the named role
req = serverClient.NewRequest("GET", "/v1/auth/approle/role/test-role/role-id")
body := request(serverClient, req, 200)
data := body["data"].(map[string]interface{})
roleID := data["role_id"].(string)

// Get a SecretID issued against the named role
req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role/secret-id")
body = request(serverClient, req, 200)
data = body["data"].(map[string]interface{})
secretID := data["secret_id"].(string)

// Write the RoleID and SecretID to temp files
roleIDPath := makeTempFile("role_id.txt", roleID+"\n")
secretIDPath := makeTempFile("secret_id.txt", secretID+"\n")
defer os.Remove(roleIDPath)
defer os.Remove(secretIDPath)

// Get a temp file path we can use for the sink
sinkPath := makeTempFile("sink.txt", "")
defer os.Remove(sinkPath)

// Create a config file
config := `
auto_auth {
method "approle" {
mount_path = "auth/approle"
config = {
role_id_file_path = "%s"
secret_id_file_path = "%s"
}
}

sink "file" {
config = {
path = "%s"
}
}
}

cache {
use_auto_auth_token = true
}

listener "tcp" {
address = "127.0.0.1:8101"
tls_disable = true
}
listener "tcp" {
address = "127.0.0.1:8102"
tls_disable = true
require_request_header = false
}
listener "tcp" {
address = "127.0.0.1:8103"
tls_disable = true
require_request_header = true
}
`
config = fmt.Sprintf(config, roleIDPath, secretIDPath, sinkPath)
configPath := makeTempFile("config.hcl", config)
defer os.Remove(configPath)

// Start the agent
ui, cmd := testAgentCommand(t, logger)
cmd.client = serverClient
cmd.startedCh = make(chan struct{})

wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
code := cmd.Run([]string{"-config", configPath})
if code != 0 {
t.Errorf("non-zero return code when running agent: %d", code)
t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String())
t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String())
}
wg.Done()
}()

select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}

// defer agent shutdown
defer func() {
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}()

//----------------------------------------------------
// Perform the tests
//----------------------------------------------------

// Test against a listener configuration that omits
// 'require_request_header', with the header missing from the request.
agentClient := newApiClient("http://127.0.0.1:8101", false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(agentClient, req, 200)

// Test against a listener configuration that sets 'require_request_header'
// to 'false', with the header missing from the request.
agentClient = newApiClient("http://127.0.0.1:8102", false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(agentClient, req, 200)

// Test against a listener configuration that sets 'require_request_header'
// to 'true', with the header missing from the request.
agentClient = newApiClient("http://127.0.0.1:8103", false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
resp, err := agentClient.RawRequest(req)
if err == nil {
t.Fatalf("expected error")
}
if resp.StatusCode != http.StatusPreconditionFailed {
t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode)
}

// Test against a listener configuration that sets 'require_request_header'
// to 'true', with an invalid header present in the request.
agentClient = newApiClient("http://127.0.0.1:8103", false)
h := agentClient.Headers()
h[consts.RequestHeaderName] = []string{"bogus"}
agentClient.SetHeaders(h)
req = agentClient.NewRequest("GET", "/v1/sys/health")
resp, err = agentClient.RawRequest(req)
if err == nil {
t.Fatalf("expected error")
}
if resp.StatusCode != http.StatusPreconditionFailed {
t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode)
}

// Test against a listener configuration that sets 'require_request_header'
// to 'true', with the proper header present in the request.
agentClient = newApiClient("http://127.0.0.1:8103", true)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(agentClient, req, 200)
}
Loading