Skip to content

Commit

Permalink
accounts/keystore: Prevent race in Import function
Browse files Browse the repository at this point in the history
  • Loading branch information
blukat29 committed Apr 25, 2024
1 parent 9f6cf6a commit 7d95568
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 8 deletions.
17 changes: 13 additions & 4 deletions accounts/keystore/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"crypto/ecdsa"
crand "crypto/rand"
"errors"
"fmt"
"math/big"
"os"
"path/filepath"
Expand Down Expand Up @@ -68,7 +67,8 @@ type KeyStore struct {
updateScope event.SubscriptionScope // Subscription scope tracking current live listeners
updating bool // Whether the event notification loop is running

mu sync.RWMutex
mu sync.RWMutex // Protects the in-memory structures
importMu sync.Mutex // Protects the keystore import process
}

type unlocked struct {
Expand Down Expand Up @@ -485,6 +485,11 @@ func (ks *KeyStore) Import(keyJSON []byte, passphrase, newPassphrase string) (ac
if err != nil {
return accounts.Account{}, err
}
ks.importMu.Lock()
defer ks.importMu.Unlock()
if ks.cache.hasAddress(key.GetAddress()) {
return accounts.Account{}, errors.New("account already exists")
}
return ks.importKey(key, newPassphrase)
}

Expand Down Expand Up @@ -523,17 +528,21 @@ func (ks *KeyStore) ImportECDSAWithAddress(priv *ecdsa.PrivateKey, passphrase st
} else {
key = newKeyFromECDSA(priv)
}
ks.importMu.Lock()
defer ks.importMu.Unlock()
if ks.cache.hasAddress(key.GetAddress()) {
return accounts.Account{}, fmt.Errorf("account already exists")
return accounts.Account{}, errors.New("account already exists")
}
return ks.importKey(key, passphrase)
}

// ImportECDSA stores the given key into the key directory, encrypting it with the passphrase.
func (ks *KeyStore) ImportECDSA(priv *ecdsa.PrivateKey, passphrase string) (accounts.Account, error) {
key := newKeyFromECDSA(priv)
ks.importMu.Lock()
defer ks.importMu.Unlock()
if ks.cache.hasAddress(key.GetAddress()) {
return accounts.Account{}, fmt.Errorf("account already exists")
return accounts.Account{}, errors.New("account already exists")
}
return ks.importKey(key, passphrase)
}
Expand Down
89 changes: 85 additions & 4 deletions accounts/keystore/keystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@ import (
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/klaytn/klaytn/accounts"
"github.com/klaytn/klaytn/blockchain/types"
"github.com/klaytn/klaytn/common"
"github.com/klaytn/klaytn/crypto"
"github.com/klaytn/klaytn/event"
"github.com/klaytn/klaytn/params"
"github.com/stretchr/testify/assert"

"github.com/klaytn/klaytn/accounts"
"github.com/klaytn/klaytn/common"
"github.com/klaytn/klaytn/event"
)

var testSigData = make([]byte, 32)
Expand Down Expand Up @@ -282,6 +283,86 @@ func TestWalletNotifierLifecycle(t *testing.T) {
t.Errorf("wallet notifier didn't terminate after unsubscribe")
}

// TestImportExport tests the import functionality of a keystore.
func TestImportECDSA(t *testing.T) {
dir, ks := tmpKeyStore(t, true)
defer os.RemoveAll(dir)
key, err := crypto.GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %v", key)
}
if _, err = ks.ImportECDSA(key, "old"); err != nil {
t.Errorf("importing failed: %v", err)
}
if _, err = ks.ImportECDSA(key, "old"); err == nil {
t.Errorf("importing same key twice succeeded")
}
if _, err = ks.ImportECDSA(key, "new"); err == nil {
t.Errorf("importing same key twice succeeded")
}
}

// TestImportECDSA tests the import and export functionality of a keystore.
func TestImportExport(t *testing.T) {
dir, ks := tmpKeyStore(t, true)
defer os.RemoveAll(dir)
acc, err := ks.NewAccount("old")
if err != nil {
t.Fatalf("failed to create account: %v", acc)
}
json, err := ks.Export(acc, "old", "new")
if err != nil {
t.Fatalf("failed to export account: %v", acc)
}
dir2, ks2 := tmpKeyStore(t, true)
defer os.RemoveAll(dir2)
if _, err = ks2.Import(json, "old", "old"); err == nil {
t.Errorf("importing with invalid password succeeded")
}
acc2, err := ks2.Import(json, "new", "new")
if err != nil {
t.Errorf("importing failed: %v", err)
}
if acc.Address != acc2.Address {
t.Error("imported account does not match exported account")
}
if _, err = ks2.Import(json, "new", "new"); err == nil {
t.Errorf("importing a key twice succeeded")
}
}

// TestImportRace tests the keystore on races.
// This test should fail under -race if importing races.
func TestImportRace(t *testing.T) {
dir, ks := tmpKeyStore(t, true)
defer os.RemoveAll(dir)
acc, err := ks.NewAccount("old")
if err != nil {
t.Fatalf("failed to create account: %v", acc)
}
json, err := ks.Export(acc, "old", "new")
if err != nil {
t.Fatalf("failed to export account: %v", acc)
}
dir2, ks2 := tmpKeyStore(t, true)
defer os.RemoveAll(dir2)
var atom uint32
var wg sync.WaitGroup
wg.Add(2)
for i := 0; i < 2; i++ {
go func() {
defer wg.Done()
if _, err := ks2.Import(json, "new", "new"); err != nil {
atomic.AddUint32(&atom, 1)
}
}()
}
wg.Wait()
if atom != 1 {
t.Errorf("Import is racy")
}
}

type walletEvent struct {
accounts.WalletEvent
a accounts.Account
Expand Down

0 comments on commit 7d95568

Please sign in to comment.