Skip to content

Commit

Permalink
client: support server TLS certificate change
Browse files Browse the repository at this point in the history
- `client/comms`: Instead of just keeping track of the connection
  being conneccted or not, a `ConnectionStatus` type is introduced
  which includes `InvalidCert` in addition to `Connected` and
  `Disconnected`. This is then passed to core and the UI through
  the `ConnEventNote`.

- `client/core`: A new function `UpdateCert` is added which uses
  TLS certificate supplied by the user to connect to a server. If
  the connection is successful, the new certificate is stored in
  the db.

- `client/webserver`: A new api request `/api/updatecert` is added
  to allow the user to update the TLS Certificate for a host.

- `ui`: The settings page is updated. Instead of all the options for
  each server being on the settings page, there is just a link to a
  dex settings page specific to a server. This contains all the options
  that previously existed on the settings page in addition to the new
  option to update the TLS Certificate.
  • Loading branch information
martonp committed May 3, 2022
1 parent 5e3f071 commit a87930b
Show file tree
Hide file tree
Showing 38 changed files with 681 additions and 316 deletions.
35 changes: 23 additions & 12 deletions client/comms/wsconn.go
Expand Up @@ -41,6 +41,15 @@ const (
DefaultResponseTimeout = 30 * time.Second
)

// ConnectionStatus represents the current status of the websocket connection.
type ConnectionStatus uint32

const (
Disconnected ConnectionStatus = iota
Connected
InvalidCert
)

