diff --git a/src/Features/Blockcore.Features.ColdStaking/ColdStakingManager.cs b/src/Features/Blockcore.Features.ColdStaking/ColdStakingManager.cs index 37949af40..b60ef862c 100644 --- a/src/Features/Blockcore.Features.ColdStaking/ColdStakingManager.cs +++ b/src/Features/Blockcore.Features.ColdStaking/ColdStakingManager.cs @@ -258,12 +258,15 @@ public HdAccount GetOrCreateColdStakingAccount(string walletName, bool isColdWal /// track addresses under the HD account. /// /// - public override Wallet.Types.Wallet RecoverWallet(string password, string name, string mnemonic, DateTime creationTime, string passphrase, int? coinType = null) + public override Wallet.Types.Wallet RecoverWallet(string password, string name, string mnemonic, DateTime creationTime, string passphrase, int? coinType = null, bool? isColdStakingWallet = false) { Wallet.Types.Wallet wallet = base.RecoverWallet(password, name, mnemonic, creationTime, passphrase, coinType); - this.GetOrCreateColdStakingAccount(wallet.Name, false, password); - this.GetOrCreateColdStakingAccount(wallet.Name, true, password); + if (isColdStakingWallet.HasValue && isColdStakingWallet == true) + { + this.GetOrCreateColdStakingAccount(wallet.Name, false, password); + this.GetOrCreateColdStakingAccount(wallet.Name, true, password); + } return wallet; } diff --git a/src/Features/Blockcore.Features.Wallet/Api/Controllers/WalletController.cs b/src/Features/Blockcore.Features.Wallet/Api/Controllers/WalletController.cs index 07cfb8a02..4241de09c 100644 --- a/src/Features/Blockcore.Features.Wallet/Api/Controllers/WalletController.cs +++ b/src/Features/Blockcore.Features.Wallet/Api/Controllers/WalletController.cs @@ -264,7 +264,7 @@ public IActionResult Recover([FromBody] WalletRecoveryRequest request) try { - Types.Wallet wallet = this.walletManager.RecoverWallet(request.Password, request.Name, request.Mnemonic, request.CreationDate, passphrase: request.Passphrase, request.CoinType); + Types.Wallet wallet = this.walletManager.RecoverWallet(request.Password, request.Name, request.Mnemonic, request.CreationDate, passphrase: request.Passphrase, request.CoinType, request.IsColdStakingWallet); this.SyncFromBestHeightForRecoveredWallets(request.CreationDate); diff --git a/src/Features/Blockcore.Features.Wallet/Api/Models/RequestModels.cs b/src/Features/Blockcore.Features.Wallet/Api/Models/RequestModels.cs index 31b59fb65..384c1be31 100644 --- a/src/Features/Blockcore.Features.Wallet/Api/Models/RequestModels.cs +++ b/src/Features/Blockcore.Features.Wallet/Api/Models/RequestModels.cs @@ -124,6 +124,11 @@ public class WalletRecoveryRequest : RequestModel /// Optional CoinType to overwrite the default . /// public int? CoinType { get; set; } + + /// + /// Optional flag that indicates if the "coldStakingColdAddresses" and "coldStakingHotAddresses" accounts should be restored. + /// + public bool? IsColdStakingWallet { get; set; } } /// diff --git a/src/Features/Blockcore.Features.Wallet/Interfaces/IWalletManager.cs b/src/Features/Blockcore.Features.Wallet/Interfaces/IWalletManager.cs index 7ae361253..529e23a32 100644 --- a/src/Features/Blockcore.Features.Wallet/Interfaces/IWalletManager.cs +++ b/src/Features/Blockcore.Features.Wallet/Interfaces/IWalletManager.cs @@ -147,7 +147,7 @@ public interface IWalletManager /// The date and time this wallet was created. /// Allow to override the default BIP44 cointype. /// The recovered wallet. - Types.Wallet RecoverWallet(string password, string name, string mnemonic, DateTime creationTime, string passphrase = null, int? coinType = null); + Types.Wallet RecoverWallet(string password, string name, string mnemonic, DateTime creationTime, string passphrase = null, int? coinType = null, bool? isColdStakingWallet = false); /// /// Recovers a wallet using extended public key and account index. diff --git a/src/Features/Blockcore.Features.Wallet/WalletManager.cs b/src/Features/Blockcore.Features.Wallet/WalletManager.cs index 4444fe607..2992c3876 100644 --- a/src/Features/Blockcore.Features.Wallet/WalletManager.cs +++ b/src/Features/Blockcore.Features.Wallet/WalletManager.cs @@ -461,7 +461,7 @@ private SecureString CacheSecret(string name, string walletPassword, TimeSpan du } /// - public virtual Types.Wallet RecoverWallet(string password, string name, string mnemonic, DateTime creationTime, string passphrase, int? coinType = null) + public virtual Types.Wallet RecoverWallet(string password, string name, string mnemonic, DateTime creationTime, string passphrase, int? coinType = null, bool? isColdStakingWallet = false) { Guard.NotEmpty(password, nameof(password)); Guard.NotEmpty(name, nameof(name)); diff --git a/src/Tests/Blockcore.Features.Wallet.Tests/WalletControllerTest.cs b/src/Tests/Blockcore.Features.Wallet.Tests/WalletControllerTest.cs index 91b59dfba..27064ea6e 100644 --- a/src/Tests/Blockcore.Features.Wallet.Tests/WalletControllerTest.cs +++ b/src/Tests/Blockcore.Features.Wallet.Tests/WalletControllerTest.cs @@ -330,7 +330,7 @@ public void RecoverWalletSuccessfullyReturnsWalletModel() }; var mockWalletManager = new Mock(); - mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null)).Returns(wallet); + mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null, true)).Returns(wallet); Mock walletSyncManager = new Mock(); walletSyncManager.Setup(w => w.WalletTip).Returns(new ChainedHeader(this.Network.GetGenesis().Header, this.Network.GetGenesis().Header.GetHash(), 3)); @@ -368,7 +368,7 @@ public void RecoverWalletWithDatedAfterCurrentSyncHeightDoesNotMoveSyncHeight() DateTime lastBlockDateTime = chainIndexer.Tip.Header.BlockTime.DateTime; var mockWalletManager = new Mock(); - mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null)).Returns(wallet); + mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null, true)).Returns(wallet); Mock walletSyncManager = new Mock(); walletSyncManager.Setup(w => w.WalletTip).Returns(new ChainedHeader(this.Network.GetGenesis().Header, this.Network.GetGenesis().Header.GetHash(), 3)); @@ -419,7 +419,7 @@ public void RecoverWalletWithInvalidOperationExceptionReturnsConflict() { string errorMessage = "An error occurred."; var mockWalletManager = new Mock(); - mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null)) + mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null, true)) .Throws(new WalletException(errorMessage)); var controller = new WalletController(this.LoggerFactory.Object, mockWalletManager.Object, new Mock().Object, new Mock().Object, It.IsAny(), this.Network, this.chainIndexer, new Mock().Object, DateTimeProvider.Default); @@ -445,7 +445,7 @@ public void RecoverWalletWithInvalidOperationExceptionReturnsConflict() public void RecoverWalletWithFileNotFoundExceptionReturnsNotFound() { var mockWalletManager = new Mock(); - mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null)) + mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null, true)) .Throws(new FileNotFoundException("File not found.")); var controller = new WalletController(this.LoggerFactory.Object, mockWalletManager.Object, new Mock().Object, new Mock().Object, It.IsAny(), this.Network, this.chainIndexer, new Mock().Object, DateTimeProvider.Default); @@ -472,7 +472,7 @@ public void RecoverWalletWithFileNotFoundExceptionReturnsNotFound() public void RecoverWalletWithExceptionReturnsBadRequest() { var mockWalletManager = new Mock(); - mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null)) + mockWalletManager.Setup(w => w.RecoverWallet(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), null, null, true)) .Throws(new FormatException("Formatting failed.")); var controller = new WalletController(this.LoggerFactory.Object, mockWalletManager.Object, new Mock().Object, new Mock().Object, It.IsAny(), this.Network, this.chainIndexer, new Mock().Object, DateTimeProvider.Default);