Skip to content

Commit

Permalink
replace OneAnd party Sets with NonEmpty Set (#11420)
Browse files Browse the repository at this point in the history
* add PartySet alias for db-backend

* add PartySet alias for fetch-contracts

* add PartySet alias for http-json

* deprecate old apply

* quick builder for NonEmpty collections

* replace PartySet in db-backend

* replace PartySet in fetch-contracts

* lar.Party is also domain.Party

* add incl1 operator

* replace PartySet in http-json

* port tests

* into with Scala 2.12 needs collection-compat

* no changelog

CHANGELOG_BEGIN
CHANGELOG_END

* simplify a couple functions that don't need so much data transformation now

* clean up some OneAnds and HKTs

* deal with Scala 2.12 without having warning suppression

* better, more obscure choice for Scala 2.12
  • Loading branch information
S11001001 committed Oct 28, 2021
1 parent 570160b commit bf00956
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 102 deletions.
Expand Up @@ -203,13 +203,13 @@ sealed abstract class Queries(tablePrefix: String, tpIdCacheMaxEntries: Long)(im
)
}

final def lastOffset(parties: OneAnd[Set, String], tpid: SurrogateTpId)(implicit
final def lastOffset(parties: PartySet, tpid: SurrogateTpId)(implicit
log: LogHandler
): ConnectionIO[Map[String, String]] = {
import Queries.CompatImplicits.catsReducibleFromFoldable1
val q = sql"""
SELECT party, last_offset FROM $ledgerOffsetTableName WHERE tpid = $tpid AND
""" ++ Fragments.in(fr"party", parties)
""" ++ Fragments.in(fr"party", parties.toF)
q.query[(String, String)]
.to[Vector]
.map(_.toMap)
Expand Down Expand Up @@ -326,7 +326,7 @@ sealed abstract class Queries(tablePrefix: String, tpIdCacheMaxEntries: Long)(im
}

private[http] final def selectContracts(
parties: OneAnd[Set, String],
parties: PartySet,
tpid: SurrogateTpId,
predicate: Fragment,
)(implicit
Expand All @@ -339,7 +339,7 @@ sealed abstract class Queries(tablePrefix: String, tpIdCacheMaxEntries: Long)(im
* which query or queries produced each contract.
*/
private[http] def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
parties: PartySet,
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
Expand Down Expand Up @@ -399,7 +399,7 @@ sealed abstract class Queries(tablePrefix: String, tpIdCacheMaxEntries: Long)(im
}

private[http] final def fetchById(
parties: OneAnd[Set, String],
parties: PartySet,
tpid: SurrogateTpId,
contractId: String,
)(implicit
Expand All @@ -408,7 +408,7 @@ sealed abstract class Queries(tablePrefix: String, tpIdCacheMaxEntries: Long)(im
selectContracts(parties, tpid, sql"c.contract_id = $contractId").option

private[http] final def fetchByKey(
parties: OneAnd[Set, String],
parties: PartySet,
tpid: SurrogateTpId,
key: Hash,
)(implicit
Expand All @@ -434,6 +434,8 @@ sealed abstract class Queries(tablePrefix: String, tpIdCacheMaxEntries: Long)(im
}

object Queries {
type PartySet = NonEmpty[Set[String]]

sealed trait SurrogateTpIdTag
val SurrogateTpId = Tag.of[SurrogateTpIdTag]
type SurrogateTpId = Long @@ SurrogateTpIdTag // matches tpid (BIGINT) above
Expand Down Expand Up @@ -729,13 +731,13 @@ private final class PostgresQueries(tablePrefix: String, tpIdCacheMaxEntries: Lo
}

private[http] override def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
parties: PartySet,
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
log: LogHandler
): Query0[DBContract[Mark, JsValue, JsValue, Vector[String]]] = {
val partyVector = parties.toVector
val partyVector: Vector[String] = parties.toVector
import ipol.{gas, pas}
queryByCondition(
queries,
Expand Down Expand Up @@ -898,7 +900,7 @@ private final class OracleQueries(
}

private[http] override def selectContractsMultiTemplate[Mark](
parties: OneAnd[Set, String],
parties: PartySet,
queries: ISeq[(SurrogateTpId, Fragment)],
trackMatchIndices: MatchedQueryMarker[Mark],
)(implicit
Expand All @@ -920,7 +922,7 @@ private final class OracleQueries(
signatories, observers, agreement_text ${rownum getOrElse fr""}
FROM $contractTableName c
JOIN $contractStakeholdersViewName cst ON (c.contract_id = cst.contract_id)
WHERE (${Fragments.in(fr"cst.stakeholder", parties)})
WHERE (${Fragments.in(fr"cst.stakeholder", parties.toF)})
AND ($queriesCondition)"""
rownum.fold(dupQ)(_ => sql"SELECT $outerSelectList FROM ($dupQ) WHERE rownumber = 1")
},
Expand Down
Expand Up @@ -12,11 +12,7 @@ import util.{AbsoluteBookmark, BeginBookmark, ContractStreamStep, InsertDeleteSt
import util.IdentifierConverters.apiIdentifier
import com.daml.ledger.api.v1.transaction.Transaction
import com.daml.ledger.api.{v1 => lav1}
import scalaz.OneAnd._
import scalaz.std.set._
import scalaz.syntax.tag._
import scalaz.syntax.foldable._
import scalaz.OneAnd

private[daml] object AcsTxStreams {
import util.AkkaStreamsDoobie.{last, max, project2}
Expand Down Expand Up @@ -132,7 +128,7 @@ private[daml] object AcsTxStreams {
}

private[daml] def transactionFilter(
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
templateIds: List[TemplateId.RequiredPkg],
): lav1.transaction_filter.TransactionFilter = {
import lav1.transaction_filter._
Expand All @@ -141,6 +137,8 @@ private[daml] object AcsTxStreams {
if (templateIds.isEmpty) Filters.defaultInstance
else Filters(Some(lav1.transaction_filter.InclusiveFilters(templateIds.map(apiIdentifier))))

TransactionFilter(domain.Party.unsubst(parties.toVector).map(_ -> filters).toMap)
TransactionFilter(
domain.Party.unsubst((parties: Set[domain.Party]).toVector).map(_ -> filters).toMap
)
}
}
Expand Up @@ -8,6 +8,7 @@ import lf.data.Ref
import util.ClientUtil.boxedRecord
import com.daml.ledger.api.{v1 => lav1}
import com.daml.ledger.api.refinements.{ApiTypes => lar}
import com.daml.scalautil.nonempty.NonEmpty
import scalaz.std.list._
import scalaz.std.option._
import scalaz.std.string._
Expand All @@ -26,6 +27,8 @@ package object domain {
type Party = lar.Party
val Party = lar.Party

type PartySet = NonEmpty[Set[Party]]

type Offset = String @@ OffsetTag

private[daml] implicit final class `fc domain ErrorOps`[A](private val o: Option[A])
Expand Down Expand Up @@ -161,6 +164,7 @@ package domain {
final val ContractId = here.ContractId
type Party = here.Party
final val Party = here.Party
type PartySet = here.PartySet
type Offset = here.Offset
final val Offset = here.Offset
type ActiveContract[+LfV] = here.ActiveContract[LfV]
Expand Down
1 change: 1 addition & 0 deletions ledger-service/http-json/BUILD.bazel
Expand Up @@ -531,6 +531,7 @@ da_scala_benchmark_jmh(
"//libs-scala/doobie-slf4j",
"//libs-scala/oracle-testing",
"//libs-scala/ports",
"//libs-scala/scala-utils",
"@maven//:com_oracle_database_jdbc_ojdbc8",
"@maven//:io_dropwizard_metrics_metrics_core",
"@maven//:org_slf4j_slf4j_api",
Expand Down
Expand Up @@ -6,9 +6,12 @@ package com.daml.http.dbbackend
import com.daml.http.dbbackend.Queries.SurrogateTpId
import com.daml.http.domain.{Party, TemplateId}
import com.daml.http.util.Logging.instanceUUIDLogCtx
import com.daml.scalautil.Statement.discard
import com.daml.scalautil.nonempty.NonEmpty
import doobie.implicits._
import org.openjdk.jmh.annotations._
import scalaz.OneAnd

import scala.collection.compat._

class QueryBenchmark extends ContractDaoBenchmark {
@Param(Array("1", "5", "9"))
Expand Down Expand Up @@ -46,9 +49,13 @@ class QueryBenchmark extends ContractDaoBenchmark {
implicit val driver: SupportedJdbcDriver.TC = dao.jdbcDriver
val result = instanceUUIDLogCtx(implicit lc =>
dao
.transact(ContractDao.selectContracts(OneAnd(Party(party), Set.empty), tpid, fr"1 = 1"))
.transact(
ContractDao.selectContracts(NonEmpty.pour(Party(party)) into Set, tpid, fr"1 = 1")
)
.unsafeRunSync()
)
assert(result.size == batchSize)
}

discard(IterableOnce) // only needed for scala 2.12
}
Expand Up @@ -10,10 +10,13 @@ import com.daml.http.dbbackend.Queries.SurrogateTpId
import com.daml.http.domain.{Party, TemplateId}
import com.daml.http.query.ValuePredicate
import com.daml.http.util.Logging.instanceUUIDLogCtx
import com.daml.scalautil.Statement.discard
import com.daml.scalautil.nonempty.NonEmpty
import org.openjdk.jmh.annotations._
import scalaz.OneAnd
import spray.json._

import scala.collection.compat._

class QueryPayloadBenchmark extends ContractDaoBenchmark {
@Param(Array("1", "10", "100"))
var extraParties: Int = _
Expand Down Expand Up @@ -70,9 +73,13 @@ class QueryPayloadBenchmark extends ContractDaoBenchmark {
implicit val sjd: SupportedJdbcDriver.TC = dao.jdbcDriver
val result = instanceUUIDLogCtx(implicit lc =>
dao
.transact(ContractDao.selectContracts(OneAnd(Party(party), Set.empty), tpid, whereClause))
.transact(
ContractDao.selectContracts(NonEmpty.pour(Party(party)) into Set, tpid, whereClause)
)
.unsafeRunSync()
)
assert(result.size == batchSize)
}

discard(IterableOnce) // only needed for scala 2.12
}
Expand Up @@ -28,8 +28,6 @@ import com.daml.ledger.api.{v1 => lav1}
import com.daml.logging.{ContextualizedLogger, LoggingContextOf}
import doobie.free.{connection => fconn}
import fconn.ConnectionIO
import scalaz.OneAnd._
import scalaz.std.set._
import scalaz.std.vector._
import scalaz.std.list._
import scalaz.std.option.none
Expand All @@ -39,7 +37,7 @@ import scalaz.syntax.functor._
import scalaz.syntax.foldable._
import scalaz.syntax.order._
import scalaz.syntax.std.option._
import scalaz.{OneAnd, \/}
import scalaz.\/
import spray.json.{JsNull, JsValue}

import scala.concurrent.ExecutionContext
Expand All @@ -64,7 +62,7 @@ private class ContractsFetch(
def fetchAndPersistBracket[A](
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
templateIds: List[domain.TemplateId.RequiredPkg],
)(within: BeginBookmark[Terminates.AtAbsolute] => ConnectionIO[A])(implicit
ec: ExecutionContext,
Expand All @@ -87,7 +85,7 @@ private class ContractsFetch(
// has desynchronized
lagging <- (templateIds.toSet, bb.map(_.toDomain)) match {
case (NonEmpty(tids), AbsoluteBookmark(expectedOff)) =>
laggingOffsets(parties.toSet, expectedOff, tids)
laggingOffsets(parties, expectedOff, tids)
case _ => fconn.pure(none[(domain.Offset, Set[domain.TemplateId.RequiredPkg])])
}
retriedA <- lagging.cata(
Expand Down Expand Up @@ -119,7 +117,7 @@ private class ContractsFetch(
def fetchAndPersist(
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
templateIds: List[domain.TemplateId.RequiredPkg],
)(implicit
ec: ExecutionContext,
Expand Down Expand Up @@ -411,6 +409,6 @@ private[http] object ContractsFetch {
private final case class FetchContext(
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
)
}
Expand Up @@ -20,7 +20,6 @@ import com.daml.fetchcontracts.util.ContractStreamStep.{Acs, LiveBegin}
import com.daml.http.util.FutureUtil.toFuture
import com.daml.http.util.Logging.{InstanceUUID, RequestID}
import com.daml.jwt.domain.Jwt
import com.daml.ledger.api.refinements.{ApiTypes => lar}
import com.daml.ledger.api.v1.active_contracts_service.GetActiveContractsResponse
import com.daml.ledger.api.{v1 => api}
import com.daml.logging.{ContextualizedLogger, LoggingContextOf}
Expand Down Expand Up @@ -69,7 +68,7 @@ class ContractsService(

def resolveContractReference(
jwt: Jwt,
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
contractLocator: domain.ContractLocator[LfValue],
ledgerId: LedgerApiDomain.LedgerId,
)(implicit
Expand Down Expand Up @@ -113,7 +112,7 @@ class ContractsService(

private[this] def findByContractKey(
jwt: Jwt,
parties: OneAnd[Set, lar.Party],
parties: domain.PartySet,
templateId: TemplateId.OptionalPkg,
ledgerId: LedgerApiDomain.LedgerId,
contractKey: LfValue,
Expand All @@ -133,7 +132,7 @@ class ContractsService(

private[this] def findByContractId(
jwt: Jwt,
parties: OneAnd[Set, lar.Party],
parties: domain.PartySet,
templateId: Option[domain.TemplateId.OptionalPkg],
ledgerId: LedgerApiDomain.LedgerId,
contractId: domain.ContractId,
Expand Down Expand Up @@ -239,7 +238,7 @@ class ContractsService(
def retrieveAll(
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
)(implicit
lc: LoggingContextOf[InstanceUUID]
): SearchResult[Error \/ domain.ActiveContract[LfValue]] =
Expand Down Expand Up @@ -270,7 +269,7 @@ class ContractsService(
def search(
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
templateIds: OneAnd[Set, domain.TemplateId.OptionalPkg],
queryParams: Map[String, JsValue],
)(implicit
Expand Down Expand Up @@ -384,7 +383,7 @@ class ContractsService(
}

private[this] def searchDbOneTpId_(
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
templateId: domain.TemplateId.RequiredPkg,
queryParams: Map[String, JsValue],
)(implicit
Expand All @@ -399,7 +398,7 @@ class ContractsService(
private[this] def searchInMemory(
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
templateIds: Set[domain.TemplateId.RequiredPkg],
queryParams: InMemoryQuery,
)(implicit
Expand Down Expand Up @@ -441,7 +440,7 @@ class ContractsService(
private[this] def searchInMemoryOneTpId(
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, domain.Party],
parties: domain.PartySet,
templateId: domain.TemplateId.RequiredPkg,
queryParams: InMemoryQuery.P,
)(implicit
Expand Down Expand Up @@ -469,7 +468,7 @@ class ContractsService(
private[http] def liveAcsAsInsertDeleteStepSource(
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, lar.Party],
parties: domain.PartySet,
templateIds: List[domain.TemplateId.RequiredPkg],
): Source[ContractStreamStep.LAV1, NotUsed] = {
val txnFilter = util.Transactions.transactionFilterFor(parties, templateIds)
Expand All @@ -486,7 +485,7 @@ class ContractsService(
private[http] def insertDeleteStepSource(
jwt: Jwt,
ledgerId: LedgerApiDomain.LedgerId,
parties: OneAnd[Set, lar.Party],
parties: domain.PartySet,
templateIds: List[domain.TemplateId.RequiredPkg],
startOffset: Option[domain.StartingOffset] = None,
terminates: Terminates = Terminates.AtLedgerEnd,
Expand Down Expand Up @@ -571,7 +570,7 @@ object ContractsService {

final case class SearchContext[Tids[_], Pkgs[_]](
jwt: Jwt,
parties: OneAnd[Set, lar.Party],
parties: domain.PartySet,
templateIds: Tids[domain.TemplateId[Pkgs[String]]],
ledgerId: LedgerApiDomain.LedgerId,
)
Expand Down

0 comments on commit bf00956

Please sign in to comment.