Skip to content

Commit

Permalink
client: support server TLS certificate change
Browse files Browse the repository at this point in the history
- `client/comms`: Instead of just keeping track of the connection
  being connected or not, a `ConnectionStatus` type is introduced
  which includes `InvalidCert` in addition to `Connected` and
  `Disconnected`. This is then passed to core and the UI through
  the `ConnEventNote`.

- `client/core`: A new function `UpdateCert` is added which uses
  TLS certificate supplied by the user to connect to a server. If
  the connection is successful, the new certificate is stored in
  the db.

- `client/webserver`: A new api request `/api/updatecert` is added
  to allow the user to update the TLS Certificate for a host.

- `ui`: The settings page is updated. Instead of all the options for
  each server being on the settings page, there is just a link to a
  dex settings page specific to a server. This contains all the options
  that previously existed on the settings page in addition to the new
  option to update the TLS Certificate.
  • Loading branch information
martonp committed May 21, 2022
1 parent 05276d7 commit 30e29d7
Show file tree
Hide file tree
Showing 42 changed files with 836 additions and 336 deletions.
71 changes: 42 additions & 29 deletions client/comms/wsconn.go
Expand Up @@ -41,6 +41,15 @@ const (
DefaultResponseTimeout = 30 * time.Second
)

// ConnectionStatus represents the current status of the websocket connection.
type ConnectionStatus uint32

const (
Disconnected ConnectionStatus = iota
Connected
InvalidCert
)

