Skip to content

Commit

Permalink
Contract keys to be typed after generated types
Browse files Browse the repository at this point in the history
Also adds support in daml-lf/interface for contract keys

Addresses #1586 (review)
  • Loading branch information
stefanobaghino-da committed Jun 14, 2019
1 parent 7f201b4 commit c65b649
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 79 deletions.
Expand Up @@ -119,11 +119,14 @@ final case class Enum(constructors: ImmArraySeq[Ref.Name]) extends DataType[Noth
def asDataType[RT, PVT]: DataType[RT, PVT] = this
}

final case class DefTemplate[+Ty](choices: Map[Ref.Name, TemplateChoice[Ty]]) {
final case class DefTemplate[+Ty](choices: Map[Ref.Name, TemplateChoice[Ty]], key: Option[Type]) {
def map[B](f: Ty => B): DefTemplate[B] = Functor[DefTemplate].map(this)(f)

def getChoices: j.Map[Ref.ChoiceName, _ <: TemplateChoice[Ty]] =
choices.asJava

def getKey: j.Optional[Type] =
key.fold(j.Optional.empty[Type])(k => j.Optional.of(k))
}

object DefTemplate {
Expand Down
Expand Up @@ -26,7 +26,6 @@ import com.digitalasset.daml.lf.data.Ref.{
PackageId,
QualifiedName
}
import com.digitalasset.daml.lf.iface.TemplateChoice.FWT

import scala.collection.JavaConverters._
import scala.collection.immutable.Map
Expand Down Expand Up @@ -213,12 +212,16 @@ object InterfaceReader {
point(InvalidDataTypeDefinition(
s"Cannot find a record associated with template: $templateName")))
case Some((rec, newState)) =>
val y: Errors[ErrorLoc, InterfaceReaderError] \/ Map[ChoiceName, FWT] =
locate('choices, choices(a, ctx))

y.fold(
newState.addError, { cs =>
newState.addTemplate(templateName, rec, DefTemplate(cs))
val templateArgs =
for {
choices <- locate('choices, choices(a, ctx))
key <- locate('key, key(a, ctx))
} yield (choices, key)

templateArgs.fold(
newState.addError, {
case (cs, k) =>
newState.addTemplate(templateName, rec, DefTemplate(cs, k))
}
)
}
Expand Down Expand Up @@ -250,6 +253,13 @@ object InterfaceReader {
choice = TemplateChoice(p, consuming = a.getConsuming, returnType = r)
} yield choice

private def key(
a: DamlLf1.DefTemplate,
ctx: Context
): InterfaceReaderError.Tree \/ Option[Type] =
if (a.hasKey) locate('key, rootErr(type_(a.getKey.getType, ctx)).map(Some(_)))
else \/-(None)

private def fullName(
m: ModuleName,
a: DamlLf1.DottedName): InterfaceReaderError \/ QualifiedName =
Expand Down
Expand Up @@ -64,10 +64,10 @@ void contractHasDeprecatedFromIdAndRecord() {

@Test
void contractHasFromIdAndRecord() {
SimpleTemplate.Contract emptyAgreement = SimpleTemplate.Contract.fromIdAndRecord("SomeId", simpleTemplateRecord, Optional.empty(), Optional.empty());
SimpleTemplate.Contract emptyAgreement = SimpleTemplate.Contract.fromIdAndRecord("SomeId", simpleTemplateRecord, Optional.empty());
assertFalse(emptyAgreement.agreementText.isPresent(), "Field agreementText should not be present");

SimpleTemplate.Contract nonEmptyAgreement = SimpleTemplate.Contract.fromIdAndRecord("SomeId", simpleTemplateRecord, Optional.of("I agree"), Optional.empty());
SimpleTemplate.Contract nonEmptyAgreement = SimpleTemplate.Contract.fromIdAndRecord("SomeId", simpleTemplateRecord, Optional.of("I agree"));
assertTrue(nonEmptyAgreement.agreementText.isPresent(), "Field agreementText should be present");
assertEquals(nonEmptyAgreement.agreementText, Optional.of("I agree"), "Unexpected agreementText");
}
Expand Down
Expand Up @@ -15,15 +15,12 @@ import com.daml.ledger.javaapi.data.{
ArchivedEvent,
Command,
CreatedEvent,
Decimal,
Event,
Filter,
FiltersByParty,
GetTransactionsRequest,
LedgerOffset,
NoFilter,
Party,
Record,
SubmitCommandsRequest,
Transaction,
Unit => DamlUnit
Expand Down Expand Up @@ -220,15 +217,13 @@ class CodegenLedgerTest extends FlatSpec with Matchers with BazelRunfiles {
wolpertinger.agreementText.get shouldBe s"${wolpertinger.data.name} has ${wolpertinger.data.wings} wings and is ${wolpertinger.data.age} years old."
}

it should "provide the contractKey" in withClient { client =>
it should "provide the key" in withClient { client =>
sendCmd(client, glookofly.create())

val wolpertinger :: _ = readActiveContracts(client)

wolpertinger.contractKey.isPresent shouldBe true
wolpertinger.contractKey.get.asRecord.isPresent shouldBe true
wolpertinger.contractKey.get.asRecord.get.getFields.asScala should contain only (new Record.Field(
"owner",
new Party(Alice)), new Record.Field("age", new Decimal(java.math.BigDecimal.valueOf(17.42))))
wolpertinger.key.isPresent shouldBe true
wolpertinger.key.get.owner shouldEqual "Alice"
wolpertinger.key.get.age shouldEqual java.math.BigDecimal.valueOf(17.42)
}
}
Expand Up @@ -6,7 +6,7 @@ package com.digitalasset.daml.lf.codegen.backend.java.inner
import java.util.Optional

import com.daml.ledger.javaapi
import com.daml.ledger.javaapi.data.{ContractId, CreatedEvent, Value}
import com.daml.ledger.javaapi.data.{ContractId, CreatedEvent}
import com.digitalasset.daml.lf.codegen.TypeWithContext
import com.digitalasset.daml.lf.codegen.backend.java.ObjectMethods
import com.digitalasset.daml.lf.data.Ref.{ChoiceName, PackageId, QualifiedName}
Expand Down Expand Up @@ -51,7 +51,7 @@ private[inner] object TemplateClass extends StrictLogging {
typeWithContext.interface.typeDecls,
typeWithContext.packageId,
packagePrefixes))
.addType(generateContractClass(className))
.addType(generateContractClass(className, template.key, packagePrefixes))
.addFields(RecordFields(fields).asJava)
.addMethods(RecordMethods(fields, className, IndexedSeq.empty, packagePrefixes).asJava)
.build()
Expand All @@ -62,116 +62,183 @@ private[inner] object TemplateClass extends StrictLogging {
private val idFieldName = "id"
private val dataFieldName = "data"
private val agreementFieldName = "agreementText"
private val contractKeyFieldName = "contractKey"
private val contractKeyFieldName = "key"

private val optionalString = ParameterizedTypeName.get(classOf[Optional[_]], classOf[String])
private val optionalValue = ParameterizedTypeName.get(classOf[Optional[_]], classOf[Value])
private def optional(name: TypeName) =
ParameterizedTypeName.get(ClassName.get(classOf[Optional[_]]), name)

private def generateContractClass(
templateClassName: ClassName,
key: Option[Type],
packagePrefixes: Map[PackageId, String]): TypeSpec = {

private def generateContractClass(templateClassName: ClassName): TypeSpec = {
val contractIdClassName = ClassName.bestGuess("ContractId")
val contractKeyClassName = key.map(toJavaTypeName(_, packagePrefixes))

val classBuilder =
TypeSpec.classBuilder("Contract").addModifiers(Modifier.STATIC, Modifier.PUBLIC)

classBuilder.addField(contractIdClassName, idFieldName, Modifier.PUBLIC, Modifier.FINAL)
classBuilder.addField(templateClassName, dataFieldName, Modifier.PUBLIC, Modifier.FINAL)
classBuilder.addField(optionalString, agreementFieldName, Modifier.PUBLIC, Modifier.FINAL)
classBuilder.addField(optionalValue, contractKeyFieldName, Modifier.PUBLIC, Modifier.FINAL)

classBuilder.addSuperinterface(ClassName.get(classOf[javaapi.data.Contract]))

val constructorBuilder = MethodSpec
.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addParameter(contractIdClassName, idFieldName)
.addParameter(templateClassName, dataFieldName)
.addParameter(optionalString, agreementFieldName)
.addParameter(optionalValue, contractKeyFieldName)

constructorBuilder.addStatement("this.$L = $L", idFieldName, idFieldName)
constructorBuilder.addStatement("this.$L = $L", dataFieldName, dataFieldName)
constructorBuilder.addStatement("this.$L = $L", agreementFieldName, agreementFieldName)
constructorBuilder.addStatement("this.$L = $L", contractKeyFieldName, contractKeyFieldName);

contractKeyClassName.foreach { name =>
classBuilder.addField(optional(name), contractKeyFieldName, Modifier.PUBLIC, Modifier.FINAL)
constructorBuilder.addParameter(optional(name), contractKeyFieldName)
constructorBuilder.addStatement("this.$L = $L", contractKeyFieldName, contractKeyFieldName)
}

val constructor = constructorBuilder.build()

classBuilder.addMethod(constructor)

val contractClassName = ClassName.bestGuess("Contract")
val fields = Array(idFieldName, dataFieldName, agreementFieldName)
classBuilder
.addMethod(generateFromIdAndRecord(contractClassName, templateClassName, contractIdClassName))
.addMethod(
generateFromIdAndRecord(
contractClassName,
templateClassName,
contractIdClassName,
contractKeyClassName))
.addMethod(
generateFromIdAndRecordDeprecated(
contractClassName,
templateClassName,
contractIdClassName))
contractIdClassName,
contractKeyClassName.isDefined))
.addMethod(
generateFromCreatedEvent(contractClassName, templateClassName, contractIdClassName))
generateFromCreatedEvent(
contractClassName,
templateClassName,
contractIdClassName,
contractKeyClassName))
.addMethods(ObjectMethods(contractClassName, fields, templateClassName).asJava)
.build()
}

private[inner] def generateFromIdAndRecord(
className: ClassName,
templateClassName: ClassName,
idClassName: ClassName): MethodSpec =
MethodSpec
.methodBuilder("fromIdAndRecord")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.returns(className)
.addParameter(classOf[String], "contractId")
.addParameter(classOf[javaapi.data.Record], "record$")
.addParameter(optionalString, agreementFieldName)
.addParameter(optionalValue, contractKeyFieldName)
.addStatement("$T $L = new $T(contractId)", idClassName, idFieldName, idClassName)
.addStatement(
"$T $L = $T.fromValue(record$$)",
templateClassName,
dataFieldName,
templateClassName)
.addStatement(
idClassName: ClassName,
maybeContractKeyClassName: Option[TypeName]): MethodSpec = {

val params = Iterable(
ParameterSpec.builder(classOf[String], "contractId").build(),
ParameterSpec.builder(classOf[javaapi.data.Record], "record$").build(),
ParameterSpec.builder(optionalString, agreementFieldName).build()
) ++ maybeContractKeyClassName
.map(name => ParameterSpec.builder(optional(name), contractKeyFieldName).build)
.toList

val spec =
MethodSpec
.methodBuilder("fromIdAndRecord")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.returns(className)
.addParameters(params.asJava)
.addStatement("$T $L = new $T(contractId)", idClassName, idFieldName, idClassName)
.addStatement(
"$T $L = $T.fromValue(record$$)",
templateClassName,
dataFieldName,
templateClassName)

if (maybeContractKeyClassName.isDefined) {
spec.addStatement(
"return new $T($L, $L, $L, $L)",
className,
idFieldName,
dataFieldName,
agreementFieldName,
contractKeyFieldName
)
.build()
contractKeyFieldName)
} else {
spec.addStatement(
"return new $T($L, $L, $L)",
className,
idFieldName,
dataFieldName,
agreementFieldName)
}

spec.build()
}

private[inner] def generateFromIdAndRecordDeprecated(
className: ClassName,
templateClassName: ClassName,
idClassName: ClassName): MethodSpec =
MethodSpec
.methodBuilder("fromIdAndRecord")
.addAnnotation(classOf[Deprecated])
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.returns(className)
.addParameter(classOf[String], "contractId")
.addParameter(classOf[javaapi.data.Record], "record$")
.addStatement("$T $L = new $T(contractId)", idClassName, idFieldName, idClassName)
.addStatement(
"$T $L = $T.fromValue(record$$)",
templateClassName,
dataFieldName,
templateClassName)
.addStatement(
idClassName: ClassName,
hasContractKey: Boolean): MethodSpec = {
val spec =
MethodSpec
.methodBuilder("fromIdAndRecord")
.addAnnotation(classOf[Deprecated])
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.returns(className)
.addParameter(classOf[String], "contractId")
.addParameter(classOf[javaapi.data.Record], "record$")
.addStatement("$T $L = new $T(contractId)", idClassName, idFieldName, idClassName)
.addStatement(
"$T $L = $T.fromValue(record$$)",
templateClassName,
dataFieldName,
templateClassName)

if (hasContractKey) {
spec.addStatement(
"return new $T($L, $L, $T.empty(), $T.empty())",
className,
idFieldName,
dataFieldName,
classOf[Optional[_]],
classOf[Optional[_]])
.build()
} else {
spec.addStatement(
"return new $T($L, $L, $T.empty())",
className,
idFieldName,
dataFieldName,
classOf[Optional[_]])
}

spec.build()
}

private[inner] def generateFromCreatedEvent(
className: ClassName,
templateClassName: ClassName,
idClassName: ClassName) = {
MethodSpec
.methodBuilder("fromCreatedEvent")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.returns(className)
.addParameter(classOf[CreatedEvent], "event")
.addStatement(
"return fromIdAndRecord(event.getContractId(), event.getArguments(), event.getAgreementText(), event.getContractKey())")
.build()
idClassName: ClassName,
maybeContractKeyClassName: Option[TypeName]) = {
val spec =
MethodSpec
.methodBuilder("fromCreatedEvent")
.addModifiers(Modifier.PUBLIC, Modifier.STATIC)
.returns(className)
.addParameter(classOf[CreatedEvent], "event")

maybeContractKeyClassName.fold(spec.addStatement(
"return fromIdAndRecord(event.getContractId(), event.getArguments(), event.getAgreementText())")) {
name =>
spec.addStatement(
"return fromIdAndRecord(event.getContractId(), event.getArguments(), event.getAgreementText(), event.getContractKey().map(k -> $T.fromValue(k)))",
name
)
}
spec.build()
}

private def generateCreateMethod(name: ClassName): MethodSpec =
Expand Down

0 comments on commit c65b649

Please sign in to comment.