Skip to content

Commit

Permalink
refactor: Assert before dereference in CWallet::GetDatabase
Browse files Browse the repository at this point in the history
  • Loading branch information
promag committed Sep 19, 2020
1 parent c6a5cd7 commit a66278f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/wallet/scriptpubkeyman.h
Expand Up @@ -33,7 +33,7 @@ class WalletStorage
public:
virtual ~WalletStorage() = default;
virtual const std::string GetDisplayName() const = 0;
virtual WalletDatabase& GetDatabase() = 0;
virtual WalletDatabase& GetDatabase() const = 0;
virtual bool IsWalletFlagSet(uint64_t) const = 0;
virtual void UnsetBlankWalletFlag(WalletBatch&) = 0;
virtual bool CanSupportFeature(enum WalletFeature) const = 0;
Expand Down
56 changes: 28 additions & 28 deletions src/wallet/wallet.cpp
Expand Up @@ -418,7 +418,7 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase,
return false;
if (!crypter.Encrypt(_vMasterKey, pMasterKey.second.vchCryptedKey))
return false;
WalletBatch(*database).WriteMasterKey(pMasterKey.first, pMasterKey.second);
WalletBatch(GetDatabase()).WriteMasterKey(pMasterKey.first, pMasterKey.second);
if (fWasLocked)
Lock();
return true;
Expand All @@ -431,7 +431,7 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase,

void CWallet::chainStateFlushed(const CBlockLocator& loc)
{
WalletBatch batch(*database);
WalletBatch batch(GetDatabase());
batch.WriteBestBlock(loc);
}

Expand All @@ -451,7 +451,7 @@ void CWallet::SetMinVersion(enum WalletFeature nVersion, WalletBatch* batch_in,
nWalletMaxVersion = nVersion;

{
WalletBatch* batch = batch_in ? batch_in : new WalletBatch(*database);
WalletBatch* batch = batch_in ? batch_in : new WalletBatch(GetDatabase());
if (nWalletVersion > 40000)
batch->WriteMinVersion(nWalletVersion);
if (!batch_in)
Expand Down Expand Up @@ -503,12 +503,12 @@ bool CWallet::HasWalletSpend(const uint256& txid) const

void CWallet::Flush()
{
database->Flush();
GetDatabase().Flush();
}

void CWallet::Close()
{
database->Close();
GetDatabase().Close();
}

void CWallet::SyncMetaData(std::pair<TxSpends::iterator, TxSpends::iterator> range)
Expand Down Expand Up @@ -634,7 +634,7 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase)
{
LOCK(cs_wallet);
mapMasterKeys[++nMasterKeyMaxID] = kMasterKey;
WalletBatch* encrypted_batch = new WalletBatch(*database);
WalletBatch* encrypted_batch = new WalletBatch(GetDatabase());
if (!encrypted_batch->TxnBegin()) {
delete encrypted_batch;
encrypted_batch = nullptr;
Expand Down Expand Up @@ -686,12 +686,12 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase)

// Need to completely rewrite the wallet file; if we don't, bdb might keep
// bits of the unencrypted private key in slack space in the database file.
database->Rewrite();
GetDatabase().Rewrite();

// BDB seems to have a bad habit of writing old data into
// slack space in .dat files; that is bad if the old data is
// unencrypted private keys. So:
database->ReloadDbEnv();
GetDatabase().ReloadDbEnv();

}
NotifyStatusChanged(this);
Expand All @@ -702,7 +702,7 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase)
DBErrors CWallet::ReorderTransactions()
{
LOCK(cs_wallet);
WalletBatch batch(*database);
WalletBatch batch(GetDatabase());

// Old wallets didn't have any defined order for transactions
// Probably a bad idea to change the output of this
Expand Down Expand Up @@ -763,7 +763,7 @@ int64_t CWallet::IncOrderPosNext(WalletBatch* batch)
if (batch) {
batch->WriteOrderPosNext(nOrderPosNext);
} else {
WalletBatch(*database).WriteOrderPosNext(nOrderPosNext);
WalletBatch(GetDatabase()).WriteOrderPosNext(nOrderPosNext);
}
return nRet;
}
Expand Down Expand Up @@ -793,7 +793,7 @@ bool CWallet::MarkReplaced(const uint256& originalHash, const uint256& newHash)

