Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pkg/transport: reload TLS certificates for every client requests #7829

Merged
merged 2 commits into from Apr 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions integration/cluster.go
Expand Up @@ -76,6 +76,13 @@ var (
ClientCertAuth: true,
}

testTLSInfoExpired = transport.TLSInfo{
KeyFile: "./fixtures-expired/server-key.pem",
CertFile: "./fixtures-expired/server.pem",
TrustedCAFile: "./fixtures-expired/etcd-root-ca.pem",
ClientCertAuth: true,
}

plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "integration")
)

Expand Down
62 changes: 62 additions & 0 deletions integration/util_test.go
@@ -0,0 +1,62 @@
// Copyright 2017 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package integration

import (
"io"
"os"
"path/filepath"

"github.com/coreos/etcd/pkg/transport"
)

// copyTLSFiles clones certs files to dst directory.
func copyTLSFiles(ti transport.TLSInfo, dst string) (transport.TLSInfo, error) {
ci := transport.TLSInfo{
KeyFile: filepath.Join(dst, "server-key.pem"),
CertFile: filepath.Join(dst, "server.pem"),
TrustedCAFile: filepath.Join(dst, "etcd-root-ca.pem"),
ClientCertAuth: ti.ClientCertAuth,
}
if err := copyFile(ti.KeyFile, ci.KeyFile); err != nil {
return transport.TLSInfo{}, err
}
if err := copyFile(ti.CertFile, ci.CertFile); err != nil {
return transport.TLSInfo{}, err
}
if err := copyFile(ti.TrustedCAFile, ci.TrustedCAFile); err != nil {
return transport.TLSInfo{}, err
}
return ci, nil
}

func copyFile(src, dst string) error {
f, err := os.Open(src)
if err != nil {
return err
}
defer f.Close()

w, err := os.Create(dst)
if err != nil {
return err
}
defer w.Close()

if _, err = io.Copy(w, f); err != nil {
return err
}
return w.Sync()
}
205 changes: 205 additions & 0 deletions integration/v3_grpc_test.go
Expand Up @@ -16,17 +16,21 @@ package integration

import (
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"math/rand"
"os"
"reflect"
"testing"
"time"

"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/etcdserver/api/v3rpc"
"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/pkg/testutil"

"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -1374,6 +1378,207 @@ func TestTLSGRPCAcceptSecureAll(t *testing.T) {
}
}

// TestTLSReloadAtomicReplace ensures server reloads expired/valid certs
// when all certs are atomically replaced by directory renaming.
// And expects server to reject client requests, and vice versa.
func TestTLSReloadAtomicReplace(t *testing.T) {
defer testutil.AfterTest(t)

// clone valid,expired certs to separate directories for atomic renaming
vDir, err := ioutil.TempDir(os.TempDir(), "fixtures-valid")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(vDir)
ts, err := copyTLSFiles(testTLSInfo, vDir)
if err != nil {
t.Fatal(err)
}
eDir, err := ioutil.TempDir(os.TempDir(), "fixtures-expired")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(eDir)
if _, err = copyTLSFiles(testTLSInfoExpired, eDir); err != nil {
t.Fatal(err)
}

tDir, err := ioutil.TempDir(os.TempDir(), "fixtures")
if err != nil {
t.Fatal(err)
}
os.RemoveAll(tDir)
defer os.RemoveAll(tDir)

// start with valid certs
clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts})
defer clus.Terminate(t)

// concurrent client dialing while certs transition from valid to expired
errc := make(chan error, 1)
go func() {
for {
cc, err := ts.ClientConfig()
if err != nil {
if os.IsNotExist(err) {
// from concurrent renaming
continue
}
t.Fatal(err)
}
cli, cerr := clientv3.New(clientv3.Config{
Endpoints: []string{clus.Members[0].GRPCAddr()},
DialTimeout: time.Second,
TLS: cc,
})
if cerr != nil {
errc <- cerr
return
}
cli.Close()
}
}()

