Skip to content

Commit

Permalink
[1.10.x] rpc: authorize raft requests (#10931)
Browse files Browse the repository at this point in the history
  • Loading branch information
eculver committed Aug 26, 2021
1 parent b7a4fe0 commit 3357e57
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 33 deletions.
3 changes: 3 additions & 0 deletions .changelog/10931.txt
@@ -0,0 +1,3 @@
```release-note:security
rpc: authorize raft requests [CVE-2021-37219](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37219)
```
3 changes: 2 additions & 1 deletion agent/consul/raft_rpc.go
Expand Up @@ -6,9 +6,10 @@ import (
"sync"
"time"

"github.com/hashicorp/raft"

"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/raft"
)

// RaftLayer implements the raft.StreamLayer interface,
Expand Down
20 changes: 16 additions & 4 deletions agent/consul/rpc.go
Expand Up @@ -200,8 +200,7 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) {
s.handleConsulConn(conn)

case pool.RPCRaft:
metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
s.raftLayer.Handoff(conn)
s.handleRaftRPC(conn)

case pool.RPCTLS:
// Don't allow malicious client to create TLS-in-TLS for ever.
Expand Down Expand Up @@ -289,8 +288,7 @@ func (s *Server) handleNativeTLS(conn net.Conn) {
s.handleConsulConn(tlsConn)

case pool.ALPN_RPCRaft:
metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
s.raftLayer.Handoff(tlsConn)
s.handleRaftRPC(tlsConn)

case pool.ALPN_RPCMultiplexV2:
s.handleMultiplexV2(tlsConn)
Expand Down Expand Up @@ -461,6 +459,20 @@ func (s *Server) handleSnapshotConn(conn net.Conn) {
}()
}

func (s *Server) handleRaftRPC(conn net.Conn) {
if tlsConn, ok := conn.(*tls.Conn); ok {
err := s.tlsConfigurator.AuthorizeServerConn(s.config.Datacenter, tlsConn)
if err != nil {
s.rpcLogger().Warn(err.Error(), "from", conn.RemoteAddr(), "operation", "raft RPC")
conn.Close()
return
}
}

metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
s.raftLayer.Handoff(conn)
}

func (s *Server) handleALPN_WANGossipPacketStream(conn net.Conn) error {
defer conn.Close()

Expand Down
278 changes: 274 additions & 4 deletions agent/consul/rpc_test.go
@@ -1,32 +1,43 @@
package consul

import (
"bufio"
"bytes"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"

"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-msgpack/codec"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/hashicorp/consul/agent/connect"

"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/structs"
tokenStore "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/tlsutil"
)

func TestRPC_NoLeader_Fail(t *testing.T) {
Expand Down Expand Up @@ -683,10 +694,10 @@ func TestRPC_RPCMaxConnsPerClient(t *testing.T) {
magicByte pool.RPCType
tlsEnabled bool
}{
{"RPC", pool.RPCMultiplexV2, false},
{"RPC TLS", pool.RPCMultiplexV2, true},
{"Raft", pool.RPCRaft, false},
{"Raft TLS", pool.RPCRaft, true},
{"RPC v2", pool.RPCMultiplexV2, false},
{"RPC v2 TLS", pool.RPCMultiplexV2, true},
{"RPC", pool.RPCConsul, false},
{"RPC TLS", pool.RPCConsul, true},
}

for _, tc := range cases {
Expand Down Expand Up @@ -1011,3 +1022,262 @@ type isReadRequest struct {
func (r isReadRequest) IsRead() bool {
return true
}

func TestRPC_AuthorizeRaftRPC(t *testing.T) {
caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "consul"})
require.NoError(t, err)

dir := testutil.TempDir(t, "certs")
err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600)
require.NoError(t, err)

newCert := func(t *testing.T, caPEM, pk, node, name string) {
t.Helper()

signer, err := tlsutil.ParseSigner(pk)
require.NoError(t, err)

pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
Signer: signer,
CA: caPEM,
Name: name,
Days: 5,
DNSNames: []string{node + "." + name, name, "localhost"},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
})
require.NoError(t, err)

err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600)
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600)
require.NoError(t, err)
}

newCert(t, caPEM, pk, "srv1", "server.dc1.consul")

_, connectCApk, err := connect.GeneratePrivateKey()
require.NoError(t, err)

_, srv := testServerWithConfig(t, func(c *Config) {
c.Domain = "consul." // consul. is the default value in agent/config
c.CAFile = filepath.Join(dir, "ca.pem")
c.CertFile = filepath.Join(dir, "srv1-server.dc1.consul.pem")
c.KeyFile = filepath.Join(dir, "srv1-server.dc1.consul.key")
c.VerifyIncoming = true
c.VerifyServerHostname = true
// Enable Auto-Encrypt so that Conenct CA roots are added to the
// tlsutil.Configurator.
c.AutoEncryptAllowTLS = true
c.CAConfig = &structs.CAConfiguration{
ClusterID: connect.TestClusterID,
Provider: structs.ConsulCAProvider,
Config: map[string]interface{}{"PrivateKey": connectCApk},
}
})
defer srv.Shutdown()

// Wait for ConnectCA initiation to complete.
retry.Run(t, func(r *retry.R) {
_, root := srv.caManager.getCAProvider()
if root == nil {
r.Fatal("ConnectCA root is still nil")
}
})