wtx.mapValue["replaced_by_txid"] = newHash.ToString();

WalletBatch batch(*database, "r+");
WalletBatch batch(GetDatabase(), "r+");

bool success = true;
if (!batch.WriteTx(wtx)) {
Expand Down Expand Up @@ -865,7 +865,7 @@ CWalletTx* CWallet::AddToWallet(CTransactionRef tx, const CWalletTx::Confirmatio
{
LOCK(cs_wallet);

WalletBatch batch(*database, "r+", fFlushOnClose);
WalletBatch batch(GetDatabase(), "r+", fFlushOnClose);

uint256 hash = tx->GetHash();

Expand Down Expand Up @@ -1064,7 +1064,7 @@ bool CWallet::AbandonTransaction(const uint256& hashTx)
{
LOCK(cs_wallet);

WalletBatch batch(*database, "r+");
WalletBatch batch(GetDatabase(), "r+");

std::set<uint256> todo;
std::set<uint256> done;
Expand Down Expand Up @@ -1127,7 +1127,7 @@ void CWallet::MarkConflicted(const uint256& hashBlock, int conflicting_height, c
return;

// Do not flush the wallet here for performance reasons
WalletBatch batch(*database, "r+", false);
WalletBatch batch(GetDatabase(), "r+", false);

std::set<uint256> todo;
std::set<uint256> done;
Expand Down Expand Up @@ -1465,13 +1465,13 @@ void CWallet::SetWalletFlag(uint64_t flags)
{
LOCK(cs_wallet);
m_wallet_flags |= flags;
if (!WalletBatch(*database).WriteWalletFlags(m_wallet_flags))
if (!WalletBatch(GetDatabase()).WriteWalletFlags(m_wallet_flags))
throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed");
}

void CWallet::UnsetWalletFlag(uint64_t flag)
{
WalletBatch batch(*database);
WalletBatch batch(GetDatabase());
UnsetWalletFlagWithDB(batch, flag);
}

Expand Down Expand Up @@ -1510,7 +1510,7 @@ bool CWallet::AddWalletFlags(uint64_t flags)
LOCK(cs_wallet);
// We should never be writing unknown non-tolerable wallet flags
assert(((flags & KNOWN_WALLET_FLAGS) >> 32) == (flags >> 32));
if (!WalletBatch(*database).WriteWalletFlags(flags)) {
if (!WalletBatch(GetDatabase()).WriteWalletFlags(flags)) {
throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed");
}

Expand Down Expand Up @@ -1601,7 +1601,7 @@ bool CWallet::ImportScriptPubKeys(const std::string& label, const std::set<CScri
return false;
}
if (apply_label) {
WalletBatch batch(*database);
WalletBatch batch(GetDatabase());
for (const CScript& script : script_pub_keys) {
CTxDestination dest;
ExtractDestination(script, dest);
Expand Down Expand Up @@ -3188,10 +3188,10 @@ DBErrors CWallet::LoadWallet(bool& fFirstRunRet)
LOCK(cs_wallet);

fFirstRunRet = false;
DBErrors nLoadWalletRet = WalletBatch(*database,"cr+").LoadWallet(this);
DBErrors nLoadWalletRet = WalletBatch(GetDatabase(), "cr+").LoadWallet(this);
if (nLoadWalletRet == DBErrors::NEED_REWRITE)
{
if (database->Rewrite("\x04pool"))
if (GetDatabase().Rewrite("\x04pool"))
{
for (const auto& spk_man_pair : m_spk_managers) {
spk_man_pair.second->RewriteDB();
Expand All @@ -3215,7 +3215,7 @@ DBErrors CWallet::LoadWallet(bool& fFirstRunRet)
DBErrors CWallet::ZapSelectTx(std::vector<uint256>& vHashIn, std::vector<uint256>& vHashOut)
{
AssertLockHeld(cs_wallet);
DBErrors nZapSelectTxRet = WalletBatch(*database, "cr+").ZapSelectTx(vHashIn, vHashOut);
DBErrors nZapSelectTxRet = WalletBatch(GetDatabase(), "cr+").ZapSelectTx(vHashIn, vHashOut);
for (const uint256& hash : vHashOut) {
const auto& it = mapWallet.find(hash);
wtxOrdered.erase(it->second.m_it_wtxOrdered);
Expand All @@ -3227,7 +3227,7 @@ DBErrors CWallet::ZapSelectTx(std::vector<uint256>& vHashIn, std::vector<uint256

if (nZapSelectTxRet == DBErrors::NEED_REWRITE)
{
if (database->Rewrite("\x04pool"))
if (GetDatabase().Rewrite("\x04pool"))
{
for (const auto& spk_man_pair : m_spk_managers) {
spk_man_pair.second->RewriteDB();
Expand Down Expand Up @@ -3265,14 +3265,14 @@ bool CWallet::SetAddressBookWithDB(WalletBatch& batch, const CTxDestination& add

bool CWallet::SetAddressBook(const CTxDestination& address, const std::string& strName, const std::string& strPurpose)
{
WalletBatch batch(*database);
WalletBatch batch(GetDatabase());
return SetAddressBookWithDB(batch, address, strName, strPurpose);
}

bool CWallet::DelAddressBook(const CTxDestination& address)
{
bool is_mine;
WalletBatch batch(*database);
WalletBatch batch(GetDatabase());
{
LOCK(cs_wallet);
// If we want to delete receiving addresses, we need to take care that DestData "used" (and possibly newer DestData) gets preserved (and the "deleted" address transformed into a change entry instead of actually being deleted)
Expand Down Expand Up @@ -4019,7 +4019,7 @@ std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain& chain, const std::st
int rescan_height = 0;
if (!gArgs.GetBoolArg("-rescan", false))
{
WalletBatch batch(*walletInstance->database);
WalletBatch batch(walletInstance->GetDatabase());
CBlockLocator locator;
if (batch.ReadBestBlock(locator)) {
if (const Optional<int> fork_height = chain.findLocatorFork(locator)) {
Expand Down Expand Up @@ -4082,7 +4082,7 @@ std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain& chain, const std::st
}
}
walletInstance->chainStateFlushed(chain.getTipLocator());
walletInstance->database->IncrementUpdateCounter();
walletInstance->GetDatabase().IncrementUpdateCounter();
}

{
Expand Down Expand Up @@ -4163,7 +4163,7 @@ void CWallet::postInitProcess()

bool CWallet::BackupWallet(const std::string& strDest) const
{
return database->Backup(strDest);
return GetDatabase().Backup(strDest);
}

CKeyPool::CKeyPool()
Expand Down Expand Up @@ -4466,7 +4466,7 @@ void CWallet::SetupDescriptorScriptPubKeyMans()

void CWallet::AddActiveScriptPubKeyMan(uint256 id, OutputType type, bool internal)
{
WalletBatch batch(*database);
WalletBatch batch(GetDatabase());
if (!batch.WriteActiveScriptPubKeyMan(static_cast<uint8_t>(type), id, internal)) {
throw std::runtime_error(std::string(__func__) + ": writing active ScriptPubKeyMan id failed");
}
Expand Down
10 changes: 7 additions & 3 deletions src/wallet/wallet.h
Expand Up @@ -698,7 +698,7 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati
std::string m_name;

/** Internal database handle. */
std::unique_ptr<WalletDatabase> database;
std::unique_ptr<WalletDatabase> m_database;

/**
* The following is used to keep track of how far behind the wallet is
Expand Down Expand Up @@ -732,7 +732,11 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati
*/
mutable RecursiveMutex cs_wallet;

WalletDatabase& GetDatabase() override { return *database; }
WalletDatabase& GetDatabase() const override
{
assert(static_cast<bool>(m_database));
return *m_database;
}

/**
* Select a set of coins such that nValueRet >= nTargetValue and at least
Expand All @@ -754,7 +758,7 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati
CWallet(interfaces::Chain* chain, const std::string& name, std::unique_ptr<WalletDatabase> database)
: m_chain(chain),
m_name(name),
database(std::move(database))
m_database(std::move(database))
{
}

Expand Down

0 comments on commit a66278f

Please sign in to comment.