Skip to content

Commit

Permalink
Fix TLV parsing for non-standard strings (#2312)
Browse files Browse the repository at this point in the history
* Fix TLV parsing for non-standard strings

* Create function

* Fix oracle migrations

* Forced all TLV strings to be normalized implicitly

* Removed redundant normalization

* Fix oracle

* Bump migration test

* Fix 2.12.12 compile

* Use NetworkElement & StringFactory

Co-authored-by: nkohen <nadavk25@gmail.com>
  • Loading branch information
benthecarman and nkohen committed Dec 3, 2020
1 parent fd08c98 commit b3d70f5
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ case class OracleRoutes(oracle: DLCOracle)(implicit
case Some(event: OracleEvent) =>
val outcomesJson = event.eventDescriptorTLV match {
case enum: EnumEventDescriptorV0TLV =>
enum.outcomes.map(Str)
enum.outcomes.map(outcome => Str(outcome.normStr))
case range: RangeEventDescriptorV0TLV =>
val outcomes: Vector[Long] = {
val startL = range.start.toLong
Expand Down
3 changes: 3 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,12 @@ lazy val walletTest = project
.dependsOn(core % testAndCompile, testkit, wallet)
.enablePlugins(FlywayPlugin)

lazy val oracleDbSettings = dbFlywaySettings("oracle")

lazy val dlcOracle = project
.in(file("dlc-oracle"))
.settings(CommonSettings.prodSettings: _*)
.settings(oracleDbSettings: _*)
.settings(
name := "bitcoin-s-dlc-oracle",
libraryDependencies ++= Deps.dlcOracle
Expand Down
122 changes: 85 additions & 37 deletions core/src/main/scala/org/bitcoins/core/protocol/tlv/TLV.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ object TLV extends TLVParentFactory[TLV] {
case None => UnknownTLV(tpe, value)
}
}

def getStringBytes(str: NormalizedString): ByteVector = {
val strBytes = str.bytes
val size = BigSizeUInt(strBytes.size)

size.bytes ++ strBytes
}
}

sealed trait TLVFactory[+T <: TLV] extends Factory[T] {
Expand Down Expand Up @@ -141,11 +148,11 @@ sealed trait TLVFactory[+T <: TLV] extends Factory[T] {
}
}

def takeString(): String = {
def takeString(): NormalizedString = {
val size = BigSizeUInt(current)
skip(size.byteSize)
val strBytes = take(size.toInt)
new String(strBytes.toArray, StandardCharsets.UTF_8)
NormalizedString(strBytes)
}

def takeSPK(): ScriptPubKey = {
Expand All @@ -155,6 +162,51 @@ sealed trait TLVFactory[+T <: TLV] extends Factory[T] {
}
}

case class NormalizedString(private val str: String) extends NetworkElement {

val normStr: String = CryptoUtil.normalize(str)

override def equals(other: Any): Boolean = {
other match {
case otherStr: String =>
normStr == otherStr
case _ => other.equals(str)
}
}

override def toString: String = normStr

override def bytes: ByteVector = CryptoUtil.serializeForHash(normStr)
}

object NormalizedString extends StringFactory[NormalizedString] {

def apply(bytes: ByteVector): NormalizedString = {
NormalizedString(new String(bytes.toArray, StandardCharsets.UTF_8))
}

import scala.language.implicitConversions

implicit def stringToNormalized(str: String): NormalizedString =
NormalizedString(str)

implicit def normalizedToString(normalized: NormalizedString): String =
normalized.normStr

// If other kinds of Iterables are needed, there's a fancy thing to do
// that is done all over the Seq code using params and an implicit CanBuildFrom
implicit def stringVecToNormalized(
strs: Vector[String]): Vector[NormalizedString] =
strs.map(apply)

implicit def normalizedVecToString(
strs: Vector[NormalizedString]): Vector[String] =
strs.map(_.normStr)

override def fromString(string: String): NormalizedString =
NormalizedString(string)
}

case class UnknownTLV(tpe: BigSizeUInt, value: ByteVector) extends TLV {
require(!TLV.knownTypes.contains(tpe), s"Type $tpe is known")
}
Expand Down Expand Up @@ -253,16 +305,16 @@ object EventDescriptorTLV extends TLVParentFactory[EventDescriptorTLV] {
* @param outcomes The set of possible outcomes
* @see https://github.com/discreetlogcontracts/dlcspecs/blob/master/Oracle.md#simple-enumeration
*/
case class EnumEventDescriptorV0TLV(outcomes: Vector[String])
case class EnumEventDescriptorV0TLV(outcomes: Vector[NormalizedString])
extends EventDescriptorTLV {
override def tpe: BigSizeUInt = EnumEventDescriptorV0TLV.tpe

override val value: ByteVector = {
val starting = UInt16(outcomes.size).bytes

outcomes.foldLeft(starting) { (accum, outcome) =>
val outcomeBytes = CryptoUtil.serializeForHash(outcome)
accum ++ UInt16(outcomeBytes.length).bytes ++ outcomeBytes
val outcomeBytes = TLV.getStringBytes(outcome)
accum ++ outcomeBytes
}
}

Expand All @@ -278,19 +330,18 @@ object EnumEventDescriptorV0TLV extends TLVFactory[EnumEventDescriptorV0TLV] {

val count = UInt16(iter.takeBits(16))

val builder = Vector.newBuilder[String]
val builder = Vector.newBuilder[NormalizedString]

while (iter.index < value.length) {
val len = UInt16(iter.takeBits(16))
val outcomeBytes = iter.take(len.toInt)
val str = new String(outcomeBytes.toArray, StandardCharsets.UTF_8)
val str = iter.takeString()
builder.+=(str)
}

val result = builder.result()

require(count.toInt == result.size,
"Did not parse the expected number of outcomes")
require(
count.toInt == result.size,
s"Did not parse the expected number of outcomes, ${count.toInt} != ${result.size}")

EnumEventDescriptorV0TLV(result)
}
Expand All @@ -299,12 +350,12 @@ object EnumEventDescriptorV0TLV extends TLVFactory[EnumEventDescriptorV0TLV] {
sealed trait NumericEventDescriptorTLV extends EventDescriptorTLV {

/** The minimum valid value in the oracle can sign */
def min: Vector[String]
def min: Vector[NormalizedString]

def minNum: BigInt

/** The maximum valid value in the oracle can sign */
def max: Vector[String]
def max: Vector[NormalizedString]

def maxNum: BigInt

Expand All @@ -320,7 +371,7 @@ sealed trait NumericEventDescriptorTLV extends EventDescriptorTLV {
def base: UInt16

/** The unit of the outcome value */
def unit: String
def unit: NormalizedString

/** The precision of the outcome representing the base exponent
* by which to multiply the number represented by the composition
Expand Down Expand Up @@ -361,29 +412,26 @@ case class RangeEventDescriptorV0TLV(
start: Int32,
count: UInt32,
step: UInt16,
unit: String,
unit: NormalizedString,
precision: Int32)
extends NumericEventDescriptorTLV {

override val minNum: BigInt = BigInt(start.toInt)

override val min: Vector[String] = Vector(minNum.toString)
override val min: Vector[NormalizedString] = Vector(minNum.toString)

override val maxNum: BigInt =
start.toLong + (step.toLong * (count.toLong - 1))

override val max: Vector[String] = Vector(maxNum.toString)
override val max: Vector[NormalizedString] = Vector(maxNum.toString)

override val base: UInt16 = UInt16(10)

override val tpe: BigSizeUInt = RangeEventDescriptorV0TLV.tpe

override val value: ByteVector = {
val unitSize = BigSizeUInt(unit.length)
val unitBytes = CryptoUtil.serializeForHash(unit)

start.bytes ++ count.bytes ++ step.bytes ++
unitSize.bytes ++ unitBytes ++ precision.bytes
TLV.getStringBytes(unit) ++ precision.bytes
}

override def noncesNeeded: Int = 1
Expand Down Expand Up @@ -420,10 +468,10 @@ trait DigitDecompositionEventDescriptorV0TLV extends NumericEventDescriptorTLV {

override lazy val maxNum: BigInt = base.toBigInt.pow(numDigits.toInt) - 1

private lazy val maxDigit = (base.toInt - 1).toString
private lazy val maxDigit: NormalizedString = (base.toInt - 1).toString

override lazy val max: Vector[String] = if (isSigned) {
"+" +: Vector.fill(numDigits.toInt)(maxDigit)
override lazy val max: Vector[NormalizedString] = if (isSigned) {
NormalizedString("+") +: Vector.fill(numDigits.toInt)(maxDigit)
} else {
Vector.fill(numDigits.toInt)(maxDigit)
}
Expand All @@ -434,8 +482,8 @@ trait DigitDecompositionEventDescriptorV0TLV extends NumericEventDescriptorTLV {
0
}

override lazy val min: Vector[String] = if (isSigned) {
"-" +: Vector.fill(numDigits.toInt)(maxDigit)
override lazy val min: Vector[NormalizedString] = if (isSigned) {
NormalizedString("-") +: Vector.fill(numDigits.toInt)(maxDigit)
} else {
Vector.fill(numDigits.toInt)("0")
}
Expand All @@ -450,10 +498,9 @@ trait DigitDecompositionEventDescriptorV0TLV extends NumericEventDescriptorTLV {
if (isSigned) ByteVector(TRUE_BYTE) else ByteVector(FALSE_BYTE)

val numDigitBytes = numDigits.bytes
val unitSize = BigSizeUInt(unit.length)
val unitBytes = CryptoUtil.serializeForHash(unit)
val unitBytes = TLV.getStringBytes(unit)

base.bytes ++ isSignedByte ++ unitSize.bytes ++ unitBytes ++ precision.bytes ++ numDigitBytes
base.bytes ++ isSignedByte ++ unitBytes ++ precision.bytes ++ numDigitBytes
}

override def noncesNeeded: Int = {
Expand All @@ -466,7 +513,7 @@ trait DigitDecompositionEventDescriptorV0TLV extends NumericEventDescriptorTLV {
case class SignedDigitDecompositionEventDescriptor(
base: UInt16,
numDigits: UInt16,
unit: String,
unit: NormalizedString,
precision: Int32)
extends DigitDecompositionEventDescriptorV0TLV {
override val isSigned: Boolean = true
Expand All @@ -476,7 +523,7 @@ case class SignedDigitDecompositionEventDescriptor(
case class UnsignedDigitDecompositionEventDescriptor(
base: UInt16,
numDigits: UInt16,
unit: String,
unit: NormalizedString,
precision: Int32)
extends DigitDecompositionEventDescriptorV0TLV {
override val isSigned: Boolean = false
Expand Down Expand Up @@ -509,7 +556,7 @@ object DigitDecompositionEventDescriptorV0TLV
base: UInt16,
isSigned: Boolean,
numDigits: Int,
unit: String,
unit: NormalizedString,
precision: Int32): DigitDecompositionEventDescriptorV0TLV = {
if (isSigned) {
SignedDigitDecompositionEventDescriptor(base,
Expand All @@ -534,7 +581,7 @@ case class OracleEventV0TLV(
nonces: Vector[SchnorrNonce],
eventMaturityEpoch: UInt32,
eventDescriptor: EventDescriptorTLV,
eventURI: String
eventId: NormalizedString
) extends OracleEventTLV {

require(eventDescriptor.noncesNeeded == nonces.size,
Expand All @@ -543,11 +590,12 @@ case class OracleEventV0TLV(
override def tpe: BigSizeUInt = OracleEventV0TLV.tpe

override val value: ByteVector = {
val uriBytes = CryptoUtil.serializeForHash(eventURI)
val eventIdBytes = TLV.getStringBytes(eventId)

val numNonces = UInt16(nonces.size)
val noncesBytes = nonces.foldLeft(numNonces.bytes)(_ ++ _.bytes)

noncesBytes ++ eventMaturityEpoch.bytes ++ eventDescriptor.bytes ++ uriBytes
noncesBytes ++ eventMaturityEpoch.bytes ++ eventDescriptor.bytes ++ eventIdBytes
}

/** Gets the maturation of the event since epoch */
Expand Down Expand Up @@ -579,9 +627,9 @@ object OracleEventV0TLV extends TLVFactory[OracleEventV0TLV] {
val eventMaturity = UInt32(iter.takeBits(32))
val eventDescriptor = EventDescriptorTLV(iter.current)
iter.skip(eventDescriptor.byteSize)
val eventURI = new String(iter.current.toArray, StandardCharsets.UTF_8)
val eventId = iter.takeString()

OracleEventV0TLV(nonces, eventMaturity, eventDescriptor, eventURI)
OracleEventV0TLV(nonces, eventMaturity, eventDescriptor, eventId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,14 @@ class DbManagementTest extends BitcoinSAsyncTest with EmbeddedPg {
val result = oracleAppConfig.migrate()
oracleAppConfig.driver match {
case SQLite =>
val expected = 2
val expected = 3
assert(result == expected)
val flywayInfo = oracleAppConfig.info()

assert(flywayInfo.applied().length == expected)
assert(flywayInfo.pending().length == 0)
case PostgreSQL =>
val expected = 2
val expected = 3
assert(result == expected)
val flywayInfo = oracleAppConfig.info()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- Fix dummy event descriptor to be a parsable one
UPDATE events SET event_descriptor_tlv = 'fdd8060800010564756d6d79' WHERE event_descriptor_tlv = 'fdd806090001000564756d6d79';
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- Fix dummy event descriptor to be a parsable one
UPDATE `events` SET `event_descriptor_tlv` = 'fdd8060800010564756d6d79' WHERE `event_descriptor_tlv` = 'fdd806090001000564756d6d79';
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.bitcoins.dlc.oracle

import java.time.Instant

import org.bitcoins.commons.jsonmodels.dlc.SigningVersion
import org.bitcoins.core.config.BitcoinNetwork
import org.bitcoins.core.crypto.ExtKeyVersion.SegWitMainNetPriv
Expand All @@ -18,6 +16,7 @@ import org.bitcoins.dlc.oracle.storage._
import org.bitcoins.dlc.oracle.util.EventDbUtil
import org.bitcoins.keymanager.{DecryptedMnemonic, WalletStorage}

import java.time.Instant
import scala.concurrent.{ExecutionContext, Future}

case class DLCOracle(private val extPrivateKey: ExtPrivateKeyHardened)(implicit
Expand Down Expand Up @@ -273,7 +272,9 @@ case class DLCOracle(private val extPrivateKey: ExtPrivateKeyHardened)(implicit
s"No event saved with nonce ${nonce.hex} $outcome"))
}

eventOutcomeOpt <- eventOutcomeDAO.read((nonce, outcome.outcomeString))
hash = eventDb.signingVersion.calcOutcomeHash(eventDb.eventDescriptorTLV,
outcome.outcomeString)
eventOutcomeOpt <- eventOutcomeDAO.find(nonce, hash)
eventOutcomeDb <- eventOutcomeOpt match {
case Some(value) => Future.successful(value)
case None =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ case class DLCOracleAppConfig(
}
logger.info(s"Applied $numMigrations to the dlc oracle project")

if (migrationsApplied() == 2) {
logger.debug(s"Doing V2 Migration")
val migrations = migrationsApplied()
if (migrations == 2 || migrations == 3) { // For V2/V3 migrations
logger.debug(s"Doing V2/V3 Migration")

val dummyMigrationTLV = EventDescriptorTLV("fdd806090001000564756d6d79")
val dummyMigrationTLV = EventDescriptorTLV("fdd8060800010564756d6d79")

val eventDAO = EventDAO()(ec, appConfig)
for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ case class EventOutcomeDAO()(implicit
safeDatabase.runVec(query.result.transactionally)
}

def find(
nonce: SchnorrNonce,
hash: ByteVector): Future[Option[EventOutcomeDb]] = {
val query =
table.filter(item => item.nonce === nonce && item.hashedMessage === hash)

safeDatabase.run(query.result.transactionally).map(_.headOption)
}

class EventOutcomeTable(tag: Tag)
extends Table[EventOutcomeDb](tag, schemaName, "event_outcomes") {

Expand Down
Loading

0 comments on commit b3d70f5

Please sign in to comment.