Skip to content

Commit

Permalink
Add wallet function to bump fee with RBF (#2392)
Browse files Browse the repository at this point in the history
* Add wallet function to bump fee

* Bump sequence number

* Respond to review

* Fix test
  • Loading branch information
benthecarman committed Dec 20, 2020
1 parent bc65614 commit 64a6b6b
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ trait WalletApi extends StartStopAsync[WalletApi] {
amounts: Vector[CurrencyUnit],
feeRate: FeeUnit)(implicit ec: ExecutionContext): Future[Transaction]

def bumpFeeRBF(
txId: DoubleSha256DigestBE,
newFeeRate: FeeUnit): Future[Transaction]

def makeOpReturnCommitment(
message: String,
hashMessage: Boolean,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.bitcoins.core.wallet.builder

import org.bitcoins.core.currency.{CurrencyUnit, Satoshis}
import org.bitcoins.core.number.Int64
import org.bitcoins.core.number.{Int64, UInt32}
import org.bitcoins.core.policy.Policy
import org.bitcoins.core.protocol.script.{ScriptPubKey, ScriptSignature}
import org.bitcoins.core.protocol.transaction._
Expand Down Expand Up @@ -66,8 +66,10 @@ abstract class FinalizerFactory[T <: RawTxFinalizer] {
outputs: Seq[TransactionOutput],
utxos: Seq[InputSigningInfo[InputInfo]],
feeRate: FeeUnit,
changeSPK: ScriptPubKey): RawTxBuilderWithFinalizer[T] = {
val inputs = InputUtil.calcSequenceForInputs(utxos)
changeSPK: ScriptPubKey,
defaultSequence: UInt32 = Policy.sequence): RawTxBuilderWithFinalizer[
T] = {
val inputs = InputUtil.calcSequenceForInputs(utxos, defaultSequence)
val lockTime = TxUtil.calcLockTime(utxos).get
val builder = RawTxBuilder().setLockTime(lockTime) ++= outputs ++= inputs
val finalizer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import org.bitcoins.core.script.control.OP_RETURN
import org.bitcoins.core.wallet.fee._
import org.bitcoins.core.wallet.utxo.TxoState
import org.bitcoins.crypto.CryptoUtil
import org.bitcoins.testkit.Implicits.GeneratorOps
import org.bitcoins.testkit.core.gen.FeeUnitGen
import org.bitcoins.testkit.wallet.BitcoinSWalletTest
import org.bitcoins.testkit.wallet.BitcoinSWalletTest.RandomFeeProvider
import org.bitcoins.testkit.wallet.FundWalletUtil.FundedWallet
Expand Down Expand Up @@ -255,6 +257,34 @@ class WalletSendingTest extends BitcoinSWalletTest {
}
}

it should "correctly bump the fee rate of a transaction" in { fundedWallet =>
val wallet = fundedWallet.wallet

val feeRate = FeeUnitGen.satsPerByte.sampleSome

for {
tx <- wallet.sendToAddress(testAddress, amountToSend, feeRate)

firstBal <- wallet.getBalance()

newFeeRate = SatoshisPerByte(feeRate.currencyUnit + Satoshis.one)
bumpedTx <- wallet.bumpFeeRBF(tx.txIdBE, newFeeRate)

txDb1Opt <- wallet.outgoingTxDAO.findByTxId(tx.txIdBE)
txDb2Opt <- wallet.outgoingTxDAO.findByTxId(bumpedTx.txIdBE)

secondBal <- wallet.getBalance()
} yield {
assert(txDb1Opt.isDefined)
assert(txDb2Opt.isDefined)
val txDb1 = txDb1Opt.get
val txDb2 = txDb2Opt.get

assert(txDb1.actualFee < txDb2.actualFee)
assert(firstBal - secondBal == txDb2.actualFee - txDb1.actualFee)
}
}

it should "fail to send from outpoints when already spent" in {
fundedWallet =>
val wallet = fundedWallet.wallet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,25 @@ class AddressDAOTest extends WalletDAOFixture {
} yield assert(readAddress.contains(createdAddress))
}

it should "find by script pub key" in { daos =>
val addressDAO = daos.addressDAO

val addr1 = WalletTestUtil.getAddressDb(WalletTestUtil.firstAccountDb)
val addr2 = WalletTestUtil.getAddressDb(WalletTestUtil.firstAccountDb,
addressIndex = 1)
val addr3 = WalletTestUtil.getAddressDb(WalletTestUtil.firstAccountDb,
addressIndex = 2)
val spks = Vector(addr1.scriptPubKey, addr2.scriptPubKey)

for {
created1 <- addressDAO.create(addr1)
created2 <- addressDAO.create(addr2)
created3 <- addressDAO.create(addr3)
found <- addressDAO.findByScriptPubKeys(spks)
} yield {
assert(found.contains(created1))
assert(found.contains(created2))
assert(!found.contains(created3))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.bitcoins.wallet.models
import org.bitcoins.core.api.wallet.db.{
LegacySpendingInfo,
NestedSegwitV0SpendingInfo,
ScriptPubKeyDb,
SegwitV0SpendingInfo
}
import org.bitcoins.core.protocol.script.ScriptSignature
Expand Down Expand Up @@ -204,4 +205,28 @@ class SpendingInfoDAOTest extends WalletDAOFixture {
case Some(other) => fail(s"did not get a nested segwit UTXO: $other")
}
}

it should "find incoming outputs dbs being spent, given a TX" in { daos =>
val utxoDAO = daos.utxoDAO

for {
created <- WalletTestUtil.insertNestedSegWitUTXO(daos)
db <- utxoDAO.read(created.id.get)

account <- daos.accountDAO.create(WalletTestUtil.firstAccountDb)
addr <- daos.addressDAO.create(getAddressDb(account))

// Add another utxo
u2 = WalletTestUtil.sampleSegwitUTXO(addr.scriptPubKey)
_ <- insertDummyIncomingTransaction(daos, u2)
_ <- utxoDAO.create(u2)

dbs <- utxoDAO.findDbsForTx(created.txid)
} yield {
assert(dbs.size == 1)
assert(db.isDefined)

assert(dbs == Vector(db.get))
}
}
}
95 changes: 86 additions & 9 deletions wallet/src/main/scala/org/bitcoins/wallet/Wallet.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.bitcoins.wallet

import java.time.Instant

import org.bitcoins.commons.jsonmodels.wallet.SyncHeightDescriptor
import org.bitcoins.core.api.chain.ChainQueryApi
import org.bitcoins.core.api.feeprovider.FeeRateApi
Expand All @@ -14,6 +12,7 @@ import org.bitcoins.core.crypto.ExtPublicKey
import org.bitcoins.core.currency._
import org.bitcoins.core.gcs.{GolombFilter, SimpleFilterMatcher}
import org.bitcoins.core.hd._
import org.bitcoins.core.number.UInt32
import org.bitcoins.core.protocol.BitcoinAddress
import org.bitcoins.core.protocol.blockchain.ChainParams
import org.bitcoins.core.protocol.script.ScriptPubKey
Expand All @@ -33,20 +32,16 @@ import org.bitcoins.core.wallet.utxo.TxoState.{
PendingConfirmationsReceived
}
import org.bitcoins.core.wallet.utxo._
import org.bitcoins.crypto.{
AesPassword,
CryptoUtil,
DoubleSha256Digest,
ECPublicKey
}
import org.bitcoins.crypto._
import org.bitcoins.keymanager.bip39.{BIP39KeyManager, BIP39LockedKeyManager}
import org.bitcoins.wallet.config.WalletAppConfig
import org.bitcoins.wallet.internal._
import org.bitcoins.wallet.models._
import scodec.bits.ByteVector

import java.time.Instant
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
import scala.util.{Failure, Random, Success}

abstract class Wallet
extends AnyHDWalletApi
Expand Down Expand Up @@ -499,6 +494,88 @@ abstract class Wallet
} yield tx
}