// replace certs directory with expired ones
if err = os.Rename(vDir, tDir); err != nil {
t.Fatal(err)
}
if err = os.Rename(eDir, vDir); err != nil {
t.Fatal(err)
}

// after rename,
// 'vDir' contains expired certs
// 'tDir' contains valid certs
// 'eDir' does not exist

select {
case err = <-errc:
if err != grpc.ErrClientConnTimeout {
t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, err)
}
case <-time.After(5 * time.Second):
t.Fatal("failed to receive dial timeout error")
}

// now, replace expired certs back with valid ones
if err = os.Rename(tDir, eDir); err != nil {
t.Fatal(err)
}
if err = os.Rename(vDir, tDir); err != nil {
t.Fatal(err)
}
if err = os.Rename(eDir, vDir); err != nil {
t.Fatal(err)
}

// new incoming client request should trigger
// listener to reload valid certs
var tls *tls.Config
tls, err = ts.ClientConfig()
if err != nil {
t.Fatal(err)
}
var cl *clientv3.Client
cl, err = clientv3.New(clientv3.Config{
Endpoints: []string{clus.Members[0].GRPCAddr()},
DialTimeout: time.Second,
TLS: tls,
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
cl.Close()
}

// TestTLSReloadCopy ensures server reloads expired/valid certs
// when new certs are copied over, one by one. And expects server
// to reject client requests, and vice versa.
func TestTLSReloadCopy(t *testing.T) {
defer testutil.AfterTest(t)

// clone certs directory, free to overwrite
cDir, err := ioutil.TempDir(os.TempDir(), "fixtures-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(cDir)
ts, err := copyTLSFiles(testTLSInfo, cDir)
if err != nil {
t.Fatal(err)
}

// start with valid certs
clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts})
defer clus.Terminate(t)

// concurrent client dialing while certs transition from valid to expired
errc := make(chan error, 1)
go func() {
for {
cc, err := ts.ClientConfig()
if err != nil {
// from concurrent certs overwriting
switch err.Error() {
case "tls: private key does not match public key":
fallthrough
case "tls: failed to find any PEM data in key input":
continue
}
t.Fatal(err)
}
cli, cerr := clientv3.New(clientv3.Config{
Endpoints: []string{clus.Members[0].GRPCAddr()},
DialTimeout: time.Second,
TLS: cc,
})
if cerr != nil {
errc <- cerr
return
}
cli.Close()
}
}()

// overwrite valid certs with expired ones
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyTLSFiles to replace the copyFiles?

// (e.g. simulate cert expiration in practice)
if _, err = copyTLSFiles(testTLSInfoExpired, cDir); err != nil {
t.Fatal(err)
}

select {
case gerr := <-errc:
if gerr != grpc.ErrClientConnTimeout {
t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, gerr)
}
case <-time.After(5 * time.Second):
t.Fatal("failed to receive dial timeout error")
}

// now, replace expired certs back with valid ones
if _, err = copyTLSFiles(testTLSInfo, cDir); err != nil {
t.Fatal(err)
}

// new incoming client request should trigger
// listener to reload valid certs
var tls *tls.Config
tls, err = ts.ClientConfig()
if err != nil {
t.Fatal(err)
}
var cl *clientv3.Client
cl, err = clientv3.New(clientv3.Config{
Endpoints: []string{clus.Members[0].GRPCAddr()},
DialTimeout: time.Second,
TLS: tls,
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
cl.Close()
}

func TestGRPCRequireLeader(t *testing.T) {
defer testutil.AfterTest(t)

Expand Down
8 changes: 8 additions & 0 deletions pkg/transport/listener.go
Expand Up @@ -172,6 +172,14 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
MinVersion: tls.VersionTLS12,
ServerName: info.ServerName,
}
// this only reloads certs when there's a client request
// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
}
cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
}
return cfg, nil
}

Expand Down