Skip to content

Commit

Permalink
Added a special typechecking case for remote parameters to the main c…
Browse files Browse the repository at this point in the history
…lient transaction, since

these are owned by the blockchain, and loss of these (when they are temporarily owned due to dynamic state tests) is okay.
  • Loading branch information
mcoblenz committed Sep 4, 2019
1 parent 47fee46 commit bdfca5a
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 177 deletions.
10 changes: 7 additions & 3 deletions src/main/scala/edu/cmu/cs/obsidian/codegen/CodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,11 @@ class CodeGen (val target: Target, table: SymbolTable) {
val obsidianRetType = transaction.retType match {
case Some(typ) =>
typ match {
case np: NonPrimitiveType => Some(np.remoteType)
case np: NonPrimitiveType => {
// Which variety of remote reference type we pick now doesn't actually matter
// becuase the point is to generate the proper Java type.
Some(np.remoteType(NonTopLevelRemoteReferenceType()))
}
case _ => Some(typ)
}
case None => None
Expand Down Expand Up @@ -1115,7 +1119,7 @@ class CodeGen (val target: Target, table: SymbolTable) {
obsidianSerialized
} else {
// We can put any permission for referenceType, since we only need to get the translated name
val referenceType = ContractReferenceType(bound, Unowned(), isRemote = false)
val referenceType = ContractReferenceType(bound, Unowned(), NotRemoteReferenceType())
resolveType(referenceType, table, Some(model.ref(translationContext.getProtobufClassName(translationContext.contract)))).boxify()
}

Expand Down Expand Up @@ -2810,7 +2814,7 @@ class CodeGen (val target: Target, table: SymbolTable) {
narrowWith(invocation, params)

case Construction(contractType, args, isFFIInvocation) =>
val contractRefType = ContractReferenceType(contractType, Owned(), false)
val contractRefType = ContractReferenceType(contractType, Owned(), NotRemoteReferenceType())
val resolvedType = resolveType(contractRefType, table)

// Should've been found when typechecking, so this is safe
Expand Down
71 changes: 48 additions & 23 deletions src/main/scala/edu/cmu/cs/obsidian/parser/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,12 @@ object Parser extends Parsers {
case _ ~ params ~ _ => params
}.getOrElse(List())

val typ = extractTypeFromPermission(permissionToken, id._1, actualGenerics, isRemote, defaultOwned = false)
val remoteType = isRemote match {
case true => NonTopLevelRemoteReferenceType() // Temporary if this is a top-level argument. This will be fixed at a higher level in that case.
case false => NotRemoteReferenceType()
}

val typ = extractTypeFromPermission(permissionToken, id._1, actualGenerics, remoteType, defaultOwned = false)
typ.setLoc(id)
}
}
Expand All @@ -116,25 +121,25 @@ object Parser extends Parsers {
private def extractTypeFromPermission(permission: Option[~[Token, Seq[Identifier]]],
name: String,
genericParams: Seq[ObsidianType],
isRemote: Boolean,
remoteReferenceType: RemoteReferenceType,
defaultOwned: Boolean): NonPrimitiveType = {
val defaultPermission = if (defaultOwned) Owned() else Inferred()

val contractType = ContractType(name, genericParams)

permission match {
case None => ContractReferenceType(contractType, defaultPermission, isRemote)
case None => ContractReferenceType(contractType, defaultPermission, remoteReferenceType)
case Some(_ ~ permissionIdentSeq) =>
if (permissionIdentSeq.size == 1) {
val thePermissionOrState = permissionIdentSeq.head
val permission = resolvePermission(thePermissionOrState._1)
permission match {
case None => StateType(contractType, thePermissionOrState._1, isRemote)
case Some(p) => ContractReferenceType(contractType, p, isRemote)
case None => StateType(contractType, thePermissionOrState._1, remoteReferenceType)
case Some(p) => ContractReferenceType(contractType, p, remoteReferenceType)
}
} else {
val stateNames = permissionIdentSeq.map(_._1)
StateType(contractType, stateNames.toSet, isRemote)
StateType(contractType, stateNames.toSet, remoteReferenceType)
}
}
}
Expand Down Expand Up @@ -179,12 +184,12 @@ object Parser extends Parsers {
permission match {
case None =>
val correctedType = t match {
case ContractReferenceType(ct, Inferred(), isRemote) => ContractReferenceType(ct, Unowned(), isRemote)
case ContractReferenceType(ct, Inferred(), remoteReferenceType) => ContractReferenceType(ct, Unowned(), remoteReferenceType)
case _ => t
}
(correctedType, correctedType)
case Some(_ ~ idSeq) => {
val typOut = extractTypeFromPermission(permission, t.contractName, t.genericParams, t.isRemote, false)
val typOut = extractTypeFromPermission(permission, t.contractName, t.genericParams, t.remoteReferenceType, false)
(t, typOut)
}
}
Expand Down Expand Up @@ -569,29 +574,56 @@ object Parser extends Parsers {
}
}

def makeRemoteNonPrimitiveTypeTopLevelIfNeeded(np: NonPrimitiveType): NonPrimitiveType =
name match {
case MainT() =>
if (np.isRemote) {
np.remoteType(TopLevelRemoteReferenceType())
} else {
np
}
case _ => np
}

def makeRemoteTypeTopLevelIfNeeded(t: ObsidianType): ObsidianType =
name match {
case MainT() => t match {
case np: NonPrimitiveType => makeRemoteNonPrimitiveTypeTopLevelIfNeeded(np)
case _ => t
}
case _ => t
}

val thisContractType = ContractType(contractName, contractParams)

val finalType = thisArg match {
case None => ContractReferenceType(thisContractType, Shared(), false)
case Some(v) => v.typOut.withParams(contractParams)
case None => ContractReferenceType(thisContractType, Shared(), NotRemoteReferenceType())
case Some(v) => makeRemoteTypeTopLevelIfNeeded(v.typOut.withParams(contractParams))
}

val thisType = thisArg match {
case None => ContractReferenceType(ContractType(contractName, contractParams), Shared(), false)
case Some(variableDecl) => variableDecl.typIn.asInstanceOf[NonPrimitiveType].withParams(contractParams)
case None => ContractReferenceType(ContractType(contractName, contractParams), Shared(), NotRemoteReferenceType())
case Some(variableDecl) => makeRemoteNonPrimitiveTypeTopLevelIfNeeded(variableDecl.typIn.asInstanceOf[NonPrimitiveType].withParams(contractParams))
}

val initialFieldTypes = privateMethodFieldTypes match {
case None => Map.empty
case Some(_ ~ argDefList ~ _) => argDefList.map((v: VariableDeclWithSpec) => (v.varName, v.typIn))
case Some(_ ~ argDefList ~ _) => argDefList.map((v: VariableDeclWithSpec) => (v.varName, makeRemoteTypeTopLevelIfNeeded(v.typIn)))
}

val finalFieldTypes = privateMethodFieldTypes match {
case None => Map.empty
case Some(_ ~ argDefList ~ _) => argDefList.map((v: VariableDeclWithSpec) => (v.varName, v.typOut))
case Some(_ ~ argDefList ~ _) => argDefList.map((v: VariableDeclWithSpec) => (v.varName, makeRemoteTypeTopLevelIfNeeded(v.typOut)))
}

Transaction(nameString, params, filteredArgs, returns,
val argTypesUpdatedWithRemoteTypes = filteredArgs.map( (d: VariableDeclWithSpec) =>
d match {
case VariableDeclWithSpec(typIn, typOut, varName) =>
VariableDeclWithSpec(makeRemoteTypeTopLevelIfNeeded(typIn), makeRemoteTypeTopLevelIfNeeded(typOut), varName).setLoc(d)
}
)

Transaction(nameString, params, argTypesUpdatedWithRemoteTypes, returns,
ensures, body, opts.isStatic, opts.isPrivate,
thisType, finalType.asInstanceOf[NonPrimitiveType],
initialFieldTypes.toMap, finalFieldTypes.toMap).setLoc(t)
Expand All @@ -612,7 +644,7 @@ object Parser extends Parsers {
private def parseConstructor (params: Seq[GenericType]) = {
parseId ~ opt(AtT() ~! parseIdAlternatives) ~! LParenT() ~! parseArgDefList("") ~! RParenT() ~! LBraceT() ~! parseBody ~! RBraceT() ^^ {
case name ~ permission ~ _ ~ args ~ _ ~ _ ~ body ~ _ =>
val resultType = extractTypeFromPermission(permission, name._1, params, isRemote = false, defaultOwned = false)
val resultType = extractTypeFromPermission(permission, name._1, params, NotRemoteReferenceType(), defaultOwned = false)
Constructor(name._1, args, resultType, body).setLoc(name)
}
}
Expand Down Expand Up @@ -694,13 +726,6 @@ object Parser extends Parsers {
}
}

private def parseGenericParams = {
opt(RemoteT()) ~ parseId ~ LBracketT() ~ repsep(parseType, CommaT()) ~ RBracketT() ~ opt(AtT() ~ parseIdAlternatives) ^^ {
case isRemote ~ name ~ _ ~ genParams ~ _ ~ permission =>
extractTypeFromPermission(permission, name._1, genParams, isRemote.isDefined, defaultOwned = false)
}
}

private def parseGenericId = {
parseId ~! opt(LBracketT() ~ repsep(genericParam, CommaT()) ~ RBracketT()) ^^ {
case id ~ optParams =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class StateTable(

def name: String = astNodeRaw.name

def nonPrimitiveType = StateType(ContractType(contract.name, Nil), astNodeRaw.name, false)
def nonPrimitiveType = StateType(ContractType(contract.name, Nil), astNodeRaw.name, NotRemoteReferenceType())
def contractType: ContractType = ContractType(name, Nil)

def ast: State = astNode
Expand Down
27 changes: 14 additions & 13 deletions src/main/scala/edu/cmu/cs/obsidian/typecheck/Checker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
}
case (StateType(ct, ss1, _), StateType(_, ss2, _)) =>
val unionStates = ss1.union(ss2)
Some(StateType(ct, unionStates, false))
Some(StateType(ct, unionStates, NotRemoteReferenceType()))

case (g1@GenericType(gVar1, gBound1), g2@GenericType(gVar2, gBound2)) =>
if (gVar1.permissionVar.isDefined && gVar2.permissionVar.isDefined) {
Expand Down Expand Up @@ -662,7 +662,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
val tableLookup = context.contractTable.lookupContract(x)
if (tableLookup.isDefined) {
val contractTable = tableLookup.get
val nonPrimitiveType = ContractReferenceType(contractTable.contractType, Shared(), false)
val nonPrimitiveType = ContractReferenceType(contractTable.contractType, Shared(), NotRemoteReferenceType())
(InterfaceContractType(contractTable.name, nonPrimitiveType), context, e)
}
else {
Expand Down Expand Up @@ -834,7 +834,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {

val (exprList, simpleType, contextPrime) = result match {
// Even if the args didn't check, we can still output a type
case None => (Nil, ContractReferenceType(contractType, Owned(), false), context)
case None => (Nil, ContractReferenceType(contractType, Owned(), NotRemoteReferenceType()), context)
case Some((newExprSequence, cntxt, constr)) =>
val outTyp = constr.asInstanceOf[Constructor].resultType match {
case ContractReferenceType(_, permission, isRemote) =>
Expand Down Expand Up @@ -1048,7 +1048,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
private def errorIfNotDisposable(variable: String, typ: ObsidianType, context: Context, ast: AST): Unit = {
typ match {
case t: NonPrimitiveType =>
if (t.isOwned && t.isAssetReference(context.contractTable) != No()) {
if (t.isOwned && t.isAssetReference(context.contractTable) != No() && t.remoteReferenceType != TopLevelRemoteReferenceType()) {
logError(ast, UnusedOwnershipError(variable))
}
case _ => ()
Expand Down Expand Up @@ -1229,7 +1229,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
}
else {
if (passedType.isOwned && initialType.permission == Shared() && !finalType.isOwned &&
passedType.isAssetReference(context.contractTable) != No()) {
passedType.isAssetReference(context.contractTable) != No() && passedType.remoteReferenceType != TopLevelRemoteReferenceType()) {
// Special case: passing an owned reference to a Shared >> Unowned arg will make the arg Unowned but also lose ownership!
logError(arg, LostOwnershipErrorDueToSharing(arg))
}
Expand Down Expand Up @@ -1266,7 +1266,8 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
case _ =>
// If the argument isn't bound to a variable but owns an asset, and this call is not going to consume ownership, then error.
if (argType.isOwned && argType.isAssetReference(context.contractTable) != No() &&
(declaredFinalType.isOwned || !declaredInitialType.isOwned))
(declaredFinalType.isOwned || !declaredInitialType.isOwned) &&
argType.remoteReferenceType != TopLevelRemoteReferenceType())
{
if (declaredInitialType.permission == Shared()) {
logError(arg, LostOwnershipErrorDueToSharing(arg))
Expand Down Expand Up @@ -1610,7 +1611,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
val newTypeTable = thisTable.contractTable.state(newStateName).get
val newSimpleType =
if (oldType.isOwned) {
StateType(thisTable.contractType, newStateName, false)
StateType(thisTable.contractType, newStateName, NotRemoteReferenceType())
}
else {
// If the old "this" was unowned, we'd better not steal ownership for ourselves here.
Expand Down Expand Up @@ -1768,15 +1769,15 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
logError(e, StateCheckRedundant())
}
val typeFalse =
StateType(np.contractType, specificStates -- states, np.isRemote)
StateType(np.contractType, specificStates -- states, np.remoteReferenceType)
(contextPrime.updated(x, newType).updatedMakingVariableVal(x),
np.permission,
contextPrime.updated(x, typeFalse).updatedMakingVariableVal(x))
case _ =>
if (allStates == states) {
logError(e, StateCheckRedundant())
}
val typeFalse = StateType(np.contractType, allStates -- states, np.isRemote)
val typeFalse = StateType(np.contractType, allStates -- states, np.remoteReferenceType)
(contextPrime.updated(x, newType).updatedMakingVariableVal(x),
np.permission,
contextPrime.updated(x, typeFalse).updatedMakingVariableVal(x))
Expand All @@ -1799,7 +1800,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
logError(e, StateCheckRedundant())
}
val newType = np.withTypeState(state)
val typeFalse = StateType(np.contractType, allStates -- states, np.isRemote)
val typeFalse = StateType(np.contractType, allStates -- states, np.remoteReferenceType)
resetOwnership = Some((x, np))
(contextPrime.updated(x, newType).updatedMakingVariableVal(x),
np.permission,
Expand Down Expand Up @@ -1875,10 +1876,10 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
val newType: ObsidianType =
contractTable.state(sc.stateName) match {
case Some(stTable) =>
StateType(contractTable.contractType, stTable.name, false)
StateType(contractTable.contractType, stTable.name, NotRemoteReferenceType())
case None =>
logError(sc, StateUndefinedError(contractTable.name, sc.stateName))
ContractReferenceType(contractTable.contractType, Owned(), false)
ContractReferenceType(contractTable.contractType, Owned(), NotRemoteReferenceType())
}

/* special case to allow types to change in the context if we match on a variable */
Expand Down Expand Up @@ -2335,7 +2336,7 @@ class Checker(globalTable: SymbolTable, verbose: Boolean = false) {
Map.empty,
constr.args.map((v: VariableDeclWithSpec) => v.varName).toSet)

val thisType = ContractReferenceType(table.contractType, Owned(), false)
val thisType = ContractReferenceType(table.contractType, Owned(), NotRemoteReferenceType())

initContext = initContext.updated("this", thisType)

Expand Down

0 comments on commit bdfca5a

Please sign in to comment.