Skip to content

Commit

Permalink
Make sure we always check field type consistency on early returns.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcoblenz committed Dec 5, 2019
1 parent 2c942e2 commit 1d8dd96
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 45 deletions.
14 changes: 10 additions & 4 deletions resources/tests/type_checker_tests/Return.obs
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,18 @@ contract HasField {
h = new HasField(); // Ignore infinite recursion for now!
}

transaction t() {
transaction t2() {
if (true) {
disown h;
return;
return; // Error: h needs to be owned here.
}
disown h;
h = new HasField();
}

transaction t3() returns int {
if (true) {
disown h;
return 3; // Error: h needs to be owned here.
}
return 4;
}
}
14 changes: 12 additions & 2 deletions src/main/scala/edu/cmu/cs/obsidian/parser/AST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ sealed abstract class InvokableDeclaration() extends Declaration {
val thisType: ObsidianType
val thisFinalType: ObsidianType
val isStatic: Boolean
val initialFieldTypes: Map[String, ObsidianType] = Map.empty
val finalFieldTypes: Map[String, ObsidianType] = Map.empty

def bodyEnd : AST =
if (body.nonEmpty) {
body.last
} else {
this
}
}

// Expressions not containing other expressions
Expand Down Expand Up @@ -306,8 +315,9 @@ case class Transaction(name: String,
isPrivate: Boolean,
thisType: NonPrimitiveType,
thisFinalType: NonPrimitiveType,
initialFieldTypes: Map[String, ObsidianType] = Map.empty, // populated after parsing
finalFieldTypes: Map[String, ObsidianType] = Map.empty // populated after parsing
override val initialFieldTypes: Map[String, ObsidianType] = Map.empty,
override val finalFieldTypes: Map[String, ObsidianType] = Map.empty
// will populate initial and final field types after parsing
) extends InvokableDeclaration with IsAvailableInStates {
val tag: DeclarationTag = TransactionDeclTag

Expand Down
40 changes: 11 additions & 29 deletions src/main/scala/edu/cmu/cs/obsidian/typecheck/Checker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1029,24 +1029,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
resultContext
}

private def checkFieldTypeConsistencyAfterConstructor(context: Context, constr: Constructor): Unit = {
// First check fields that may be of inconsistent type with their declarations to make sure they match
// either the declarations or the specified final types.
for ((field, typ) <- context.thisFieldTypes) {
val requiredFieldType = context.lookupDeclaredFieldTypeInThis(field)

requiredFieldType match {
case None =>
// Nothing to do; there was an assignment to a field type that may not be in scope, so there will be a separate error message.
case Some(declaredFieldType) =>
if (isSubtype(context.contractTable, typ, declaredFieldType, false).isDefined || (typ.isOwned != declaredFieldType.isOwned && !declaredFieldType.isBottom)) {
logError(constr, InvalidInconsistentFieldType(field, typ, declaredFieldType))
}
}
}
}