useTLSByte := func(t *testing.T, c *tlsutil.Configurator) net.Conn {
wrapper := tlsutil.SpecificDC("dc1", c.OutgoingRPCWrapper())
tlsEnabled := func(_ raft.ServerAddress) bool {
return true
}

rl := NewRaftLayer(nil, nil, wrapper, tlsEnabled)
conn, err := rl.Dial(raft.ServerAddress(srv.Listener.Addr().String()), 100*time.Millisecond)
require.NoError(t, err)
return conn
}

useNativeTLS := func(t *testing.T, c *tlsutil.Configurator) net.Conn {
wrapper := c.OutgoingALPNRPCWrapper()
dialer := &net.Dialer{Timeout: 100 * time.Millisecond}

rawConn, err := dialer.Dial("tcp", srv.Listener.Addr().String())
require.NoError(t, err)

tlsConn, err := wrapper("dc1", "srv1", pool.ALPN_RPCRaft, rawConn)
require.NoError(t, err)
return tlsConn
}

setupAgentTLSCert := func(name string) func(t *testing.T) string {
return func(t *testing.T) string {
newCert(t, caPEM, pk, "node1", name)
return filepath.Join(dir, "node1-"+name)
}
}

setupConnectCACert := func(name string) func(t *testing.T) string {
return func(t *testing.T) string {
_, caRoot := srv.caManager.getCAProvider()
newCert(t, caRoot.RootCert, connectCApk, "node1", name)
return filepath.Join(dir, "node1-"+name)
}
}

type testCase struct {
name string
conn func(t *testing.T, c *tlsutil.Configurator) net.Conn
setupCert func(t *testing.T) string
expectError bool
}

run := func(t *testing.T, tc testCase) {
certPath := tc.setupCert(t)

cfg := tlsutil.Config{
VerifyOutgoing: true,
VerifyServerHostname: true,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
Domain: "consul",
}
c, err := tlsutil.NewConfigurator(cfg, hclog.New(nil))
require.NoError(t, err)

_, err = doRaftRPC(tc.conn(t, c), srv.config.NodeName)
if tc.expectError {
if !isConnectionClosedError(err) {
t.Fatalf("expected a connection closed error, got: %v", err)
}
return
}
require.NoError(t, err)
}

var testCases = []testCase{
{
name: "TLS byte with client cert",
setupCert: setupAgentTLSCert("client.dc1.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "TLS byte with server cert in different DC",
setupCert: setupAgentTLSCert("server.dc2.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "TLS byte with server cert in same DC",
setupCert: setupAgentTLSCert("server.dc1.consul"),
conn: useTLSByte,
},
{
name: "TLS byte with ConnectCA leaf cert",
setupCert: setupConnectCACert("server.dc1.consul"),
conn: useTLSByte,
expectError: true,
},
{
name: "native TLS with client cert",
setupCert: setupAgentTLSCert("client.dc1.consul"),
conn: useNativeTLS,
expectError: true,
},
{
name: "native TLS with server cert in different DC",
setupCert: setupAgentTLSCert("server.dc2.consul"),
conn: useNativeTLS,
expectError: true,
},
{
name: "native TLS with server cert in same DC",
setupCert: setupAgentTLSCert("server.dc1.consul"),
conn: useNativeTLS,
},
{
name: "native TLS with ConnectCA leaf cert",
setupCert: setupConnectCACert("server.dc1.consul"),
conn: useNativeTLS,
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}

func doRaftRPC(conn net.Conn, leader string) (raft.AppendEntriesResponse, error) {
var resp raft.AppendEntriesResponse

var term uint64 = 0xc
a := raft.AppendEntriesRequest{
RPCHeader: raft.RPCHeader{ProtocolVersion: 3},
Term: 0,
Leader: []byte(leader),
PrevLogEntry: 0,
PrevLogTerm: term,
LeaderCommitIndex: 50,
}

if err := appendEntries(conn, a, &resp); err != nil {
return resp, err
}
return resp, nil
}

func appendEntries(conn net.Conn, req raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) error {
w := bufio.NewWriter(conn)
enc := codec.NewEncoder(w, &codec.MsgpackHandle{})

const rpcAppendEntries = 0
if err := w.WriteByte(rpcAppendEntries); err != nil {
return fmt.Errorf("failed to write raft-RPC byte: %w", err)
}

if err := enc.Encode(req); err != nil {
return fmt.Errorf("failed to send append entries RPC: %w", err)
}
if err := w.Flush(); err != nil {
return fmt.Errorf("failed to flush RPC: %w", err)
}

if err := decodeRaftRPCResponse(conn, resp); err != nil {
return fmt.Errorf("response error: %w", err)
}
return nil
}

// copied and modified from raft/net_transport.go
func decodeRaftRPCResponse(conn net.Conn, resp *raft.AppendEntriesResponse) error {
r := bufio.NewReader(conn)
dec := codec.NewDecoder(r, &codec.MsgpackHandle{})

var rpcError string
if err := dec.Decode(&rpcError); err != nil {
return fmt.Errorf("failed to decode response error: %w", err)
}
if err := dec.Decode(resp); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
if rpcError != "" {
return fmt.Errorf("rpc error: %v", rpcError)
}
return nil
}

func isConnectionClosedError(err error) bool {
switch {
case err == nil:
return false
case errors.Is(err, io.EOF):
return true
case strings.Contains(err.Error(), "connection reset by peer"):
return true
default:
return false
}
}

0 comments on commit 3357e57

Please sign in to comment.