Skip to content

Commit

Permalink
client: support server TLS certificate change
Browse files Browse the repository at this point in the history
- `client/comms`: Instead of just keeping track of the connection
  being conneccted or not, a `ConnectionStatus` type is introduced
  which includes `InvalidCert` in addition to `Connected` and
  `Disconnected`. This is then passed to core and the UI through
  the `ConnEventNote`.

- `client/core`: A new function `UpdateCert` is added which uses
  TLS certificate supplied by the user to connect to a server. If
  the connection is successful, the new certificate is stored in
  the db.

- `client/webserver`: A new api request `/api/updatecert` is added
  to allow the user to update the TLS Certificate for a host.

- `ui`: The settings page is updated. Instead of all the options for
  each server being on the settings page, there is just a link to a
  dex settings page specific to a server. This contains all the options
  that previously existed on the settings page in addition to the new
  option to update the TLS Certificate.
  • Loading branch information
martonp committed May 4, 2022
1 parent 5e3f071 commit 1b5370b
Show file tree
Hide file tree
Showing 40 changed files with 763 additions and 319 deletions.
35 changes: 23 additions & 12 deletions client/comms/wsconn.go
Expand Up @@ -41,6 +41,15 @@ const (
DefaultResponseTimeout = 30 * time.Second
)

// ConnectionStatus represents the current status of the websocket connection.
type ConnectionStatus uint32

const (
Disconnected ConnectionStatus = iota
Connected
InvalidCert
)