// ErrInvalidCert is the error returned when attempting to use an invalid cert
// to set up a ws connection.
var ErrInvalidCert = fmt.Errorf("invalid certificate")
Expand Down Expand Up @@ -88,7 +97,7 @@ type WsCfg struct {
//
// NOTE: Disconnect event notifications may lag behind actual
// disconnections.
ConnectEventFunc func(bool)
ConnectEventFunc func(ConnectionStatus)

// Logger is the logger for the WsConn.
Logger dex.Logger
Expand All @@ -112,8 +121,8 @@ type wsConn struct {
wsMtx sync.Mutex
ws *websocket.Conn

connectedMtx sync.RWMutex
connected bool
connectedMtx sync.RWMutex
connectionStatus ConnectionStatus

reqMtx sync.RWMutex
respHandlers map[uint64]*responseHandler
Expand Down Expand Up @@ -165,18 +174,18 @@ func NewWsConn(cfg *WsCfg) (WsConn, error) {
func (conn *wsConn) IsDown() bool {
conn.connectedMtx.RLock()
defer conn.connectedMtx.RUnlock()
return !conn.connected
return conn.connectionStatus != Connected
}

// setConnected updates the connection's connected state and runs the
// setConnectionStatus updates the connection's status and runs the
// ConnectEventFunc in case of a change.
func (conn *wsConn) setConnected(connected bool) {
func (conn *wsConn) setConnectionStatus(status ConnectionStatus) {
conn.connectedMtx.Lock()
statusChange := conn.connected != connected
conn.connected = connected
statusChange := conn.connectionStatus != status
conn.connectionStatus = status
conn.connectedMtx.Unlock()
if statusChange && conn.cfg.ConnectEventFunc != nil {
conn.cfg.ConnectEventFunc(connected)
conn.cfg.ConnectEventFunc(status)
}
}

Expand All @@ -195,11 +204,13 @@ func (conn *wsConn) connect(ctx context.Context) error {
if err != nil {
var e x509.UnknownAuthorityError
if errors.As(err, &e) {
conn.setConnectionStatus(InvalidCert)
if conn.tlsCfg == nil {
return ErrCertRequired
}
return ErrInvalidCert
}
conn.setConnectionStatus(Disconnected)
return err
}

Expand Down Expand Up @@ -241,7 +252,7 @@ func (conn *wsConn) connect(ctx context.Context) error {
conn.ws = ws
conn.wsMtx.Unlock()

conn.setConnected(true)
conn.setConnectionStatus(Connected)
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
Expand All @@ -264,7 +275,7 @@ func (conn *wsConn) close() {
// run as a goroutine. Increment the wg before calling read.
func (conn *wsConn) read(ctx context.Context) {
reconnect := func() {
conn.setConnected(false)
conn.setConnectionStatus(Disconnected)
conn.reconnectCh <- struct{}{}
}

Expand Down Expand Up @@ -416,7 +427,7 @@ func (conn *wsConn) Connect(ctx context.Context) (*sync.WaitGroup, error) {
go func() {
defer conn.wg.Done()
<-ctxInternal.Done()
conn.setConnected(false)
conn.setConnectionStatus(Disconnected)
conn.wsMtx.Lock()
if conn.ws != nil {
conn.log.Debug("Sending close 1000 (normal) message.")
Expand Down
79 changes: 55 additions & 24 deletions client/core/core.go
Expand Up @@ -134,8 +134,9 @@ type dexConnection struct {

epochMtx sync.RWMutex
epoch map[string]uint64
// connected is a best guess on the ws connection status.
connected uint32

// connectionStatus is a best guess on the ws connection status.
connectionStatus uint32

pendingFeeMtx sync.RWMutex
pendingFee *pendingFeeState
Expand Down Expand Up @@ -294,12 +295,14 @@ func (dc *dexConnection) exchangeInfo() *Exchange {
dc.cfgMtx.RLock()
cfg := dc.cfg
dc.cfgMtx.RUnlock()
connectionStatus := comms.ConnectionStatus(
atomic.LoadUint32(&dc.connectionStatus))
if cfg == nil { // no config, assets, or markets data
return &Exchange{
Host: dc.acct.host,
AcctID: acctID,
Connected: atomic.LoadUint32(&dc.connected) == 1,
PendingFee: dc.getPendingFee(),
Host: dc.acct.host,
AcctID: acctID,
ConnectionStatus: connectionStatus,
PendingFee: dc.getPendingFee(),
}
}

Expand Down Expand Up @@ -328,16 +331,18 @@ func (dc *dexConnection) exchangeInfo() *Exchange {
feeAssets["dcr"] = dcrAsset
}

connectionStatus = comms.ConnectionStatus(
atomic.LoadUint32(&dc.connectionStatus))
return &Exchange{
Host: dc.acct.host,
AcctID: acctID,
Markets: dc.marketMap(),
Assets: assets,
Connected: atomic.LoadUint32(&dc.connected) == 1,
Fee: dcrAsset,
RegFees: feeAssets,
PendingFee: dc.getPendingFee(),
CandleDurs: cfg.BinSizes,
Host: dc.acct.host,
AcctID: acctID,
Markets: dc.marketMap(),
Assets: assets,
ConnectionStatus: connectionStatus,
Fee: dcrAsset,
RegFees: feeAssets,
PendingFee: dc.getPendingFee(),
CandleDurs: cfg.BinSizes,
}
}

Expand Down Expand Up @@ -947,10 +952,12 @@ func (c *Core) dex(addr string) (*dexConnection, bool, error) {
c.connMtx.RLock()
dc, found := c.conns[host]
c.connMtx.RUnlock()
connected := found && atomic.LoadUint32(&dc.connected) == 1
if !found {
return nil, false, fmt.Errorf("unknown DEX %s", addr)
}
connectionStatus := comms.ConnectionStatus(
atomic.LoadUint32(&dc.connectionStatus))
connected := connectionStatus == comms.Connected
return dc, connected, nil
}

Expand Down Expand Up @@ -2769,6 +2776,31 @@ func (c *Core) EstimateRegistrationTxFee(host string, certI interface{}, assetID
return txFee, nil
}

// UpdateCert attempts to connect to a server using a new TLS certificate. If
// the connection is successful, then the cert in the database is updated.
func (c *Core) UpdateCert(host string, cert []byte) error {
account, err := c.db.Account(host)
if err != nil {
return err
}
if account == nil {
return fmt.Errorf("account does not exist for host: %v", host)
}
account.Cert = cert

_, connected := c.connectAccount(account)
if !connected {
return errors.New("failed to connect using new cert")
}

err = c.db.UpdateAccountInfo(account)
if err != nil {
return fmt.Errorf("failed to update account info: %w", err)
}

return nil
}

// Register registers an account with a new DEX. If an error occurs while
// fetching the DEX configuration or creating the fee transaction, it will be
// returned immediately.
Expand Down Expand Up @@ -5555,6 +5587,7 @@ func (c *Core) connectDEX(acctInfo *db.AccountInfo, temporary ...bool) (*dexConn
apiVer: -1,
reportingConnects: reporting,
spots: make(map[string]*msgjson.Spot),
connectionStatus: uint32(comms.Disconnected),
// On connect, must set: cfg, epoch, and assets.
}

Expand Down Expand Up @@ -5584,8 +5617,8 @@ func (c *Core) connectDEX(acctInfo *db.AccountInfo, temporary ...bool) (*dexConn
return nil, errors.New("a TLS connection is required when not using a hidden service")
}

wsCfg.ConnectEventFunc = func(connected bool) {
c.handleConnectEvent(dc, connected)
wsCfg.ConnectEventFunc = func(status comms.ConnectionStatus) {
c.handleConnectEvent(dc, status)
}
wsCfg.ReconnectSync = func() {
go c.handleReconnect(host)
Expand Down Expand Up @@ -5739,11 +5772,9 @@ func (dc *dexConnection) broadcastingConnect() bool {
// lost or established.
//
// NOTE: Disconnect event notifications may lag behind actual disconnections.
func (c *Core) handleConnectEvent(dc *dexConnection, connected bool) {
var v uint32
func (c *Core) handleConnectEvent(dc *dexConnection, status comms.ConnectionStatus) {
topic := TopicDEXDisconnected
if connected {
v = 1
if status == comms.Connected {
topic = TopicDEXConnected
} else {
for _, tracker := range dc.trackedTrades() {
Expand All @@ -5759,10 +5790,10 @@ func (c *Core) handleConnectEvent(dc *dexConnection, connected bool) {
tracker.mtx.Unlock()
}
}
atomic.StoreUint32(&dc.connected, v)
atomic.StoreUint32(&dc.connectionStatus, uint32(status))
if dc.broadcastingConnect() {
subject, details := c.formatDetails(topic, dc.acct.host)
dc.notify(newConnEventNote(topic, subject, dc.acct.host, connected, details, db.Poke))
dc.notify(newConnEventNote(topic, subject, dc.acct.host, status, details, db.Poke))
}
}

Expand Down
8 changes: 4 additions & 4 deletions client/core/core_test.go
Expand Up @@ -267,7 +267,7 @@ func testDexConnection(ctx context.Context, crypter *tCrypter) (*dexConnection,
trades: make(map[order.OrderID]*trackedTrade),
epoch: map[string]uint64{tDcrBtcMktName: 0},
apiVer: serverdex.PreAPIVersion,
connected: 1,
connectionStatus: uint32(comms.Connected),
reportingConnects: 1,
spots: make(map[string]*msgjson.Spot),
}, conn, acct
Expand Down Expand Up @@ -396,7 +396,7 @@ func (tdb *TDB) DisableAccount(url string) error {
return tdb.disableAccountErr
}

func (tdb *TDB) UpdateAccount(ai *db.AccountInfo) error {
func (tdb *TDB) UpdateAccountInfo(ai *db.AccountInfo) error {
tdb.verifyUpdateAccount = true
tdb.acct = ai
return nil
Expand Down Expand Up @@ -2662,12 +2662,12 @@ wait:
rig.dc.acct.unlock(rig.crypter)

// DEX not connected
atomic.StoreUint32(&rig.dc.connected, 0)
atomic.StoreUint32(&rig.dc.connectionStatus, uint32(comms.Disconnected))
_, err = tCore.Trade(tPW, form)
if err == nil {
t.Fatalf("no error for disconnected dex")
}
atomic.StoreUint32(&rig.dc.connected, 1)
atomic.StoreUint32(&rig.dc.connectionStatus, uint32(comms.Connected))

// No base asset
form.Base = 12345
Expand Down
13 changes: 7 additions & 6 deletions client/core/notification.go
Expand Up @@ -6,6 +6,7 @@ package core
import (
"fmt"

"decred.org/dcrdex/client/comms"
"decred.org/dcrdex/client/db"
"decred.org/dcrdex/dex"
"decred.org/dcrdex/dex/msgjson"
Expand Down Expand Up @@ -335,20 +336,20 @@ func (on *EpochNotification) String() string {
// ConnEventNote is a notification regarding individual DEX connection status.
type ConnEventNote struct {
db.Notification
Host string `json:"host"`
Connected bool `json:"connected"`
Host string `json:"host"`
ConnectionStatus comms.ConnectionStatus `json:"connectionStatus"`
}

const (
TopicDEXConnected Topic = "DEXConnected"
TopicDEXDisconnected Topic = "DEXDisconnected"
)

func newConnEventNote(topic Topic, subject, host string, connected bool, details string, severity db.Severity) *ConnEventNote {
func newConnEventNote(topic Topic, subject, host string, status comms.ConnectionStatus, details string, severity db.Severity) *ConnEventNote {
return &ConnEventNote{
Notification: db.NewNotification(NoteTypeConnEvent, topic, subject, details, severity),
Host: host,
Connected: connected,
Notification: db.NewNotification(NoteTypeConnEvent, topic, subject, details, severity),
Host: host,
ConnectionStatus: status,
}
}

Expand Down
3 changes: 2 additions & 1 deletion client/core/trade_simnet_test.go
Expand Up @@ -39,6 +39,7 @@ import (

"decred.org/dcrdex/client/asset/btc"
"decred.org/dcrdex/client/asset/dcr"
"decred.org/dcrdex/client/comms"
"decred.org/dcrdex/client/db"
"decred.org/dcrdex/dex"
"decred.org/dcrdex/dex/calc"
Expand Down Expand Up @@ -677,7 +678,7 @@ func TestOrderStatusReconciliation(t *testing.T) {
disconnectTimeout := 10 * sleepFactor * time.Second
disconnected := client2.notes.find(context.Background(), disconnectTimeout, func(n Notification) bool {
connNote, ok := n.(*ConnEventNote)
return ok && connNote.Host == dexHost && !connNote.Connected
return ok && connNote.Host == dexHost && connNote.ConnectionStatus != comms.Connected
})
if !disconnected {
t.Fatalf("client 2 dex not disconnected after %v", disconnectTimeout)
Expand Down
19 changes: 10 additions & 9 deletions client/core/types.go
Expand Up @@ -11,6 +11,7 @@ import (
"sync"

"decred.org/dcrdex/client/asset"
"decred.org/dcrdex/client/comms"
"decred.org/dcrdex/client/db"
"decred.org/dcrdex/dex"
"decred.org/dcrdex/dex/calc"
Expand Down Expand Up @@ -491,15 +492,15 @@ type PendingFeeState struct {

// Exchange represents a single DEX with any number of markets.
type Exchange struct {
Host string `json:"host"`
AcctID string `json:"acctID"`
Markets map[string]*Market `json:"markets"`
Assets map[uint32]*dex.Asset `json:"assets"`
Connected bool `json:"connected"`
Fee *FeeAsset `json:"feeAsset"` // DEPRECATED. DCR.
RegFees map[string]*FeeAsset `json:"regFees"`
PendingFee *PendingFeeState `json:"pendingFee,omitempty"`
CandleDurs []string `json:"candleDurs"`
Host string `json:"host"`
AcctID string `json:"acctID"`
Markets map[string]*Market `json:"markets"`
Assets map[uint32]*dex.Asset `json:"assets"`
ConnectionStatus comms.ConnectionStatus `json:"connectionStatus"`
Fee *FeeAsset `json:"feeAsset"` // DEPRECATED. DCR.
RegFees map[string]*FeeAsset `json:"regFees"`
PendingFee *PendingFeeState `json:"pendingFee,omitempty"`
CandleDurs []string `json:"candleDurs"`
}

// newDisplayID creates a display-friendly market ID for a base/quote ID pair.
Expand Down
13 changes: 13 additions & 0 deletions client/db/bolt/db.go
Expand Up @@ -521,6 +521,19 @@ func (db *BoltDB) CreateAccount(ai *dexdb.AccountInfo) error {
})
}

// UpdateAccountInfo updates the account info for an existing account with
// the same Host as the parameter. If no account exists with this host,
// an error is returned.
func (db *BoltDB) UpdateAccountInfo(ai *dexdb.AccountInfo) error {
return db.acctsUpdate(func(accts *bbolt.Bucket) error {
acct := accts.Bucket([]byte(ai.Host))
if acct == nil {
return fmt.Errorf("account not found for %s", ai.Host)
}
return acct.Put(accountKey, ai.Encode())
})
}

// deleteAccount removes the account by host.
func (db *BoltDB) deleteAccount(host string) error {
acctKey := []byte(host)
Expand Down

0 comments on commit a87930b

Please sign in to comment.