Skip to content

Commit

Permalink
Merge pull request #585 from dedis/issue583
Browse files Browse the repository at this point in the history
Make websocket start correctly when cert comes from reloader
  • Loading branch information
Gaylor Bosson committed Oct 9, 2019
2 parents 3561427 + 7e9aef1 commit c4fa81c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 6 deletions.
27 changes: 22 additions & 5 deletions local.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type LocalTest struct {
webSocketTLSCertificate []byte
// TLS certificate key if we want TLS for websocket
webSocketTLSCertificateKey []byte
// True if the unit test wants that webSocketTLSCertificate and webSocketTLSCertificateKey
// should be used as filenames.
webSocketTLSReadFiles bool
// the context for the local connections
// it enables to have multiple local test running simultaneously
ctx *network.LocalManager
Expand Down Expand Up @@ -598,12 +601,26 @@ func (l *LocalTest) NewServer(s network.Suite, port int) *Server {
server = l.newTCPServer(s)
// Set TLS certificate if any configuration available
if l.wantsTLS() {
cert, err := tls.X509KeyPair(l.webSocketTLSCertificate, l.webSocketTLSCertificateKey)
if err != nil {
panic(err)
}
server.WebSocket.Lock()
server.WebSocket.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
if l.webSocketTLSReadFiles {
cr, err := NewCertificateReloader(
string(l.webSocketTLSCertificate),
string(l.webSocketTLSCertificateKey))
if err != nil {
log.Error("cannot configure TLS reloader", err)
return nil
}
server.WebSocket.TLSConfig = &tls.Config{
GetCertificate: cr.GetCertificateFunc(),
}

} else {
cert, err := tls.X509KeyPair(l.webSocketTLSCertificate, l.webSocketTLSCertificateKey)
if err != nil {
panic(err)
}
server.WebSocket.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
server.WebSocket.Unlock()
}
server.StartInBackground()
Expand Down
2 changes: 1 addition & 1 deletion websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (w *WebSocket) start() {
go func() {
// Check if server is configured for TLS
started <- true
if w.server.Server.TLSConfig != nil && len(w.server.Server.TLSConfig.Certificates) >= 1 {
if w.server.Server.TLSConfig != nil && (w.server.TLSConfig.GetCertificate != nil || len(w.server.Server.TLSConfig.Certificates) >= 1) {
w.server.ListenAndServeTLS("", "")
} else {
w.server.ListenAndServe()
Expand Down
53 changes: 53 additions & 0 deletions websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,59 @@ func TestClientTLS_Send(t *testing.T) {
require.True(t, client.Tx() > client.Rx())
}

func TestClientTLS_certfile_Send(t *testing.T) {
// like TestClientTLSfile_Send, but uses cert and key from a file
// to solve issue 583.
cert, key, err := getSelfSignedCertificateAndKey()
require.Nil(t, err)
CAPool := x509.NewCertPool()
CAPool.AppendCertsFromPEM(cert)

f1, err := ioutil.TempFile("", "cert")
require.NoError(t, err)
defer os.Remove(f1.Name())
f1.Write(cert)
f1.Close()

f2, err := ioutil.TempFile("", "key")
require.NoError(t, err)
defer os.Remove(f2.Name())
f2.Write(key)
f2.Close()

local := NewTCPTest(tSuite)
local.webSocketTLSCertificate = []byte(f1.Name())
local.webSocketTLSCertificateKey = []byte(f2.Name())
local.webSocketTLSReadFiles = true
defer local.CloseAll()

// register service
RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
return &simpleService{
ctx: c,
}, nil
})
defer ServiceFactory.Unregister(backForthServiceName)

// create servers
servers, el, _ := local.GenTree(4, false)
client := local.NewClient(backForthServiceName)
client.TLSClientConfig = &tls.Config{RootCAs: CAPool}

r := &SimpleRequest{
ServerIdentities: el,
Val: 10,
}
sr := &SimpleResponse{}
require.Equal(t, uint64(0), client.Rx())
require.Equal(t, uint64(0), client.Tx())
require.Nil(t, client.SendProtobuf(servers[0].ServerIdentity, r, sr))
require.Equal(t, sr.Val, int64(10))
require.NotEqual(t, uint64(0), client.Rx())
require.NotEqual(t, uint64(0), client.Tx())
require.True(t, client.Tx() > client.Rx())
}

func TestClient_Parallel(t *testing.T) {
nbrNodes := 4
nbrParallel := 20
Expand Down

0 comments on commit c4fa81c

Please sign in to comment.