Skip to content

Commit

Permalink
ONTO-481 fix account duplicate import at sigsvr import cmd (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
gasby88 authored and laizy committed Sep 6, 2018
1 parent e6ca859 commit ad3a82e
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 23 deletions.
2 changes: 1 addition & 1 deletion cmd/sigsvr/handlers/create_account.go
Expand Up @@ -46,7 +46,7 @@ func CreateAccount(req *clisvrcom.CliRpcRequest, resp *clisvrcom.CliRpcResponse)
log.Errorf("CreateAccount Qid:%s NewAccountData error:%s", req.Qid, err)
return
}
err = clisvrcom.DefWalletStore.AddAccountData(accData)
_, err = clisvrcom.DefWalletStore.AddAccountData(accData)
if err != nil {
resp.ErrorCode = clisvrcom.CLIERR_INTERNAL_ERR
resp.ErrorInfo = "create wallet failed"
Expand Down
2 changes: 1 addition & 1 deletion cmd/sigsvr/handlers/sig_raw_tx_test.go
Expand Up @@ -61,7 +61,7 @@ func TestMain(m *testing.M) {
log.Errorf("NewWalletStore error:%s", err)
return
}
err = clisvrcom.DefWalletStore.AddAccountData(testWallet.GetWalletData().Accounts[0])
_, err = clisvrcom.DefWalletStore.AddAccountData(testWallet.GetWalletData().Accounts[0])
if err != nil {
log.Errorf("AddAccountData error:%s", err)
return
Expand Down
13 changes: 11 additions & 2 deletions cmd/sigsvr/import_wallet.go
Expand Up @@ -62,13 +62,22 @@ func importWallet(ctx *cli.Context) error {
if *walletStore.WalletScrypt != *walletData.Scrypt {
return fmt.Errorf("import account failed, wallet scrypt:%+v != %+v", walletData.Scrypt, walletStore.WalletScrypt)
}
addNum := 0
updateNum := 0
for i := 0; i < len(walletData.Accounts); i++ {
err = walletStore.AddAccountData(walletData.Accounts[i])
ok, err := walletStore.AddAccountData(walletData.Accounts[i])
if err != nil {
return fmt.Errorf("import account address:%s error:%s", walletData.Accounts[i].Address, err)
}
if ok {
addNum++
} else {
updateNum++
}
}
cmd.PrintInfoMsg("Import account success.")
cmd.PrintInfoMsg("Account number:%d", len(walletData.Accounts))
cmd.PrintInfoMsg("Total account number:%d", len(walletData.Accounts))
cmd.PrintInfoMsg("Add account number:%d", addNum)
cmd.PrintInfoMsg("Update account number:%d", updateNum)
return nil
}
5 changes: 5 additions & 0 deletions cmd/sigsvr/store/common.go
Expand Up @@ -34,6 +34,7 @@ const (
WALLET_ACCOUNT_INDEX_PREFIX = 0x05
WALLET_ACCOUNT_PREFIX = 0x06
WALLET_EXTRA_PREFIX = 0x07
WALLET_ACCOUNT_NUMBER = 0x08
)

func GetWalletInitKey() []byte {
Expand Down Expand Up @@ -69,3 +70,7 @@ func GetAccountKey(address string) []byte {
func GetWalletExtraKey() []byte {
return []byte{WALLET_EXTRA_PREFIX}
}

func GetWalletAccountNumberKey() []byte {
return []byte{WALLET_ACCOUNT_NUMBER}
}
102 changes: 87 additions & 15 deletions cmd/sigsvr/store/wallet_store.go
Expand Up @@ -257,41 +257,61 @@ func (this *WalletStore) NewAccountData(typeCode keypair.KeyType, curveCode byte
return accData, nil
}

func (this *WalletStore) AddAccountData(accData *account.AccountData) error {
func (this *WalletStore) AddAccountData(accData *account.AccountData) (bool, error) {
isExist, err := this.IsAccountExist(accData.Address)
if err != nil {
return false, err
}

this.lock.Lock()
defer this.lock.Unlock()

if this.nextAccountIndex == 0 {
accountNum, err := this.GetAccountNumber()
if err != nil {
return false, fmt.Errorf("GetAccountNumber error:%s", err)
}
if accountNum == 0 {
accData.IsDefault = true
} else {
accData.IsDefault = false
}

batch := &leveldb.Batch{}
data, err := json.Marshal(accData)
if err != nil {
return err
return false, err
}

batch := &leveldb.Batch{}
//Put account
batch.Put(GetAccountKey(accData.Address), data)

accountIndex := this.nextAccountIndex
//Put account index
batch.Put(GetAccountIndexKey(accountIndex), []byte(accData.Address))
nextIndex := this.nextAccountIndex
if !isExist {
accountIndex := nextIndex
//Put account index
batch.Put(GetAccountIndexKey(accountIndex), []byte(accData.Address))

nextIndex++
data = make([]byte, 4, 4)
binary.LittleEndian.PutUint32(data, nextIndex)

//Put next account index
batch.Put(GetNextAccountIndexKey(), data)

nextIndex := accountIndex + 1
data = make([]byte, 4, 4)
binary.LittleEndian.PutUint32(data, nextIndex)
accountNum++
binary.LittleEndian.PutUint32(data, accountNum)

//Put next account index
batch.Put(GetNextAccountIndexKey(), data)
//Put account number
batch.Put(GetWalletAccountNumberKey(), data)
}

err = this.db.Write(batch, nil)
if err != nil {
return err
return false, err
}
this.nextAccountIndex = nextIndex
return nil

isAdd := !isExist
return isAdd, nil
}

func (this *WalletStore) getNextAccountIndex() (uint32, error) {
Expand Down Expand Up @@ -321,6 +341,17 @@ func (this *WalletStore) GetAccountDataByAddress(address string) (*account.Accou
return accData, nil
}

func (this *WalletStore) IsAccountExist(address string) (bool, error) {
data, err := this.db.Get(GetAccountKey(address), nil)
if err != nil {
if err == leveldb.ErrNotFound {
return false, nil
}
return false, err
}
return len(data) != 0, nil
}

func (this *WalletStore) GetAccountDataByIndex(index uint32) (*account.AccountData, error) {
address, err := this.GetAccountAddress(index)
if err != nil {
Expand All @@ -342,3 +373,44 @@ func (this *WalletStore) GetAccountAddress(index uint32) (string, error) {
}
return string(data), nil
}

func (this *WalletStore) setAccountNumber(number uint32) error {
data := make([]byte, 4, 4)
binary.LittleEndian.PutUint32(data, number)
return this.db.Put(GetWalletAccountNumberKey(), data, nil)
}

func (this *WalletStore) GetAccountNumber() (uint32, error) {
data, err := this.db.Get(GetWalletAccountNumberKey(), nil)
if err == nil {
return binary.LittleEndian.Uint32(data), nil
}
if err != leveldb.ErrNotFound {
return 0, err
}
//Keep downward compatible
nextIndex, err := this.getNextAccountIndex()
if err != nil {
return 0, fmt.Errorf("getNextAccountIndex error:%s", err)
}
if nextIndex == 0 {
return 0, nil
}
addresses := make(map[string]string, 0)
for i := uint32(0); i < nextIndex; i++ {
address, err := this.GetAccountAddress(i)
if err != nil {
return 0, fmt.Errorf("GetAccountAddress Index:%d error:%s", i, err)
}
if address == "" {
continue
}
addresses[address] = ""
}
accNum := uint32(len(addresses))
err = this.setAccountNumber(accNum)
if err != nil {
return 0, fmt.Errorf("setAccountNumber error")
}
return accNum, nil
}
13 changes: 9 additions & 4 deletions sigsvr.go
Expand Up @@ -63,23 +63,28 @@ func startSigSvr(ctx *cli.Context) {

walletDirPath := ctx.String(utils.GetFlagName(utils.CliWalletDirFlag))
if walletDirPath == "" {
log.Infof("Please using --walletdir flag to specific wallet saving path")
log.Errorf("Please using --walletdir flag to specific wallet saving path")
return
}

walletStore, err := store.NewWalletStore(walletDirPath)
if err != nil {
log.Infof("NewWalletStore error:%s", err)
log.Errorf("NewWalletStore error:%s", err)
return
}
clisvrcom.DefWalletStore = walletStore

log.Infof("Load wallet data success. Account number:%d", walletStore.GetNextAccountIndex())
accountNum, err := walletStore.GetAccountNumber()
if err != nil {
log.Errorf("GetAccountNumber error:%s", err)
return
}
log.Infof("Load wallet data success. Account number:%d", accountNum)

rpcAddress := ctx.String(utils.GetFlagName(utils.CliAddressFlag))
rpcPort := ctx.Uint(utils.GetFlagName(utils.CliRpcPortFlag))
if rpcPort == 0 {
log.Infof("Please using sig server port by --%s flag", utils.GetFlagName(utils.CliRpcPortFlag))
log.Errorf("Please using sig server port by --%s flag", utils.GetFlagName(utils.CliRpcPortFlag))
return
}
go cmdsvr.DefCliRpcSvr.Start(rpcAddress, rpcPort)
Expand Down

0 comments on commit ad3a82e

Please sign in to comment.