override def bumpFeeRBF(
txId: DoubleSha256DigestBE,
newFeeRate: FeeUnit): Future[Transaction] = {
for {
txDbOpt <- transactionDAO.findByTxId(txId)
tx <- txDbOpt match {
case Some(db) => Future.successful(db.transaction)
case None =>
Future.failed(
new RuntimeException(s"Unable to find transaction ${txId.hex}"))
}

outPoints = tx.inputs.map(_.previousOutput).toVector
spks = tx.outputs.map(_.scriptPubKey).toVector

utxos <- spendingInfoDAO.findByOutPoints(outPoints)
_ = require(utxos.nonEmpty, "Can only bump fee for our own transaction")
_ = require(utxos.size == tx.inputs.size,
"Can only bump fee for a transaction we own all the inputs")
spendingInfos <- FutureUtil.sequentially(utxos) { utxo =>
transactionDAO
.findByOutPoint(utxo.outPoint)
.map(txDbOpt =>
utxo.toUTXOInfo(keyManager = keyManager, txDbOpt.get.transaction))
}

_ = {
val inputAmount = utxos.foldLeft(CurrencyUnits.zero)(_ + _.output.value)

val oldFeeRate = newFeeRate match {
case _: SatoshisPerByte =>
SatoshisPerByte.calc(inputAmount, tx)
case _: SatoshisPerKiloByte =>
SatoshisPerKiloByte.calc(inputAmount, tx)
case _: SatoshisPerVirtualByte =>
SatoshisPerVirtualByte.calc(inputAmount, tx)
case _: SatoshisPerKW =>
SatoshisPerKW.calc(inputAmount, tx)
}

require(
oldFeeRate.currencyUnit < newFeeRate.currencyUnit,
s"Cannot bump to a lower fee ${oldFeeRate.currencyUnit} < ${newFeeRate.currencyUnit}")
}

myAddrs <- addressDAO.findByScriptPubKeys(spks)
_ = require(myAddrs.nonEmpty, "Must have an output we own")

changeSpks = myAddrs.flatMap { db =>
if (db.path.chain.chainType == HDChainType.Change) {
Some(db.scriptPubKey)
} else None
}

changeSpk =
if (changeSpks.nonEmpty) {
// Pick a random change spk
Random.shuffle(changeSpks).head
} else {
// If none are explicit change, pick a random one we own
Random.shuffle(myAddrs.map(_.scriptPubKey)).head
}

// Mark old outputs as replaced
oldUtxos <- spendingInfoDAO.findDbsForTx(txId)
_ <- spendingInfoDAO.updateAll(
oldUtxos.map(_.copyWithState(TxoState.DoesNotExist)))

sequence = tx.inputs.head.sequence + UInt32.one
outputs = tx.outputs.filterNot(_.scriptPubKey == changeSpk)
txBuilder = StandardNonInteractiveFinalizer.txBuilderFrom(outputs,
spendingInfos,
newFeeRate,
changeSpk,
sequence)

amount = outputs.foldLeft(CurrencyUnits.zero)(_ + _.value)
tx <-
finishSend(txBuilder, spendingInfos, amount, newFeeRate, Vector.empty)
} yield tx
}

