Skip to content

Commit

Permalink
Always add variable to capture result of non-void invoke and simplify…
Browse files Browse the repository at this point in the history
… code that previously did so only when result was used in a check
  • Loading branch information
conradz committed Jun 20, 2024
1 parent e8da0c0 commit 7aeb1f4
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 130 deletions.
8 changes: 4 additions & 4 deletions src/main/scala/gvc/main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,13 @@ object Main extends App {
fileNames.irFileName,
IRPrinter.print(ir, includeSpecs = true)
)
if (config.dump.contains(Config.DumpSilver)) dump(silver.program.toString())
if (config.dump.contains(Config.DumpSilver)) dump(silver.toString())
else if (config.saveFiles)
writeFile(fileNames.silverFileName, silver.program.toString())
writeFile(fileNames.silverFileName, silver.toString())

val verificationStart = System.nanoTime()
silicon.start()
silicon.verify(silver.program) match {
silicon.verify(silver) match {
case verifier.Success => if (stopImmediately) silicon.stop()
case verifier.Failure(errors) =>
val message = errors.map(_.readableMessage).mkString("\n")
Expand All @@ -423,7 +423,7 @@ object Main extends App {
if (config.dump.contains(Config.DumpC0))
dumpC0(c0Source)
VerifiedOutput(
silver.program,
silver,
c0Source,
ProfilingInfo(
profilingInfo.getTotalConjuncts,
Expand Down
50 changes: 11 additions & 39 deletions src/main/scala/gvc/transformer/IRSilver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@ import scala.collection.mutable

case class SilverVarId(methodName: String, varName: String)

class SilverProgram(
val program: vpr.Program,

// Map of (methodName, varName) Silver variables that represent the result
// of the invoke
val temporaryVars: Map[SilverVarId, IR.Invoke]
)

object IRSilver {
def toSilver(program: IR.Program) = new Converter(program).convert()

Expand All @@ -22,22 +14,6 @@ object IRSilver {
val RenamedResult = "_result$"
}

private class TempVars(methodName: String, index: mutable.Map[SilverVarId, IR.Invoke]) {
private var counter = -1
val declarations = mutable.ListBuffer[vpr.LocalVarDecl]()

def next(invoke: IR.Invoke, t: vpr.Type): vpr.LocalVar = {
counter += 1
val name = Names.TempResultPrefix + counter

index += SilverVarId(methodName, name) -> invoke

val decl = vpr.LocalVarDecl(name, t)()
declarations += decl
decl.localVar
}
}

class Converter(ir: IR.Program) {
val fields = mutable.ListBuffer[vpr.Field]()
val structFields = mutable.Map[IR.StructField, vpr.Field]()
Expand All @@ -48,11 +24,10 @@ object IRSilver {
field
}

def convert(): SilverProgram = {
def convert(): vpr.Program = {
val predicates = ir.predicates.map(convertPredicate).toList
val tempVarIndex = mutable.Map[SilverVarId, IR.Invoke]()
val methods = (
ir.methods.map(convertMethod(_, tempVarIndex)) ++
ir.methods.map(convertMethod) ++
ir.dependencies.flatMap(_.methods.map(convertLibraryMethod))
).toList
val fields = this.fields.toSeq.sortBy(_.name).toList
Expand All @@ -66,7 +41,7 @@ object IRSilver {
Seq.empty
)()

new SilverProgram(program, tempVarIndex.toMap)
program
}

private def returnVarDecl(t: Option[IR.Type]): List[vpr.LocalVarDecl] = {
Expand All @@ -92,17 +67,14 @@ object IRSilver {
)()
}

private def convertMethod(method: IR.Method, tempVarIndex: mutable.Map[SilverVarId, IR.Invoke]): vpr.Method = {
var tempCount = 0
private def convertMethod(method: IR.Method): vpr.Method = {

val params = method.parameters.map(convertDecl).toList
val vars = method.variables.map(convertDecl).toList
val decls = method.variables.map(convertDecl).toList
val ret = returnVarDecl(method.returnType)
val pre = method.precondition.map(convertExpr).toSeq
val post = method.postcondition.map(convertExpr).toSeq
val tempVars = new TempVars(method.name, tempVarIndex)
val body = method.body.flatMap(convertOp(_, tempVars)).toList
val decls = vars ++ tempVars.declarations.toList
val body = method.body.flatMap(convertOp).toList

vpr.Method(
method.name,
Expand Down Expand Up @@ -139,10 +111,10 @@ object IRSilver {
def getReturnVar(method: IR.Method): vpr.LocalVar =
vpr.LocalVar(Names.ReturnVar, convertType(method.returnType.get))()

private def convertOp(op: IR.Op, tempVars: TempVars): Seq[vpr.Stmt] = op match {
private def convertOp(op: IR.Op): Seq[vpr.Stmt] = op match {
case iff: IR.If => {
val ifTrue = iff.ifTrue.flatMap(convertOp(_, tempVars)).toList
val ifFalse = iff.ifFalse.flatMap(convertOp(_, tempVars)).toList
val ifTrue = iff.ifTrue.flatMap(convertOp).toList
val ifFalse = iff.ifFalse.flatMap(convertOp).toList
Seq(
vpr.If(
convertExpr(iff.condition),
Expand All @@ -157,7 +129,7 @@ object IRSilver {
vpr.While(
convertExpr(loop.condition),
List(convertExpr(loop.invariant)),
vpr.Seqn(loop.body.flatMap(convertOp(_, tempVars)).toList, Seq.empty)()
vpr.Seqn(loop.body.flatMap(convertOp).toList, Seq.empty)()
)()
)
}
Expand All @@ -172,7 +144,7 @@ object IRSilver {

case None => invoke.callee.returnType match {
case Some(retType) =>
Some(tempVars.next(invoke, convertType(retType)))
throw new IRException("Cannot convert invoke of non-void method with no target")
case None =>
None
}
Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/gvc/transformer/IRTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -668,10 +668,13 @@ object IRTransformer {
scope += new IR.Invoke(method, args, Some(target))
}

def invokeVoid(input: ResolvedInvoke, scope: Scope): Unit = {
def invokeVoid(input: ResolvedInvoke, scope: MethodScope): Unit = {
val method = resolveMethod(input)
val args = input.arguments.map(arg => transformExpr(arg, scope))
scope += new IR.Invoke(method, args, None)
// Add a variable to capture the result, even when it is not used. (The
// [conditions of] Viper run-time checks may reference it.)
val target = method.returnType.map(t => scope.method.addVar(t))
scope += new IR.Invoke(method, args, target)
}

def resolveMethod(invoke: ResolvedInvoke): IR.MethodDefinition =
Expand Down
42 changes: 7 additions & 35 deletions src/main/scala/gvc/weaver/Checker.scala
Original file line number Diff line number Diff line change
@@ -1,40 +1,13 @@
package gvc.weaver

import gvc.transformer.{IR, SilverVarId}
import gvc.transformer.IR
import Collector._
import scala.collection.mutable
import scala.annotation.tailrec

object Checker {
type StructIDTracker = Map[String, IR.StructField]

class CheckerMethod(
val method: IR.Method,
tempVars: Map[SilverVarId, IR.Invoke]
) extends CheckMethod {
val resultVars = mutable.Map[String, IR.Expression]()
def resultVar(name: String): IR.Expression = {
resultVars.getOrElseUpdate(
name, {
val invoke = tempVars.getOrElse(
SilverVarId(method.name, name),
throw new WeaverException(s"Missing temporary variable '$name'")
)
invoke.target.getOrElse {
val retType = invoke.method.returnType.getOrElse(
throw new WeaverException(
s"Invalid temporary variable '$name' for void '${invoke.callee.name}'"
)
)
val tempVar = method.addVar(retType)
invoke.target = Some(tempVar)
tempVar
}
}
)
}
}

def insert(program: Collector.CollectedProgram): Unit = {
val runtime = CheckRuntime.addToIR(program.program)

Expand All @@ -61,7 +34,6 @@ object Checker {
): Unit = {
val program = programData.program
val method = methodData.method
val checkMethod = new CheckerMethod(method, programData.temporaryVars)

val callsImprecise: Boolean = methodData.calls.exists(c =>
programData.methods.get(c.ir.callee.name) match {
Expand Down Expand Up @@ -106,7 +78,7 @@ object Checker {
}

def getCondition(cond: Condition): IR.Expression = cond match {
case ImmediateCondition(expr) => expr.toIR(program, checkMethod, None)
case ImmediateCondition(expr) => expr.toIR(program, method, None)
case cond: TrackedCondition => conditionVars(cond)
case NotCondition(value) =>
new IR.Unary(IR.UnaryOp.Not, getCondition(value))
Expand Down Expand Up @@ -193,7 +165,7 @@ object Checker {
// Insert the runtime checks
// Group them by location and condition, so that multiple checks can be contained in a single
// if block.
val context = CheckContext(program, checkMethod, implementation, runtime)
val context = CheckContext(program, method, implementation, runtime)
for ((loc, checkData) <- groupChecks(methodData.checks)) {
insertAt(
loc,
Expand All @@ -205,7 +177,7 @@ object Checker {

def getTemporaryOwnedFields(): IR.Var =
temporaryOwnedFields.getOrElse {
val tempVar = context.method.method.addVar(
val tempVar = context.method.addVar(
context.runtime.ownedFieldsRef,
CheckRuntime.Names.temporaryOwnedFields
)
Expand Down Expand Up @@ -394,7 +366,7 @@ object Checker {
conds.map(
c =>
new IR.Assign(conditionVars(c),
c.value.toIR(program, checkMethod, retVal)))
c.value.toIR(program, method, retVal)))
})
}

Expand Down Expand Up @@ -474,7 +446,7 @@ object Checker {

case class CheckContext(
program: IR.Program,
method: CheckMethod,
method: IR.Method,
implementation: CheckImplementation,
runtime: CheckRuntime
)
Expand Down Expand Up @@ -604,7 +576,7 @@ object Checker {
case u: CheckExpression.Unary =>
nesting(u.operand) + 1
case _: CheckExpression.Literal | _: CheckExpression.Var |
CheckExpression.Result | _: CheckExpression.ResultVar =>
CheckExpression.Result =>
1
}
}
41 changes: 14 additions & 27 deletions src/main/scala/gvc/weaver/Checks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ import gvc.transformer.{IR, IRSilver}

sealed trait Check

trait CheckMethod {
def method: IR.Method
def resultVar(name: String): IR.Expression
}

object Check {
def fromViper(
check: vpr.Exp,
Expand Down Expand Up @@ -75,7 +70,7 @@ case class PredicateAccessibilityCheck(
sealed trait CheckExpression extends Check {
def toIR(
p: IR.Program,
m: CheckMethod,
m: IR.Method,
returnValue: Option[IR.Expression]
): IR.Expression

Expand Down Expand Up @@ -103,7 +98,7 @@ object CheckExpression {
def op: IR.BinaryOp
def toIR(
p: IR.Program,
m: CheckMethod,
m: IR.Method,
r: Option[IR.Expression]
): IR.Binary =
new IR.Binary(op, left.toIR(p, m, r), right.toIR(p, m, r))
Expand Down Expand Up @@ -154,7 +149,7 @@ object CheckExpression {
def op: IR.UnaryOp
def toIR(
p: IR.Program,
m: CheckMethod,
m: IR.Method,
r: Option[IR.Expression]
): IR.Unary =
new IR.Unary(op, operand.toIR(p, m, r))
Expand All @@ -168,15 +163,8 @@ object CheckExpression {
}

case class Var(name: String) extends Expr {
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) = {
m.method.variable(name)
}
def guard = None
}

case class ResultVar(name: String) extends Expr {
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) = {
m.resultVar(name)
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) = {
m.variable(name)
}
def guard = None
}
Expand All @@ -191,14 +179,14 @@ object CheckExpression {
throw new WeaverException(s"Field '$fieldName' does not exist")
)

def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) =
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) =
new IR.FieldMember(root.toIR(p, m, r), getIRField(p))

def guard = Some(and(root.guard, Not(Eq(root, NullLit))))
}

case class Deref(operand: Expr) extends Expr {
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) =
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) =
new IR.DereferenceMember(operand.toIR(p, m, r))
def guard = Some(and(operand.guard, Not(Eq(operand, NullLit))))
}
Expand All @@ -208,24 +196,24 @@ object CheckExpression {
}

case class IntLit(value: Int) extends Literal {
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) =
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) =
new IR.IntLit(value)
}
case class CharLit(value: Char) extends Literal {
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) =
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) =
new IR.CharLit(value)
}
case class StrLit(value: String) extends Literal {
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) =
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) =
new IR.StringLit(value)
}
case object NullLit extends Literal {
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) =
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) =
new IR.NullLit()
}
sealed trait BoolLit extends Literal {
def value: Boolean
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) =
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) =
new IR.BoolLit(value)
}
object BoolLit {
Expand All @@ -239,7 +227,7 @@ object CheckExpression {
}

case class Cond(cond: Expr, ifTrue: Expr, ifFalse: Expr) extends Expr {
def toIR(p: IR.Program, m: CheckMethod, r: Option[IR.Expression]) =
def toIR(p: IR.Program, m: IR.Method, r: Option[IR.Expression]) =
new IR.Conditional(
cond.toIR(p, m, r),
ifTrue.toIR(p, m, r),
Expand All @@ -266,7 +254,7 @@ object CheckExpression {
case object Result extends Expr {
def toIR(
p: IR.Program,
m: CheckMethod,
m: IR.Method,
returnValue: Option[IR.Expression]
): IR.Expression =
returnValue.getOrElse(
Expand Down Expand Up @@ -373,7 +361,6 @@ object CheckExpression {
v.name match {
case IRSilver.Names.ReturnVar => Result
case IRSilver.Names.RenamedResult => Var(IRSilver.Names.ReservedResult)
case temp if temp.startsWith(IRSilver.Names.TempResultPrefix) => ResultVar(temp)
case id => Var(id)
}

Expand Down
Loading

0 comments on commit 7aeb1f4

Please sign in to comment.