private def checkFieldTypeConsistencyAfterTransaction(context: Context, tx: Transaction): Unit = {
private def checkFieldTypeConsistencyAfterTransaction(context: Context, tx: InvokableDeclaration, exitLocation: AST): Unit = {
// First check fields that may be of inconsistent type with their declarations to make sure they match
// either the declarations or the specified final types.
for ((field, typ) <- context.thisFieldTypes) {
Expand All @@ -1067,10 +1050,10 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {

requiredFieldType match {
case None =>
assert(false, "Bug: invalid field in field type context")
// Nothing to do; there was an assignment to a field type that may not be in scope, so there will be a separate error message.
case Some(declaredFieldType) =>
if (isSubtype(context.contractTable, typ, declaredFieldType, false).isDefined || (typ.isOwned != declaredFieldType.isOwned && !declaredFieldType.isBottom)) {
logError(tx, InvalidInconsistentFieldType(field, typ, declaredFieldType))
logError(exitLocation, InvalidInconsistentFieldType(field, typ, declaredFieldType))
}
}
}
Expand Down Expand Up @@ -1537,6 +1520,8 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
(contextPrime.updated(name, declaredType), VariableDeclWithInit(typ, name, ePrime))

case Return() =>
checkFieldTypeConsistencyAfterTransaction(context, decl, s)

decl match {
/* the tx/function must have no return type */
case tx: Transaction if tx.retType.isEmpty =>
Expand Down Expand Up @@ -1579,6 +1564,8 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
}

checkForUnusedOwnershipErrors(s, contextPrime, thisSetToExclude ++ argsSetToExclude)
checkFieldTypeConsistencyAfterTransaction(context, decl, s)


if (retTypeOpt.isDefined && !retTypeOpt.get.isBottom) {
checkIsSubtype(context.contractTable, s, typ, retTypeOpt.get)
Expand Down Expand Up @@ -2387,13 +2374,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {

// Don't need to check interface methods to make sure they return
if (!hasReturnStatement(tx.body) && !impl.isInterface && tx.retType.isDefined) {
val ast =
if (tx.body.nonEmpty) {
tx.body.last
} else {
tx
}
logError(ast, MustReturnError(tx.name))
logError(tx.bodyEnd, MustReturnError(tx.name))
} else if (tx.retType.isEmpty) {
// We check for unused ownership errors at each return; if there isn't guaranteed to be one at the end, check separately.
// Every arg whose output type is owned should be owned at the end.
Expand All @@ -2413,7 +2394,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
}

// Check to make sure all the field types are consistent with their declarations.
checkFieldTypeConsistencyAfterTransaction(outputContext, tx)
checkFieldTypeConsistencyAfterTransaction(outputContext, tx, tx.bodyEnd)

// todo: check that every declared variable is initialized before use
tx.copy(body = newStatements)
Expand Down Expand Up @@ -2654,7 +2635,8 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {

checkForUnusedOwnershipErrors(constr, outputContext, Set("this"))
checkForUnusedStateInitializers(outputContext)
checkFieldTypeConsistencyAfterConstructor(outputContext, constr)

checkFieldTypeConsistencyAfterTransaction(outputContext, constr, constr.bodyEnd)

// if the contract contains states, its constructor must contain a state transition
if (hasStates && !hasTransition(constr.body)) {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/edu/cmu/cs/obsidian/typecheck/Error.scala
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ case class InvalidNonThisFieldAccess() extends Error {


case class InvalidInconsistentFieldType(fieldName: String, actualType: ObsidianType, expectedType: ObsidianType) extends Error {
val msg: String = s"At the ends of transactions, all fields must reference objects consistent with their declared types. " +
val msg: String = s"When transactions exit, all fields must reference objects consistent with their declared types. " +
s" Field '$fieldName' is of type $actualType but was declared as $expectedType."
}

Expand Down
26 changes: 17 additions & 9 deletions src/test/scala/edu/cmu/cs/obsidian/tests/TypeCheckerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ class TypeCheckerTests extends JUnitSuite {
::
(MustReturnError("branching_return2"), 58)
::
(InvalidInconsistentFieldType("h",
ContractReferenceType(ContractType("HasField", Nil), Shared(), NotRemoteReferenceType()),
ContractReferenceType(ContractType("HasField", Nil), Owned(), NotRemoteReferenceType())), 81)
::
(InvalidInconsistentFieldType("h",
ContractReferenceType(ContractType("HasField", Nil), Shared(), NotRemoteReferenceType()),
ContractReferenceType(ContractType("HasField", Nil), Owned(), NotRemoteReferenceType())), 88)
::
Nil
)
}
Expand Down Expand Up @@ -448,7 +456,7 @@ class TypeCheckerTests extends JUnitSuite {
::
(InvalidInconsistentFieldType("money",
ContractReferenceType(ContractType("Money", Nil), Unowned(), NotRemoteReferenceType()),
ContractReferenceType(ContractType("Money", Nil), Owned(), NotRemoteReferenceType())), 26)
ContractReferenceType(ContractType("Money", Nil), Owned(), NotRemoteReferenceType())), 27)
::
(OverwrittenOwnershipError("money"), 27)
::
Expand Down Expand Up @@ -479,7 +487,7 @@ class TypeCheckerTests extends JUnitSuite {
runTest("resources/tests/type_checker_tests/Ownership.obs",
(InvalidInconsistentFieldType("prescription",
ContractReferenceType(ContractType("Prescription", Nil), Unowned(), NotRemoteReferenceType()),
ContractReferenceType(ContractType("Prescription", Nil), Owned(), NotRemoteReferenceType())), 15)
ContractReferenceType(ContractType("Prescription", Nil), Owned(), NotRemoteReferenceType())), 16)
::
Nil
)
Expand Down Expand Up @@ -616,14 +624,14 @@ class TypeCheckerTests extends JUnitSuite {
(ReceiverTypeIncompatibleError("changeStateStateSpecified",
ContractReferenceType(ContractType("C", Nil), Owned(), NotRemoteReferenceType()),
StateType(ContractType("C", Nil), Set("S1"), NotRemoteReferenceType())), 44) ::
(InvalidInconsistentFieldType("s1C", StateType(ContractType("C", Nil), Set("S2"), NotRemoteReferenceType()), StateType(ContractType("C", Nil), Set("S1"), NotRemoteReferenceType())), 47) ::
(InvalidInconsistentFieldType("s1C", StateType(ContractType("C", Nil), Set("S2"), NotRemoteReferenceType()), StateType(ContractType("C", Nil), Set("S1"), NotRemoteReferenceType())), 51) ::
Nil
)
}

@Test def fieldTypeMismatchTest(): Unit = {
runTest("resources/tests/type_checker_tests/FieldTypeMismatch.obs",
(InvalidInconsistentFieldType("c", StateType(ContractType("C", Nil), Set("S2"), NotRemoteReferenceType()), StateType(ContractType("C", Nil), Set("S1"), NotRemoteReferenceType())), 24) ::
(InvalidInconsistentFieldType("c", StateType(ContractType("C", Nil), Set("S2"), NotRemoteReferenceType()), StateType(ContractType("C", Nil), Set("S1"), NotRemoteReferenceType())), 25) ::
Nil
)
}
Expand All @@ -640,7 +648,7 @@ class TypeCheckerTests extends JUnitSuite {
runTest("resources/tests/type_checker_tests/PrivateTransactions.obs",
(InvalidFinalFieldTypeDeclarationError("bogus"), 30)::
(FieldTypesDeclaredOnPublicTransactionError("t2"), 33)::
(InvalidInconsistentFieldType("c", StateType(ContractType("C", Nil), "S2", NotRemoteReferenceType()), StateType(ContractType("C", Nil), "S1", NotRemoteReferenceType())), 42)::
(InvalidInconsistentFieldType("c", StateType(ContractType("C", Nil), "S2", NotRemoteReferenceType()), StateType(ContractType("C", Nil), "S1", NotRemoteReferenceType())), 43)::
(FieldSubtypingError("c", StateType(ContractType("C", Nil), "S1", NotRemoteReferenceType()), StateType(ContractType("C", Nil), "S2", NotRemoteReferenceType())), 48)::
Nil)
}
Expand Down Expand Up @@ -784,7 +792,7 @@ class TypeCheckerTests extends JUnitSuite {
runTest("resources/tests/type_checker_tests/GenericsOwnership.obs",
(InvalidInconsistentFieldType("x",
GenericType(GenericVar(isAsset = false,"T",None),GenericBoundPerm(false, false, ContractType.topContractType, Unowned())),
GenericType(GenericVar(isAsset = false,"T",None),GenericBoundPerm(false, false, ContractType.topContractType, Owned()))), 12) ::
GenericType(GenericVar(isAsset = false,"T",None),GenericBoundPerm(false, false, ContractType.topContractType, Owned()))), 14) ::
Nil)
}

Expand Down Expand Up @@ -854,7 +862,7 @@ class TypeCheckerTests extends JUnitSuite {
GenericType(GenericVar(isAsset = false,"X",None),
GenericBoundPerm(interfaceSpecified = false, permSpecified = false, ContractType.topContractType, Unowned())),
GenericType(GenericVar(isAsset = false,"X",Some("s")),
GenericBoundPerm(interfaceSpecified = false, permSpecified = false, ContractType.topContractType, Unowned()))), 16) ::
GenericBoundPerm(interfaceSpecified = false, permSpecified = false, ContractType.topContractType, Unowned()))), 17) ::
(ReceiverTypeIncompatibleError("getX",
StateType(ContractType("A", Nil), "S2", NotRemoteReferenceType()),
StateType(ContractType("A", Nil), "S1", NotRemoteReferenceType())), 66) ::
Expand All @@ -881,7 +889,7 @@ class TypeCheckerTests extends JUnitSuite {
(InvalidInconsistentFieldType("c",
ContractReferenceType(ContractType("C", Nil), Owned(), NotRemoteReferenceType()),
ContractReferenceType(ContractType("C", Nil), Unowned(), NotRemoteReferenceType()))
, 8) :: Nil)
, 9) :: Nil)
}

@Test def genericsLinkedList(): Unit = {
Expand Down Expand Up @@ -948,7 +956,7 @@ class TypeCheckerTests extends JUnitSuite {
runTest("resources/tests/type_checker_tests/ConstructorFieldTypes.obs",
(InvalidInconsistentFieldType("seller",
StateType(ContractType("Seller", Nil), "InAuction", NotRemoteReferenceType()),
StateType(ContractType("Seller", Nil), "Unsold", NotRemoteReferenceType())), 11) :: Nil)
StateType(ContractType("Seller", Nil), "Unsold", NotRemoteReferenceType())), 13) :: Nil)
}

@Test def vals(): Unit = {
Expand Down

0 comments on commit 1d8dd96

Please sign in to comment.