Skip to content

Commit

Permalink
Allow TLS cipher suites to be set for the OPA server
Browse files Browse the repository at this point in the history
This change adds a new flag to `opa run` to allow
users to specify a list of enabled TLS 1.0–1.2 cipher
suites. This allows users to control the cipher suites
the OPA server supports during a TLS handshake.

Signed-off-by: Ashutosh Narkar <anarkar4387@gmail.com>
  • Loading branch information
ashutosh-narkar committed Jan 22, 2024
1 parent c0589c1 commit 64e4115
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 36 deletions.
37 changes: 37 additions & 0 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type runCmdParams struct {
skipBundleVerify bool
skipKnownSchemaCheck bool
excludeVerifyFiles []string
cipherSuites []string
}

func newRunParams() runCmdParams {
Expand Down Expand Up @@ -181,6 +182,9 @@ be expanded in the future. To disable this, use the --skip-known-schema-check fl
The --v1-compatible flag can be used to opt-in to OPA features and behaviors that will be enabled by default in a future OPA v1.0 release.
Current behaviors enabled by this flag include:
- setting OPA's listening address to "localhost:8181" by default.
The --tls-cipher-suites flag can be used to specify the list of enabled TLS 1.0–1.2 cipher suites. Note that TLS 1.3
cipher suites are not configurable. See https://godoc.org/crypto/tls#pkg-constants for supported cipher suites.
`,

Run: func(cmd *cobra.Command, args []string) {
Expand Down Expand Up @@ -221,6 +225,7 @@ Current behaviors enabled by this flag include:
runCommand.Flags().IntVar(&cmdParams.rt.GracefulShutdownPeriod, "shutdown-grace-period", 10, "set the time (in seconds) that the server will wait to gracefully shut down")
runCommand.Flags().IntVar(&cmdParams.rt.ShutdownWaitPeriod, "shutdown-wait-period", 0, "set the time (in seconds) that the server will wait before initiating shutdown")
runCommand.Flags().BoolVar(&cmdParams.skipKnownSchemaCheck, "skip-known-schema-check", false, "disables type checking on known input schemas")
runCommand.Flags().StringSliceVar(&cmdParams.cipherSuites, "tls-cipher-suites", []string{}, "set list of enabled TLS 1.0–1.2 cipher suites")
addConfigOverrides(runCommand.Flags(), &cmdParams.rt.ConfigOverrides)
addConfigOverrideFiles(runCommand.Flags(), &cmdParams.rt.ConfigOverrideFiles)
addBundleModeFlag(runCommand.Flags(), &cmdParams.rt.BundleMode, false)
Expand Down Expand Up @@ -332,6 +337,15 @@ func initRuntime(ctx context.Context, params runCmdParams, args []string, addrSe

params.rt.SkipKnownSchemaCheck = params.skipKnownSchemaCheck

if len(params.cipherSuites) > 0 {
cipherSuites, err := verifyCipherSuites(params.cipherSuites)
if err != nil {
return nil, err
}

params.rt.CipherSuites = cipherSuites
}

rt, err := runtime.NewRuntime(ctx, params.rt)
if err != nil {
return nil, err
Expand All @@ -355,6 +369,29 @@ func startRuntime(ctx context.Context, rt *runtime.Runtime, serverMode bool) {
}
}

func verifyCipherSuites(cipherSuites []string) (*[]uint16, error) {
cipherSuitesMap := map[string]uint16{}

for _, c := range tls.CipherSuites() {
cipherSuitesMap[c.Name] = c.ID
}

for _, c := range tls.InsecureCipherSuites() {
cipherSuitesMap[c.Name] = c.ID
}

cipherSuitesIds := []uint16{}
for _, c := range cipherSuites {
val, ok := cipherSuitesMap[c]
if !ok {
return nil, fmt.Errorf("invalid cipher suite %v", c)
}
cipherSuitesIds = append(cipherSuitesIds, val)
}

return &cipherSuitesIds, nil
}

func historyPath() string {
home := os.Getenv("HOME")
if len(home) == 0 {
Expand Down
42 changes: 42 additions & 0 deletions cmd/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package cmd
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -149,6 +151,46 @@ func TestInitRuntimeVerifyNonBundle(t *testing.T) {
}
}

func TestInitRuntimeCipherSuites(t *testing.T) {

params := newTestRunParams()

// no cipher suites
rt, err := initRuntime(context.Background(), params, nil, false)
if err != nil {
t.Fatal(err)
}

if len(params.cipherSuites) != 0 || rt.Params.CipherSuites != nil {
t.Fatal("expected no value defined for cipher suites")
}

// secure and insecure cipher suites
params.cipherSuites = []string{"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_RC4_128_SHA"}
rt, err = initRuntime(context.Background(), params, nil, false)
if err != nil {
t.Fatal(err)
}

if rt.Params.CipherSuites == nil {
t.Fatal("expected value defined for cipher suites")
}

expectedCipherSuitesIds := []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, tls.TLS_RSA_WITH_RC4_128_SHA}

if !reflect.DeepEqual(*rt.Params.CipherSuites, expectedCipherSuitesIds) {
t.Fatalf("expected cipher suites %v but got %v", expectedCipherSuitesIds, *rt.Params.CipherSuites)
}

// invalid cipher suites
params.cipherSuites = []string{"foo"}

_, err = initRuntime(context.Background(), params, nil, false)
if err == nil {
t.Fatal("expected error but got nil")
}
}

func TestInitRuntimeSkipKnownSchemaCheck(t *testing.T) {

fs := map[string]string{
Expand Down
4 changes: 4 additions & 0 deletions runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ type Params struct {
// This flag allows users to opt-in to the new behavior and helps transition to the future release upon which
// the new behavior will be enabled by default.
V1Compatible bool

// CipherSuites specifies the list of enabled TLS 1.0–1.2 cipher suites
CipherSuites *[]uint16
}

// LoggingConfig stores the configuration for OPA's logging behaviour.
Expand Down Expand Up @@ -550,6 +553,7 @@ func (rt *Runtime) Serve(ctx context.Context) error {
WithRuntime(rt.Manager.Info).
WithMetrics(rt.metrics).
WithMinTLSVersion(rt.Params.MinTLSVersion).
WithCipherSuites(rt.Params.CipherSuites).
WithDistributedTracingOpts(rt.Params.DistributedTracingOpts)

// If decision_logging plugin enabled, check to see if we opted in to the ND builtins cache.
Expand Down
65 changes: 39 additions & 26 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ type Server struct {
distributedTracingOpts tracing.Options
ndbCacheEnabled bool
unixSocketPerm *string
cipherSuites *[]uint16
}

// Metrics defines the interface that the server requires for recording HTTP
Expand Down Expand Up @@ -400,6 +401,12 @@ func (s *Server) WithNDBCacheEnabled(ndbCacheEnabled bool) *Server {
return s
}

// WithCipherSuites sets the list of enabled TLS 1.0–1.2 cipher suites.
func (s *Server) WithCipherSuites(cipherSuites *[]uint16) *Server {
s.cipherSuites = cipherSuites
return s
}

// WithUnixSocketPermission sets the permission for the Unix domain socket if used to listen for
// incoming connections. Applies to the sockets the server is listening on including diagnostic API's.
func (s *Server) WithUnixSocketPermission(unixSocketPerm *string) *Server {
Expand Down Expand Up @@ -635,38 +642,44 @@ func (s *Server) getListenerForHTTPSServer(u *url.URL, h http.Handler, t httpLis
return nil, nil, fmt.Errorf("TLS certificate required but not supplied")
}

httpsServer := http.Server{
Addr: u.Host,
Handler: h,
TLSConfig: &tls.Config{
GetCertificate: s.getCertificate,
// GetConfigForClient is used to ensure that a fresh config is provided containing the latest cert pool.
// This is not required, but appears to be how connect time updates config should be done:
// https://github.com/golang/go/issues/16066#issuecomment-250606132
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
s.tlsConfigMtx.Lock()
defer s.tlsConfigMtx.Unlock()

cfg := &tls.Config{
GetCertificate: s.getCertificate,
ClientCAs: s.certPool,
}
tlsConfig := tls.Config{
GetCertificate: s.getCertificate,
// GetConfigForClient is used to ensure that a fresh config is provided containing the latest cert pool.
// This is not required, but appears to be how connect time updates config should be done:
// https://github.com/golang/go/issues/16066#issuecomment-250606132
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
s.tlsConfigMtx.Lock()
defer s.tlsConfigMtx.Unlock()

if s.authentication == AuthenticationTLS {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
cfg := &tls.Config{
GetCertificate: s.getCertificate,
ClientCAs: s.certPool,
}

if s.minTLSVersion != 0 {
cfg.MinVersion = s.minTLSVersion
} else {
cfg.MinVersion = defaultMinTLSVersion
}
if s.authentication == AuthenticationTLS {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}

return cfg, nil
},
if s.minTLSVersion != 0 {
cfg.MinVersion = s.minTLSVersion
} else {
cfg.MinVersion = defaultMinTLSVersion
}

if s.cipherSuites != nil {
cfg.CipherSuites = *s.cipherSuites
}

return cfg, nil
},
}

httpsServer := http.Server{
Addr: u.Host,
Handler: h,
TLSConfig: &tlsConfig,
}

l := newHTTPListener(&httpsServer, t)

httpsLoop := func() error { return l.ListenAndServeTLS("", "") }
Expand Down
55 changes: 45 additions & 10 deletions test/e2e/tls/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -88,23 +89,52 @@ allow {
testServerParams.MinTLSVersion = TLSVersion
}

// RSA cipher suite given server's key is RSA
testServerParams.CipherSuites = &[]uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}

testRuntime, err = e2e.NewTestRuntime(testServerParams)
if err != nil {
fatal(err)
}

// We need a client with proper TLS setup, otherwise the health check
// that loops to determine if the server is ready will fail.
testRuntime.Client = newClient(0, pool, "testdata/client-cert.pem", "testdata/client-key.pem")
testRuntime.Client = newClient(0, pool, nil, "testdata/client-cert.pem", "testdata/client-key.pem")

os.Exit(testRuntime.RunTests(m))
}

func TestCipherSuites(t *testing.T) {
endpoint := testRuntime.URL()
t.Run("Cipher suite supported by both client and server", func(t *testing.T) {

c := newClient(tls.VersionTLS12, pool, &[]uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, "testdata/client-cert.pem", "testdata/client-key.pem")
_, err := c.Get(endpoint)
if err != nil {
t.Fatal(err)
}
})
t.Run("No cipher suite supported by both client and server", func(t *testing.T) {

// Since server's key is RSA, client specifying an ECDSA cipher suite should result in an error
c := newClient(tls.VersionTLS12, pool, &[]uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, "testdata/client-cert.pem", "testdata/client-key.pem")
_, err := c.Get(endpoint)
if err == nil {
t.Error("expected err - no cipher suite supported by both client and server, got nil")
}

expErr := "tls: handshake failure"
if !strings.Contains(err.Error(), expErr) {
t.Fatalf("unexpected error message %v", err)
}
})
}

func TestMinTLSVersion(t *testing.T) {
endpoint := testRuntime.URL()
t.Run("TLS version not suported by server", func(t *testing.T) {
t.Run("TLS version not supported by server", func(t *testing.T) {

c := newClient(tls.VersionTLS10, pool, "testdata/client-cert.pem", "testdata/client-key.pem")
c := newClient(tls.VersionTLS10, pool, nil, "testdata/client-cert.pem", "testdata/client-key.pem")
_, err := c.Get(endpoint)

if err == nil {
Expand All @@ -114,7 +144,7 @@ func TestMinTLSVersion(t *testing.T) {
})
t.Run("TLS Version supported by server", func(t *testing.T) {

c := newClient(tls.VersionTLS12, pool, "testdata/client-cert.pem", "testdata/client-key.pem")
c := newClient(tls.VersionTLS12, pool, nil, "testdata/client-cert.pem", "testdata/client-key.pem")
resp, err := c.Get(endpoint)
if err != nil {
t.Fatalf("GET: %v", err)
Expand All @@ -134,7 +164,7 @@ func TestNotDefaultTLSVersion(t *testing.T) {
endpoint := testRuntime.URL()
t.Run("server started with min TLS Version 1.3, client connecting with not supported TLS version", func(t *testing.T) {

c := newClient(tls.VersionTLS10, pool, "testdata/client-cert.pem", "testdata/client-key.pem")
c := newClient(tls.VersionTLS10, pool, nil, "testdata/client-cert.pem", "testdata/client-key.pem")
_, err := c.Get(endpoint)

if err == nil {
Expand All @@ -148,7 +178,7 @@ func TestNotDefaultTLSVersion(t *testing.T) {

t.Run("server started with min TLS Version 1.3, client connecting supported TLS version", func(t *testing.T) {

c := newClient(tls.VersionTLS13, pool, "testdata/client-cert.pem", "testdata/client-key.pem")
c := newClient(tls.VersionTLS13, pool, nil, "testdata/client-cert.pem", "testdata/client-key.pem")
resp, err := c.Get(endpoint)
if err != nil {
t.Fatalf("GET: %v", err)
Expand All @@ -168,7 +198,7 @@ func TestAuthenticationTLS(t *testing.T) {
// already queries the health endpoint using a properly authenticated, and
// authorized, http client.
t.Run("happy path", func(t *testing.T) {
c := newClient(0, pool, "testdata/client-cert.pem", "testdata/client-key.pem")
c := newClient(0, pool, nil, "testdata/client-cert.pem", "testdata/client-key.pem")
resp, err := c.Get(endpoint)
if err != nil {
t.Fatalf("GET: %v", err)
Expand All @@ -180,7 +210,7 @@ func TestAuthenticationTLS(t *testing.T) {
})

t.Run("authn successful, authz failed", func(t *testing.T) {
c := newClient(0, pool, "testdata/client-cert-2.pem", "testdata/client-key-2.pem")
c := newClient(0, pool, nil, "testdata/client-cert-2.pem", "testdata/client-key-2.pem")
resp, err := c.Get(endpoint)
if err != nil {
t.Fatalf("GET: %v", err)
Expand All @@ -192,15 +222,15 @@ func TestAuthenticationTLS(t *testing.T) {
})

t.Run("client trusts server, but doesn't provide client cert", func(t *testing.T) {
c := newClient(0, pool)
c := newClient(0, pool, nil)
_, err := c.Get(endpoint)
if _, ok := err.(*url.Error); !ok {
t.Errorf("expected *url.Error, got %T: %v", err, err)
}
})
}

func newClient(maxTLSVersion uint16, pool *x509.CertPool, clientKeyPair ...string) *http.Client {
func newClient(maxTLSVersion uint16, pool *x509.CertPool, cipherSuites *[]uint16, clientKeyPair ...string) *http.Client {
c := *http.DefaultClient
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{
Expand All @@ -217,6 +247,11 @@ func newClient(maxTLSVersion uint16, pool *x509.CertPool, clientKeyPair ...strin
if maxTLSVersion != 0 {
tr.TLSClientConfig.MaxVersion = maxTLSVersion
}

if cipherSuites != nil {
tr.TLSClientConfig.CipherSuites = *cipherSuites
}

c.Transport = tr
return &c
}

0 comments on commit 64e4115

Please sign in to comment.