Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1.10.x] rpc: authorize raft requests #10931

Merged
merged 10 commits into from
Aug 26, 2021
3 changes: 3 additions & 0 deletions .changelog/10931.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:security
rpc: authorize raft requests (CVE-2021-37219)
eculver marked this conversation as resolved.
Show resolved Hide resolved
```
3 changes: 2 additions & 1 deletion agent/consul/raft_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (
"sync"
"time"

"github.com/hashicorp/raft"

"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/raft"
)

// RaftLayer implements the raft.StreamLayer interface,
Expand Down
20 changes: 16 additions & 4 deletions agent/consul/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) {
s.handleConsulConn(conn)

case pool.RPCRaft:
metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
s.raftLayer.Handoff(conn)
s.handleRaftRPC(conn)

case pool.RPCTLS:
// Don't allow malicious client to create TLS-in-TLS for ever.
Expand Down Expand Up @@ -289,8 +288,7 @@ func (s *Server) handleNativeTLS(conn net.Conn) {
s.handleConsulConn(tlsConn)

case pool.ALPN_RPCRaft:
metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
s.raftLayer.Handoff(tlsConn)
s.handleRaftRPC(tlsConn)

case pool.ALPN_RPCMultiplexV2:
s.handleMultiplexV2(tlsConn)
Expand Down Expand Up @@ -461,6 +459,20 @@ func (s *Server) handleSnapshotConn(conn net.Conn) {
}()
}

func (s *Server) handleRaftRPC(conn net.Conn) {
if tlsConn, ok := conn.(*tls.Conn); ok {
err := s.tlsConfigurator.AuthorizeServerConn(s.config.Datacenter, tlsConn)
if err != nil {
s.rpcLogger().Warn(err.Error(), "from", conn.RemoteAddr(), "operation", "raft RPC")
conn.Close()
return
}
}

metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
s.raftLayer.Handoff(conn)
}

func (s *Server) handleALPN_WANGossipPacketStream(conn net.Conn) error {
defer conn.Close()

Expand Down
278 changes: 274 additions & 4 deletions agent/consul/rpc_test.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
package consul

import (
"bufio"
"bytes"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"

"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-msgpack/codec"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/hashicorp/consul/agent/connect"

"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/structs"
tokenStore "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/tlsutil"
)

func TestRPC_NoLeader_Fail(t *testing.T) {
Expand Down Expand Up @@ -683,10 +694,10 @@ func TestRPC_RPCMaxConnsPerClient(t *testing.T) {
magicByte pool.RPCType
tlsEnabled bool
}{
{"RPC", pool.RPCMultiplexV2, false},
{"RPC TLS", pool.RPCMultiplexV2, true},
{"Raft", pool.RPCRaft, false},
{"Raft TLS", pool.RPCRaft, true},
{"RPC v2", pool.RPCMultiplexV2, false},
{"RPC v2 TLS", pool.RPCMultiplexV2, true},
{"RPC", pool.RPCConsul, false},
{"RPC TLS", pool.RPCConsul, true},
}

for _, tc := range cases {
Expand Down Expand Up @@ -1011,3 +1022,262 @@ type isReadRequest struct {
func (r isReadRequest) IsRead() bool {
return true
}

func TestRPC_AuthorizeRaftRPC(t *testing.T) {
caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "consul"})
require.NoError(t, err)

dir := testutil.TempDir(t, "certs")
err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600)
require.NoError(t, err)

newCert := func(t *testing.T, caPEM, pk, node, name string) {
t.Helper()

signer, err := tlsutil.ParseSigner(pk)
require.NoError(t, err)

pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
Signer: signer,
CA: caPEM,
Name: name,
Days: 5,
DNSNames: []string{node + "." + name, name, "localhost"},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
})
require.NoError(t, err)

err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600)
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600)
require.NoError(t, err)
}

newCert(t, caPEM, pk, "srv1", "server.dc1.consul")

_, connectCApk, err := connect.GeneratePrivateKey()
require.NoError(t, err)

_, srv := testServerWithConfig(t, func(c *Config) {
c.Domain = "consul." // consul. is the default value in agent/config
c.CAFile = filepath.Join(dir, "ca.pem")
c.CertFile = filepath.Join(dir, "srv1-server.dc1.consul.pem")
c.KeyFile = filepath.Join(dir, "srv1-server.dc1.consul.key")
c.VerifyIncoming = true
c.VerifyServerHostname = true
// Enable Auto-Encrypt so that Conenct CA roots are added to the
// tlsutil.Configurator.
c.AutoEncryptAllowTLS = true
c.CAConfig = &structs.CAConfiguration{
ClusterID: connect.TestClusterID,
Provider: structs.ConsulCAProvider,
Config: map[string]interface{}{"PrivateKey": connectCApk},
}
})
defer srv.Shutdown()

// Wait for ConnectCA initiation to complete.
retry.Run(t, func(r *retry.R) {
_, root := srv.caManager.getCAProvider()
if root == nil {
r.Fatal("ConnectCA root is still nil")
}
})

useTLSByte := func(t *testing.T, c *tlsutil.Configurator) net.Conn {
wrapper := tlsutil.SpecificDC("dc1", c.OutgoingRPCWrapper())
tlsEnabled := func(_ raft.ServerAddress) bool {
return true
}

rl := NewRaftLayer(nil, nil, wrapper, tlsEnabled)
conn, err := rl.Dial(raft.ServerAddress(srv.Listener.Addr().String()), 100*time.Millisecond)
require.NoError(t, err)
return conn
}

useNativeTLS := func(t *testing.T, c *tlsutil.Configurator) net.Conn {
wrapper := c.OutgoingALPNRPCWrapper()
dialer := &net.Dialer{Timeout: 100 * time.Millisecond}

rawConn, err := dialer.Dial("tcp", srv.Listener.Addr().String())
require.NoError(t, err)

tlsConn, err := wrapper("dc1", "srv1", pool.ALPN_RPCRaft, rawConn)
require.NoError(t, err)
return tlsConn
}

setupAgentTLSCert := func(name string) func(t *testing.T) string {
return func(t *testing.T) string {
newCert(t, caPEM, pk, "node1", name)
return filepath.Join(dir, "node1-"+name)
}
}

setupConnectCACert := func(name string) func(t *testing.T) string {
return func(t *testing.T) string {
_, caRoot := srv.caManager.getCAProvider()
newCert(t, caRoot.RootCert, connectCApk, "node1", name)
return filepath.Join(dir, "node1-"+name)
}
}

type testCase struct {
name string
conn func(t *testing.T, c *tlsutil.Configurator) net.Conn
setupCert func(t *testing.T) string
expectError bool
}

run := func(t *testing.T, tc testCase) {
certPath := tc.setupCert(t)

cfg := tlsutil.Config{
VerifyOutgoing: true,
VerifyServerHostname: true,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
Domain: "consul",
}
c, err := tlsutil.NewConfigurator(cfg, hclog.New(nil))
require.NoError(t, err)

_, err = doRaftRPC(tc.conn(t, c), srv.config.NodeName)
if tc.expectError {
if !isConnectionClosedError(err) {
t.Fatalf("expected a connection closed error, got: %v", err)
}
return
}
require.NoError(t, err)
}

var testCases = []testCase{
{
name: "TLS byte with client cert",
setupCert: setupAgentTLSCert("client.dc1.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "TLS byte with server cert in different DC",
setupCert: setupAgentTLSCert("server.dc2.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "TLS byte with server cert in same DC",
setupCert: setupAgentTLSCert("server.dc1.consul"),
conn: useTLSByte,
},
{
name: "TLS byte with ConnectCA leaf cert",
setupCert: setupConnectCACert("server.dc1.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "native TLS with client cert",
setupCert: setupAgentTLSCert("client.dc1.consul"),
conn: useNativeTLS,
expectError: true,
},
{
name: "native TLS with server cert in different DC",
setupCert: setupAgentTLSCert("server.dc2.consul"),
conn: useNativeTLS,
expectError: true,
},
{
name: "native TLS with server cert in same DC",
setupCert: setupAgentTLSCert("server.dc1.consul"),
conn: useNativeTLS,
},
{
name: "native TLS with ConnectCA leaf cert",
setupCert: setupConnectCACert("server.dc1.consul"),
conn: useNativeTLS,
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}

func doRaftRPC(conn net.Conn, leader string) (raft.AppendEntriesResponse, error) {
var resp raft.AppendEntriesResponse

var term uint64 = 0xc
a := raft.AppendEntriesRequest{
RPCHeader: raft.RPCHeader{ProtocolVersion: 3},
Term: 0,
Leader: []byte(leader),
PrevLogEntry: 0,
PrevLogTerm: term,
LeaderCommitIndex: 50,
}

if err := appendEntries(conn, a, &resp); err != nil {
return resp, err
}
return resp, nil
}

func appendEntries(conn net.Conn, req raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) error {
w := bufio.NewWriter(conn)
enc := codec.NewEncoder(w, &codec.MsgpackHandle{})

const rpcAppendEntries = 0
if err := w.WriteByte(rpcAppendEntries); err != nil {
return fmt.Errorf("failed to write raft-RPC byte: %w", err)
}

if err := enc.Encode(req); err != nil {
return fmt.Errorf("failed to send append entries RPC: %w", err)
}
if err := w.Flush(); err != nil {
return fmt.Errorf("failed to flush RPC: %w", err)
}

if err := decodeRaftRPCResponse(conn, resp); err != nil {
return fmt.Errorf("response error: %w", err)
}
return nil
}

// copied and modified from raft/net_transport.go
func decodeRaftRPCResponse(conn net.Conn, resp *raft.AppendEntriesResponse) error {
r := bufio.NewReader(conn)
dec := codec.NewDecoder(r, &codec.MsgpackHandle{})

var rpcError string
if err := dec.Decode(&rpcError); err != nil {
return fmt.Errorf("failed to decode response error: %w", err)
}
if err := dec.Decode(resp); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
if rpcError != "" {
return fmt.Errorf("rpc error: %v", rpcError)
}
return nil
}

func isConnectionClosedError(err error) bool {
switch {
case err == nil:
return false
case errors.Is(err, io.EOF):
return true
case strings.Contains(err.Error(), "connection reset by peer"):
return true
default:
return false
}
}
Loading