// ErrInvalidCert is the error returned when attempting to use an invalid cert
// to set up a ws connection.
var ErrInvalidCert = fmt.Errorf("invalid certificate")
Expand Down Expand Up @@ -88,7 +97,7 @@ type WsCfg struct {
//
// NOTE: Disconnect event notifications may lag behind actual
// disconnections.
ConnectEventFunc func(bool)
ConnectEventFunc func(ConnectionStatus)

// Logger is the logger for the WsConn.
Logger dex.Logger
Expand All @@ -112,8 +121,8 @@ type wsConn struct {
wsMtx sync.Mutex
ws *websocket.Conn

connectedMtx sync.RWMutex
connected bool
connectedMtx sync.RWMutex
connectionStatus ConnectionStatus

reqMtx sync.RWMutex
respHandlers map[uint64]*responseHandler
Expand Down Expand Up @@ -165,18 +174,18 @@ func NewWsConn(cfg *WsCfg) (WsConn, error) {
func (conn *wsConn) IsDown() bool {
conn.connectedMtx.RLock()
defer conn.connectedMtx.RUnlock()
return !conn.connected
return conn.connectionStatus != Connected
}

// setConnected updates the connection's connected state and runs the
// setConnectionStatus updates the connection's status and runs the
// ConnectEventFunc in case of a change.
func (conn *wsConn) setConnected(connected bool) {
func (conn *wsConn) setConnectionStatus(status ConnectionStatus) {
conn.connectedMtx.Lock()
statusChange := conn.connected != connected
conn.connected = connected
statusChange := conn.connectionStatus != status
conn.connectionStatus = status
conn.connectedMtx.Unlock()
if statusChange && conn.cfg.ConnectEventFunc != nil {
conn.cfg.ConnectEventFunc(connected)
conn.cfg.ConnectEventFunc(status)
}
}

Expand All @@ -195,11 +204,13 @@ func (conn *wsConn) connect(ctx context.Context) error {
if err != nil {
var e x509.UnknownAuthorityError
if errors.As(err, &e) {
conn.setConnectionStatus(InvalidCert)
if conn.tlsCfg == nil {
return ErrCertRequired
}
return ErrInvalidCert
}
conn.setConnectionStatus(Disconnected)
return err
}

Expand Down Expand Up @@ -241,7 +252,7 @@ func (conn *wsConn) connect(ctx context.Context) error {
conn.ws = ws
conn.wsMtx.Unlock()

conn.setConnected(true)
conn.setConnectionStatus(Connected)
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
Expand All @@ -264,7 +275,7 @@ func (conn *wsConn) close() {
// run as a goroutine. Increment the wg before calling read.
func (conn *wsConn) read(ctx context.Context) {
reconnect := func() {
conn.setConnected(false)
conn.setConnectionStatus(Disconnected)
conn.reconnectCh <- struct{}{}
}

Expand Down Expand Up @@ -416,7 +427,7 @@ func (conn *wsConn) Connect(ctx context.Context) (*sync.WaitGroup, error) {
go func() {
defer conn.wg.Done()
<-ctxInternal.Done()
conn.setConnected(false)
conn.setConnectionStatus(Disconnected)
conn.wsMtx.Lock()
if conn.ws != nil {
conn.log.Debug("Sending close 1000 (normal) message.")
Expand Down
26 changes: 26 additions & 0 deletions client/core/account.go
Expand Up @@ -2,6 +2,7 @@ package core

import (
"encoding/hex"
"errors"
"fmt"

"decred.org/dcrdex/client/db"
Expand Down Expand Up @@ -190,3 +191,28 @@ func (c *Core) AccountImport(pw []byte, acct Account) error {

return nil
}

// UpdateCert attempts to connect to a server using a new TLS certificate. If
// the connection is successful, then the cert in the database is updated.
func (c *Core) UpdateCert(host string, cert []byte) error {
account, err := c.db.Account(host)
if err != nil {
return err
}
if account == nil {
return fmt.Errorf("account does not exist for host: %v", host)
}
account.Cert = cert

_, connected := c.connectAccount(account)
if !connected {
return errors.New("failed to connect using new cert")
}

err = c.db.UpdateAccountInfo(account)
if err != nil {
return fmt.Errorf("failed to update account info: %w", err)
}

return nil
}
77 changes: 77 additions & 0 deletions client/core/account_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"decred.org/dcrdex/client/db"
"decred.org/dcrdex/dex/encode"
"decred.org/dcrdex/dex/order"
"decred.org/dcrdex/server/account"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
Expand Down Expand Up @@ -179,6 +180,82 @@ func TestAccountDisable(t *testing.T) {
}
}

func TestUpdateCert(t *testing.T) {
rig := newTestRig()
tCore := rig.core
rig.db.acct.Paid = true
rig.db.acct.FeeCoin = encode.RandomBytes(32)

tests := []struct {
name string
host string
acctErr bool
updateAccountInfoErr bool
queueConfig bool
expectError bool
}{
{
name: "ok",
host: rig.db.acct.Host,
queueConfig: true,
},
{
name: "connect error",
host: rig.db.acct.Host,
queueConfig: false,
expectError: true,
},
{
name: "db get account error",
host: rig.db.acct.Host,
queueConfig: true,
acctErr: true,
expectError: true,
},
{
name: "db update account err",
host: rig.db.acct.Host,
queueConfig: true,
updateAccountInfoErr: true,
expectError: true,
},
}

for _, test := range tests {
rig.db.verifyUpdateAccountInfo = false
if test.updateAccountInfoErr {
rig.db.updateAccountInfoErr = errors.New("")
} else {
rig.db.updateAccountInfoErr = nil
}
if test.acctErr {
rig.db.acctErr = errors.New("")
} else {
rig.db.acctErr = nil
}
randomCert := encode.RandomBytes(32)
if test.queueConfig {
rig.queueConfig()
}
err := tCore.UpdateCert(test.host, randomCert)
if test.expectError {
if err == nil {
t.Fatalf("%s: expected error but did not get", test.name)
}
continue
}
if err != nil {
t.Fatalf("%s: unexpected error: %v", test.name, err)
}
if !rig.db.verifyUpdateAccountInfo {
t.Fatalf("%s: expected update account to be called but it was not", test.name)
}
if !bytes.Equal(randomCert, rig.db.acct.Cert) {
t.Fatalf("%s: expected account to be updated with cert but it was not", test.name)
}
}
}

func TestAccountExportPasswordError(t *testing.T) {
rig := newTestRig()
tCore := rig.core
Expand Down
54 changes: 30 additions & 24 deletions client/core/core.go
Expand Up @@ -134,8 +134,9 @@ type dexConnection struct {

epochMtx sync.RWMutex
epoch map[string]uint64
// connected is a best guess on the ws connection status.
connected uint32

// connectionStatus is a best guess on the ws connection status.
connectionStatus uint32

pendingFeeMtx sync.RWMutex
pendingFee *pendingFeeState
Expand Down Expand Up @@ -294,12 +295,14 @@ func (dc *dexConnection) exchangeInfo() *Exchange {
dc.cfgMtx.RLock()
cfg := dc.cfg
dc.cfgMtx.RUnlock()
connectionStatus := comms.ConnectionStatus(
atomic.LoadUint32(&dc.connectionStatus))
if cfg == nil { // no config, assets, or markets data
return &Exchange{
Host: dc.acct.host,
AcctID: acctID,
Connected: atomic.LoadUint32(&dc.connected) == 1,
PendingFee: dc.getPendingFee(),
Host: dc.acct.host,
AcctID: acctID,
ConnectionStatus: connectionStatus,
PendingFee: dc.getPendingFee(),
}
}

Expand Down Expand Up @@ -328,16 +331,18 @@ func (dc *dexConnection) exchangeInfo() *Exchange {
feeAssets["dcr"] = dcrAsset
}

connectionStatus = comms.ConnectionStatus(
atomic.LoadUint32(&dc.connectionStatus))
return &Exchange{
Host: dc.acct.host,
AcctID: acctID,
Markets: dc.marketMap(),
Assets: assets,
Connected: atomic.LoadUint32(&dc.connected) == 1,
Fee: dcrAsset,
RegFees: feeAssets,
PendingFee: dc.getPendingFee(),
CandleDurs: cfg.BinSizes,
Host: dc.acct.host,
AcctID: acctID,
Markets: dc.marketMap(),
Assets: assets,
ConnectionStatus: connectionStatus,
Fee: dcrAsset,
RegFees: feeAssets,
PendingFee: dc.getPendingFee(),
CandleDurs: cfg.BinSizes,
}
}

Expand Down Expand Up @@ -947,10 +952,12 @@ func (c *Core) dex(addr string) (*dexConnection, bool, error) {
c.connMtx.RLock()
dc, found := c.conns[host]
c.connMtx.RUnlock()
connected := found && atomic.LoadUint32(&dc.connected) == 1
if !found {
return nil, false, fmt.Errorf("unknown DEX %s", addr)
}
connectionStatus := comms.ConnectionStatus(
atomic.LoadUint32(&dc.connectionStatus))
connected := connectionStatus == comms.Connected
return dc, connected, nil
}

Expand Down Expand Up @@ -5555,6 +5562,7 @@ func (c *Core) connectDEX(acctInfo *db.AccountInfo, temporary ...bool) (*dexConn
apiVer: -1,
reportingConnects: reporting,
spots: make(map[string]*msgjson.Spot),
connectionStatus: uint32(comms.Disconnected),
// On connect, must set: cfg, epoch, and assets.
}

Expand Down Expand Up @@ -5584,8 +5592,8 @@ func (c *Core) connectDEX(acctInfo *db.AccountInfo, temporary ...bool) (*dexConn
return nil, errors.New("a TLS connection is required when not using a hidden service")
}

wsCfg.ConnectEventFunc = func(connected bool) {
c.handleConnectEvent(dc, connected)
wsCfg.ConnectEventFunc = func(status comms.ConnectionStatus) {
c.handleConnectEvent(dc, status)
}
wsCfg.ReconnectSync = func() {
go c.handleReconnect(host)
Expand Down Expand Up @@ -5739,11 +5747,9 @@ func (dc *dexConnection) broadcastingConnect() bool {
// lost or established.
//
// NOTE: Disconnect event notifications may lag behind actual disconnections.
func (c *Core) handleConnectEvent(dc *dexConnection, connected bool) {
var v uint32
func (c *Core) handleConnectEvent(dc *dexConnection, status comms.ConnectionStatus) {
topic := TopicDEXDisconnected
if connected {
v = 1
if status == comms.Connected {
topic = TopicDEXConnected
} else {
for _, tracker := range dc.trackedTrades() {
Expand All @@ -5759,10 +5765,10 @@ func (c *Core) handleConnectEvent(dc *dexConnection, connected bool) {
tracker.mtx.Unlock()
}
}
atomic.StoreUint32(&dc.connected, v)
atomic.StoreUint32(&dc.connectionStatus, uint32(status))
if dc.broadcastingConnect() {
subject, details := c.formatDetails(topic, dc.acct.host)
dc.notify(newConnEventNote(topic, subject, dc.acct.host, connected, details, db.Poke))
dc.notify(newConnEventNote(topic, subject, dc.acct.host, status, details, db.Poke))
}
}

Expand Down

0 comments on commit 1b5370b

Please sign in to comment.