diff --git a/dex/marshal.go b/dex/marshal.go index 21ae4435ca..1632456757 100644 --- a/dex/marshal.go +++ b/dex/marshal.go @@ -24,6 +24,18 @@ func (b Bytes) MarshalJSON() ([]byte, error) { return json.Marshal(hex.EncodeToString(b)) } +// Scan implements the sql.Scanner interface. +func (b *Bytes) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + *b = Bytes(src) + return nil + case nil: + return nil + } + return fmt.Errorf("cannot convert %T to Bytes", src) +} + // UnmarshalJSON satisfies the json.Unmarshaler interface, and expects a UTF-8 // encoding of a hex string. func (b *Bytes) UnmarshalJSON(encHex []byte) (err error) { diff --git a/server/admin/api.go b/server/admin/api.go index 9efd36d2fe..d1576e1481 100644 --- a/server/admin/api.go +++ b/server/admin/api.go @@ -161,12 +161,16 @@ func (s *Server) apiAccounts(w http.ResponseWriter, _ *http.Request) { // apiAccountInfo is the handler for the '/account/{account id}' API request. func (s *Server) apiAccountInfo(w http.ResponseWriter, r *http.Request) { - acct := strings.ToLower(chi.URLParam(r, accountNameKey)) - acctIDSlice, err := hex.DecodeString(acct) + acctIDStr := chi.URLParam(r, accountNameKey) + acctIDSlice, err := hex.DecodeString(acctIDStr) if err != nil { http.Error(w, fmt.Sprintf("could not decode accout id: %v", err), http.StatusBadRequest) return } + if len(acctIDSlice) != account.HashSize { + http.Error(w, "account id has incorrect length", http.StatusBadRequest) + return + } var acctID account.AccountID copy(acctID[:], acctIDSlice) acctInfo, err := s.core.AccountInfo(acctID) diff --git a/server/admin/server_test.go b/server/admin/server_test.go index c21097f03d..95fd24770f 100644 --- a/server/admin/server_test.go +++ b/server/admin/server_test.go @@ -874,6 +874,20 @@ func TestAccountInfo(t *testing.T) { t.Errorf("unexpected response %q, wanted %q", w.Body.String(), exp) } + // ok, upper case account id + w = httptest.NewRecorder() + r, _ = http.NewRequest("GET", "https://localhost/account/"+strings.ToUpper(acctIDStr), nil) + r.RemoteAddr = "localhost" + + mux.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("apiAccounts returned code %d, expected %d", w.Code, http.StatusOK) + } + if exp != w.Body.String() { + t.Errorf("unexpected response %q, wanted %q", w.Body.String(), exp) + } + // acct id is not hex w = httptest.NewRecorder() r, _ = http.NewRequest("GET", "https://localhost/account/nothex", nil) @@ -885,6 +899,17 @@ func TestAccountInfo(t *testing.T) { t.Fatalf("apiAccounts returned code %d, expected %d", w.Code, http.StatusBadRequest) } + // acct id wrong length + w = httptest.NewRecorder() + r, _ = http.NewRequest("GET", "https://localhost/account/"+acctIDStr[2:], nil) + r.RemoteAddr = "localhost" + + mux.ServeHTTP(w, r) + + if w.Code != http.StatusBadRequest { + t.Fatalf("apiAccounts returned code %d, expected %d", w.Code, http.StatusBadRequest) + } + // core.Account error core.accountErr = errors.New("error") diff --git a/server/db/driver/pg/accounts.go b/server/db/driver/pg/accounts.go index a5610e7244..24344f669c 100644 --- a/server/db/driver/pg/accounts.go +++ b/server/db/driver/pg/accounts.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" - "decred.org/dcrdex/dex" "decred.org/dcrdex/server/account" "decred.org/dcrdex/server/db" "decred.org/dcrdex/server/db/driver/pg/internal" @@ -54,18 +53,14 @@ func (a *Archiver) Accounts() ([]*db.Account, error) { } defer rows.Close() var accts []*db.Account - var accountID, pubkey, feeCoin []byte - var brokenRule byte + var feeAddress sql.NullString for rows.Next() { a := new(db.Account) - err = rows.Scan(&accountID, &pubkey, &a.FeeAddress, &feeCoin, &brokenRule) + err = rows.Scan(&a.AccountID, &a.Pubkey, &feeAddress, &a.FeeCoin, &a.BrokenRule) if err != nil { return nil, err } - copy(a.AccountID[:], accountID) - a.Pubkey = dex.Bytes(pubkey) - a.FeeCoin = dex.Bytes(feeCoin) - a.BrokenRule = account.Rule(brokenRule) + a.FeeAddress = feeAddress.String accts = append(accts, a) } if err = rows.Err(); err != nil { @@ -78,15 +73,12 @@ func (a *Archiver) Accounts() ([]*db.Account, error) { func (a *Archiver) AccountInfo(aid account.AccountID) (*db.Account, error) { stmt := fmt.Sprintf(internal.SelectAccountInfo, a.tables.accounts) acct := new(db.Account) - var accountID, pubkey, feeCoin []byte - var brokenRule byte - if err := a.db.QueryRow(stmt, aid).Scan(&accountID, &pubkey, &acct.FeeAddress, &feeCoin, &brokenRule); err != nil { + var feeAddress sql.NullString + if err := a.db.QueryRow(stmt, aid).Scan(&acct.AccountID, &acct.Pubkey, &feeAddress, + &acct.FeeCoin, &acct.BrokenRule); err != nil { return nil, err } - copy(acct.AccountID[:], accountID) - acct.Pubkey = dex.Bytes(pubkey) - acct.FeeCoin = dex.Bytes(feeCoin) - acct.BrokenRule = account.Rule(brokenRule) + acct.FeeAddress = feeAddress.String return acct, nil } diff --git a/server/db/driver/pg/accounts_online_test.go b/server/db/driver/pg/accounts_online_test.go index eb76b18b84..8d790e70d3 100644 --- a/server/db/driver/pg/accounts_online_test.go +++ b/server/db/driver/pg/accounts_online_test.go @@ -4,6 +4,7 @@ package pg import ( "encoding/hex" + "fmt" "reflect" "testing" @@ -98,6 +99,41 @@ func TestAccounts(t *testing.T) { t.Fatal("error getting account info: actual does not equal expected") } + // The Account ID cannot be null. broken_rule has a default value of 0 + // and is unexpected to become null. + nullAccounts := `UPDATE %s + SET + pubkey = null , + fee_address = null, + fee_coin = null;` + + stmt := fmt.Sprintf(nullAccounts, archie.tables.accounts) + if _, err = sqlExec(archie.db, stmt); err != nil { + t.Fatalf("error nullifying account: %v", err) + } + + accts, err = archie.Accounts() + if err != nil { + t.Fatalf("error getting null accounts: %v", err) + } + + // All fields except account ID are null. + if accts[0].AccountID.String() != "0a9912205b2cbab0c25c2de30bda9074de0ae23b065489a99199bad763f102cc" || + accts[0].Pubkey.String() != "" || + accts[0].FeeAddress != "" || + accts[0].FeeCoin.String() != "" || + byte(accts[0].BrokenRule) != byte(0) { + t.Fatal("accounts has unexpected data") + } + + anAcct, err = archie.AccountInfo(accts[0].AccountID) + if err != nil { + t.Fatalf("error getting null account info: %v", err) + } + if !reflect.DeepEqual(accts[0], anAcct) { + t.Fatal("error getting null account info: actual does not equal expected") + } + // Close the account for failure to complete a swap. archie.CloseAccount(tAcctID, account.FailureToAct) _, _, open = archie.Account(tAcctID)