override def sendWithAlgo(
address: BitcoinAddress,
amount: CurrencyUnit,
Expand Down
15 changes: 15 additions & 0 deletions wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,21 @@ case class AddressDAO()(implicit
})
}

def findByScriptPubKeys(
spks: Vector[ScriptPubKey]): Future[Vector[AddressDb]] = {
val query = table
.join(spkTable)
.on(_.scriptPubKeyId === _.id)
.filter(_._2.scriptPubKey.inSet(spks))

safeDatabase
.runVec(query.result.transactionally)
.map(res =>
res.map {
case (addrRec, spkRec) => addrRec.toAddressDb(spkRec.scriptPubKey)
})
}

private def findMostRecentForChain(account: HDAccount, chain: HDChainType) = {
addressesForAccountQuery(account.index)
.filter(_._1.purpose === account.purpose)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,15 @@ case class SpendingInfoDAO()(implicit

}

/**
* Fetches all the incoming TXOs in our DB that are in
* the transaction with the given TXID
*/
def findDbsForTx(txid: DoubleSha256DigestBE): Future[Vector[UTXORecord]] = {
val query = table.filter(_.txid === txid)
safeDatabase.runVec(query.result)
}

/**
* Fetches all the incoming TXOs in our DB that are in
* the transaction with the given TXID
Expand Down

0 comments on commit 64a6b6b

Please sign in to comment.