diff --git a/doc/abc_update_logs.md b/doc/abc_update_logs.md index 69cf33f89..805560bbd 100644 --- a/doc/abc_update_logs.md +++ b/doc/abc_update_logs.md @@ -372,7 +372,7 @@ XXXXX - Partial upgrade of wallet stuff [SECP256k1] Add the CMake/Ninja build to Travis [CMAKE] Add a check-extended target Log env path in BerkeleyEnvironment::Flush - wallet: detecting duplicate wallet by comparing the db filename. +XXXXX wallet: detecting duplicate wallet by comparing the db filename. ##### [bugfix] wallet: Fix duplicate fileid detection ##### [wallet] Reopen CDBEnv after encryption instead of shutting down Make ECM error message more helpful diff --git a/src/wallet/db.cpp b/src/wallet/db.cpp index 6f545cea4..e09fd0302 100644 --- a/src/wallet/db.cpp +++ b/src/wallet/db.cpp @@ -67,9 +67,9 @@ bool WalletDatabaseFileId::operator==(const WalletDatabaseFileId &rhs) const { return memcmp(value, &rhs.value, sizeof(value)) == 0; } -BerkeleyEnvironment *GetWalletEnv(const fs::path &wallet_path, - std::string &database_filename) { - fs::path env_directory; +static void SplitWalletPath(const fs::path &wallet_path, + fs::path &env_directory, + std::string &database_filename) { if (fs::is_regular_file(wallet_path)) { // Special case for backwards compatibility: if wallet path points to an // existing file, treat it as the path to a BDB data file in a parent @@ -82,6 +82,26 @@ BerkeleyEnvironment *GetWalletEnv(const fs::path &wallet_path, env_directory = wallet_path; database_filename = "wallet.dat"; } +} + +bool IsWalletLoaded(const fs::path &wallet_path) { + fs::path env_directory; + std::string database_filename; + SplitWalletPath(wallet_path, env_directory, database_filename); + + LOCK(cs_db); + auto env = g_dbenvs.find(env_directory.string()); + if (env == g_dbenvs.end()) { + return false; + } + + return env->second.IsDatabaseLoaded(database_filename); +} + +BerkeleyEnvironment *GetWalletEnv(const fs::path &wallet_path, + std::string &database_filename) { + fs::path env_directory; + SplitWalletPath(wallet_path, env_directory, database_filename); LOCK(cs_db); // Note: An ununsed temporary BerkeleyEnvironment object may be created // inside the emplace function if the key already exists. This is a little @@ -105,13 +125,13 @@ void BerkeleyEnvironment::Close() { fDbEnvInit = false; - for (auto &db : mapDb) { + for (auto &db : m_databases) { auto count = mapFileUseCount.find(db.first); assert(count == mapFileUseCount.end() || count->second == 0); - if (db.second) { - db.second->close(0); - delete db.second; - db.second = nullptr; + BerkeleyDatabase &database = db.second.get(); + if (database.m_db) { + database.m_db->close(0); + database.m_db.reset(); } } @@ -518,7 +538,7 @@ BerkeleyBatch::BerkeleyBatch(BerkeleyDatabase &database, const char *pszMode, "BerkeleyBatch: Failed to open database environment."); } - pdb = env->mapDb[strFilename]; + pdb = database.m_db.get(); if (pdb == nullptr) { int ret; std::unique_ptr pdb_temp = @@ -571,7 +591,7 @@ BerkeleyBatch::BerkeleyBatch(BerkeleyDatabase &database, const char *pszMode, } pdb = pdb_temp.release(); - env->mapDb[strFilename] = pdb; + database.m_db.reset(pdb); if (fCreate && !Exists(std::string("version"))) { bool fTmp = fReadOnly; @@ -626,12 +646,13 @@ void BerkeleyBatch::Close() { void BerkeleyEnvironment::CloseDb(const std::string &strFile) { LOCK(cs_db); - if (mapDb[strFile] != nullptr) { + auto it = m_databases.find(strFile); + assert(it != m_databases.end()); + BerkeleyDatabase &database = it->second.get(); + if (database.m_db) { // Close the database handle - Db *pdb = mapDb[strFile]; - pdb->close(0); - delete pdb; - mapDb[strFile] = nullptr; + database.m_db->close(0); + database.m_db.reset(); } } diff --git a/src/wallet/db.h b/src/wallet/db.h index 3d902369f..5d6299c01 100644 --- a/src/wallet/db.h +++ b/src/wallet/db.h @@ -29,6 +29,8 @@ struct WalletDatabaseFileId { bool operator==(const WalletDatabaseFileId &rhs) const; }; +class BerkeleyDatabase; + class BerkeleyEnvironment { private: bool fDbEnvInit; @@ -41,7 +43,7 @@ class BerkeleyEnvironment { public: std::unique_ptr dbenv; std::map mapFileUseCount; - std::map mapDb; + std::map> m_databases; std::unordered_map m_fileids; std::condition_variable_any m_db_in_use; @@ -52,6 +54,9 @@ class BerkeleyEnvironment { void MakeMock(); bool IsMock() const { return fMockDb; } bool IsInitialized() const { return fDbEnvInit; } + bool IsDatabaseLoaded(const std::string &db_filename) const { + return m_databases.find(db_filename) != m_databases.end(); + } fs::path Directory() const { return strPath; } /** @@ -94,6 +99,9 @@ class BerkeleyEnvironment { } }; +/** Return whether a wallet database is currently loaded. */ +bool IsWalletLoaded(const fs::path &wallet_path); + /** Get BerkeleyEnvironment and database filename given a wallet path. */ BerkeleyEnvironment *GetWalletEnv(const fs::path &wallet_path, std::string &database_filename); @@ -116,6 +124,8 @@ class BerkeleyDatabase { : nUpdateCounter(0), nLastSeen(0), nLastFlushed(0), nLastWalletUpdate(0) { env = GetWalletEnv(wallet_path, strFile); + auto inserted = env->m_databases.emplace(strFile, std::ref(*this)); + assert(inserted.second); if (mock) { env->Close(); env->Reset(); @@ -123,6 +133,13 @@ class BerkeleyDatabase { } } + ~BerkeleyDatabase() { + if (env) { + size_t erased = env->m_databases.erase(strFile); + assert(erased == 1); + } + } + /** Return object for accessing database at specified path. */ static std::unique_ptr Create(const fs::path &path) { return std::make_unique(path); @@ -166,6 +183,12 @@ class BerkeleyDatabase { unsigned int nLastFlushed; int64_t nLastWalletUpdate; + /** + * Database pointer. This is initialized lazily and reset during flushes, + * so it can be null. + */ + std::unique_ptr m_db; + private: /** BerkeleyDB specific */ BerkeleyEnvironment *env; diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index c59c08137..5165272f0 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -4670,6 +4670,90 @@ CWallet::GetDestValues(const std::string &prefix) const { } return values; } +#ifdef ADD_VERIFY +bool CWallet::Verify(const CChainParams &chainParams, + const WalletLocation &location, bool salvage_wallet, + std::string &error_string, std::string &warning_string) { + // Do some checking on wallet path. It should be either a: + // + // 1. Path where a directory can be created. + // 2. Path to an existing directory. + // 3. Path to a symlink to a directory. + // 4. For backwards compatibility, the name of a data file in -walletdir. + LOCK(cs_wallets); + const fs::path &wallet_path = location.GetPath(); + fs::file_type path_type = fs::symlink_status(wallet_path).type(); + if (!(path_type == fs::file_not_found || path_type == fs::directory_file || + (path_type == fs::symlink_file && fs::is_directory(wallet_path)) || + (path_type == fs::regular_file && + fs::path(location.GetName()).filename() == location.GetName()))) { + error_string = + strprintf("Invalid -wallet path '%s'. -wallet path should point to " + "a directory where wallet.dat and " + "database/log.?????????? files can be stored, a location " + "where such a directory could be created, " + "or (for backwards compatibility) the name of an " + "existing data file in -walletdir (%s)", + location.GetName(), GetWalletDir()); + return false; + } + + // Make sure that the wallet path doesn't clash with an existing wallet path + if (IsWalletLoaded(wallet_path)) { + error_string = strprintf( + "Error loading wallet %s. Duplicate -wallet filename specified.", + location.GetName()); + return false; + } + + try { + if (!WalletBatch::VerifyEnvironment(wallet_path, error_string)) { + return false; + } + } catch (const fs::filesystem_error &e) { + error_string = strprintf("Error loading wallet %s. %s", + location.GetName(), e.what()); + return false; + } + + if (salvage_wallet) { + // Recover readable keypairs: + CWallet dummyWallet(chainParams, WalletLocation(), + WalletDatabase::CreateDummy()); + std::string backup_filename; + if (!WalletBatch::Recover( + wallet_path, static_cast(&dummyWallet), + WalletBatch::RecoverKeysOnlyFilter, backup_filename)) { + return false; + } + } + + return WalletBatch::VerifyDatabaseFile(wallet_path, warning_string, + error_string); +} + +#endif +#ifdef USE_PRESPLIT +void CWallet::MarkPreSplitKeys() { + WalletBatch batch(*database); + for (auto it = setExternalKeyPool.begin(); + it != setExternalKeyPool.end();) { + int64_t index = *it; + CKeyPool keypool; + if (!batch.ReadPool(index, keypool)) { + throw std::runtime_error(std::string(__func__) + + ": read keypool entry failed"); + } + keypool.m_pre_split = true; + if (!batch.WritePool(index, keypool)) { + throw std::runtime_error(std::string(__func__) + + ": writing modified keypool entry failed"); + } + set_pre_split_keypool.insert(index); + it = setExternalKeyPool.erase(it); + } +} +#endif CWallet *CWallet::CreateWalletFromFile(const CChainParams &chainParams, const WalletLocation &location, diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 9ec0a5354..8412e62ca 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -778,6 +778,7 @@ class CWallet final : public CCryptoKeyStore, public CValidationInterface { void LoadKeyPool(int64_t nIndex, const CKeyPool &keypool) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); + //void MarkPreSplitKeys() EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); // Map from Key ID to key metadata. std::map mapKeyMetadata; diff --git a/test/functional/wallet_multiwallet.py b/test/functional/wallet_multiwallet.py index c8b23aaf7..43f2c7ddc 100755 --- a/test/functional/wallet_multiwallet.py +++ b/test/functional/wallet_multiwallet.py @@ -203,6 +203,10 @@ def wallet(name): return node.get_wallet_rpc(name) assert_raises_rpc_error(-4, 'Wallet file verification failed: Error loading wallet w1. Duplicate -wallet filename specified.', self.nodes[0].loadwallet, wallet_names[0]) + # Fail to load duplicate wallets by different ways (directory and filepath) + assert_raises_rpc_error(-4, "Wallet file verification failed: Error loading wallet wallet.dat. Duplicate -wallet filename specified.", + self.nodes[0].loadwallet, 'wallet.dat') + # Fail to load if one wallet is a copy of another assert_raises_rpc_error(-1, "BerkeleyBatch: Can't open database w8_copy (duplicates fileid", self.nodes[0].loadwallet, 'w8_copy')