// ErrInvalidCert is the error returned when attempting to use an invalid cert
// to set up a ws connection.
var ErrInvalidCert = fmt.Errorf("invalid certificate")
Expand Down Expand Up @@ -88,7 +97,7 @@ type WsCfg struct {
//
// NOTE: Disconnect event notifications may lag behind actual
// disconnections.
ConnectEventFunc func(bool)
ConnectEventFunc func(ConnectionStatus)

// Logger is the logger for the WsConn.
Logger dex.Logger
Expand All @@ -112,8 +121,7 @@ type wsConn struct {
wsMtx sync.Mutex
ws *websocket.Conn

connectedMtx sync.RWMutex
connected bool
connectionStatus uint32 // atomic

reqMtx sync.RWMutex
respHandlers map[uint64]*responseHandler
Expand Down Expand Up @@ -163,20 +171,16 @@ func NewWsConn(cfg *WsCfg) (WsConn, error) {

// IsDown indicates if the connection is known to be down.
func (conn *wsConn) IsDown() bool {
conn.connectedMtx.RLock()
defer conn.connectedMtx.RUnlock()
return !conn.connected
return atomic.LoadUint32(&conn.connectionStatus) != uint32(Connected)
}

// setConnected updates the connection's connected state and runs the
// setConnectionStatus updates the connection's status and runs the
// ConnectEventFunc in case of a change.
func (conn *wsConn) setConnected(connected bool) {
conn.connectedMtx.Lock()
statusChange := conn.connected != connected
conn.connected = connected
conn.connectedMtx.Unlock()
func (conn *wsConn) setConnectionStatus(status ConnectionStatus) {
oldStatus := atomic.SwapUint32(&conn.connectionStatus, uint32(status))
statusChange := oldStatus != uint32(status)
if statusChange && conn.cfg.ConnectEventFunc != nil {
conn.cfg.ConnectEventFunc(connected)
conn.cfg.ConnectEventFunc(status)
}
}

Expand All @@ -195,11 +199,13 @@ func (conn *wsConn) connect(ctx context.Context) error {
if err != nil {
var e x509.UnknownAuthorityError
if errors.As(err, &e) {
conn.setConnectionStatus(InvalidCert)
if conn.tlsCfg == nil {
return ErrCertRequired
}
return ErrInvalidCert
}
conn.setConnectionStatus(Disconnected)
return err
}

Expand Down Expand Up @@ -241,7 +247,7 @@ func (conn *wsConn) connect(ctx context.Context) error {
conn.ws = ws
conn.wsMtx.Unlock()

conn.setConnected(true)
conn.setConnectionStatus(Connected)
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
Expand All @@ -264,7 +270,7 @@ func (conn *wsConn) close() {
// run as a goroutine. Increment the wg before calling read.
func (conn *wsConn) read(ctx context.Context) {
reconnect := func() {
conn.setConnected(false)
conn.setConnectionStatus(Disconnected)
conn.reconnectCh <- struct{}{}
}

Expand Down Expand Up @@ -406,17 +412,34 @@ func (conn *wsConn) Connect(ctx context.Context) (*sync.WaitGroup, error) {
var ctxInternal context.Context
ctxInternal, conn.cancel = context.WithCancel(ctx)

conn.wg.Add(1)
err := conn.connect(ctxInternal)
if err != nil {
// If the certificate is invalid or missing, do not start the reconnect
// loop, and return an error with no WaitGroup.
if errors.Is(err, ErrInvalidCert) || errors.Is(err, ErrCertRequired) {
conn.cancel()
conn.wg.Wait() // probably a no-op
close(conn.readCh)
return nil, err
}

// The read loop would normally trigger keepAlive, but it wasn't started
// on account of a connect error.
conn.log.Errorf("Initial connection failed, starting reconnect loop: %v", err)
time.AfterFunc(5*time.Second, func() {
conn.reconnectCh <- struct{}{}
})
}

conn.wg.Add(2)
go func() {
defer conn.wg.Done()
conn.keepAlive(ctxInternal)
}()

conn.wg.Add(1)
go func() {
defer conn.wg.Done()
<-ctxInternal.Done()
conn.setConnected(false)
conn.setConnectionStatus(Disconnected)
conn.wsMtx.Lock()
if conn.ws != nil {
conn.log.Debug("Sending close 1000 (normal) message.")
Expand All @@ -427,16 +450,6 @@ func (conn *wsConn) Connect(ctx context.Context) (*sync.WaitGroup, error) {
close(conn.readCh) // signal to MessageSource receivers that the wsConn is dead
}()

err := conn.connect(ctxInternal)
if err != nil {
// The read loop would normally trigger keepAlive, but it wasn't started
// on account of a connect error.
conn.log.Errorf("Initial connection failed, starting reconnect loop: %v", err)
time.AfterFunc(5*time.Second, func() {
conn.reconnectCh <- struct{}{}
})
}

return &conn.wg, err
}

Expand Down
24 changes: 24 additions & 0 deletions client/core/account.go
Expand Up @@ -2,6 +2,7 @@ package core

import (
"encoding/hex"
"errors"
"fmt"

"decred.org/dcrdex/client/db"
Expand Down Expand Up @@ -190,3 +191,26 @@ func (c *Core) AccountImport(pw []byte, acct Account) error {

return nil
}

// UpdateCert attempts to connect to a server using a new TLS certificate. If
// the connection is successful, then the cert in the database is updated.
func (c *Core) UpdateCert(host string, cert []byte) error {
accountInfo, err := c.db.Account(host)
if err != nil {
return err
}

accountInfo.Cert = cert

_, connected := c.connectAccount(accountInfo)
if !connected {
return errors.New("failed to connect using new cert")
}

err = c.db.UpdateAccountInfo(accountInfo)
if err != nil {
return fmt.Errorf("failed to update account info: %w", err)
}

return nil
}
77 changes: 77 additions & 0 deletions client/core/account_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"decred.org/dcrdex/client/db"
"decred.org/dcrdex/dex/encode"
"decred.org/dcrdex/dex/order"
"decred.org/dcrdex/server/account"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
Expand Down Expand Up @@ -179,6 +180,82 @@ func TestAccountDisable(t *testing.T) {
}
}

func TestUpdateCert(t *testing.T) {
rig := newTestRig()
tCore := rig.core
rig.db.acct.Paid = true
rig.db.acct.FeeCoin = encode.RandomBytes(32)

tests := []struct {
name string
host string
acctErr bool
updateAccountInfoErr bool
queueConfig bool
expectError bool
}{
{
name: "ok",
host: rig.db.acct.Host,
queueConfig: true,
},
{
name: "connect error",
host: rig.db.acct.Host,
queueConfig: false,
expectError: true,
},
{
name: "db get account error",
host: rig.db.acct.Host,
queueConfig: true,
acctErr: true,
expectError: true,
},
{
name: "db update account err",
host: rig.db.acct.Host,
queueConfig: true,
updateAccountInfoErr: true,
expectError: true,
},
}

for _, test := range tests {
rig.db.verifyUpdateAccountInfo = false
if test.updateAccountInfoErr {
rig.db.updateAccountInfoErr = errors.New("")
} else {
rig.db.updateAccountInfoErr = nil
}
if test.acctErr {
rig.db.acctErr = errors.New("")
} else {
rig.db.acctErr = nil
}
randomCert := encode.RandomBytes(32)
if test.queueConfig {
rig.queueConfig()
}
err := tCore.UpdateCert(test.host, randomCert)
if test.expectError {
if err == nil {
t.Fatalf("%s: expected error but did not get", test.name)
}
continue
}
if err != nil {
t.Fatalf("%s: unexpected error: %v", test.name, err)
}
if !rig.db.verifyUpdateAccountInfo {
t.Fatalf("%s: expected update account to be called but it was not", test.name)
}
if !bytes.Equal(randomCert, rig.db.acct.Cert) {
t.Fatalf("%s: expected account to be updated with cert but it was not", test.name)
}
}
}

func TestAccountExportPasswordError(t *testing.T) {
rig := newTestRig()
tCore := rig.core
Expand Down

0 comments on commit 30e29d7

Please sign in to comment.