Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebAssembly backend #500

Open
wants to merge 14 commits into
base: scala-2
Choose a base branch
from
3 changes: 2 additions & 1 deletion core/src/main/scala/stainless/frontend/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ package object frontend {
val allComponents: Seq[Component] = Seq(
verification.VerificationComponent,
termination.TerminationComponent,
evaluators.EvaluatorComponent
evaluators.EvaluatorComponent,
wasmgen.WasmComponent
)

/**
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/scala/stainless/wasmgen/LibProvider.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* Copyright 2009-2019 EPFL, Lausanne */

package stainless.wasmgen

trait LibProvider {
import LibProvider.libPath

protected val trees: stainless.ast.Trees

def fun(name: String)(implicit s: trees.Symbols): trees.FunDef =
s.lookup[trees.FunDef](libPath + name)

def sort(name: String)(implicit s: trees.Symbols): trees.ADTSort =
s.lookup[trees.ADTSort](libPath + name)
}

object LibProvider {
val libPath = "stainless.wasm.Runtime."
}
68 changes: 68 additions & 0 deletions core/src/main/scala/stainless/wasmgen/WasmComponent.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright 2009-2019 EPFL, Lausanne */

package stainless
package wasmgen

import inox.Context
import inox.transformers.SymbolTransformer
import extraction.StainlessPipeline
import utils.{CheckFilter, DependenciesFinder}

import scala.concurrent.Future

object DebugSectionWasm extends inox.DebugSection("wasm")

class WasmAnalysis extends AbstractAnalysis {
val name = "no analysis"
type Report = NoReport

def toReport = new NoReport
}

class WasmFilter(val context: Context) extends CheckFilter {
val trees: stainless.trees.type = stainless.trees

override def shouldBeChecked(fd: trees.FunDef): Boolean = {
fd.params.isEmpty && super.shouldBeChecked(fd)
}
}

object WasmComponent extends Component {
val name = "wasm-codegen"
val description = "Generate WebAssembly code that runs parameterless functions in the program"
type Report = NoReport
type Analysis = WasmAnalysis

override val lowering: SymbolTransformer {
val s: extraction.trees.type
val t: extraction.trees.type
} = inox.transformers.SymbolTransformer(new transformers.TreeTransformer {
val s: extraction.trees.type = extraction.trees
val t: extraction.trees.type = extraction.trees
})

def run(pipeline: StainlessPipeline)(implicit context: Context) =
new WasmComponentRun(pipeline)(context)
}

class WasmComponentRun(override val pipeline: StainlessPipeline)
(override implicit val context: Context) extends {
override val component = WasmComponent
override val trees: stainless.trees.type = stainless.trees
} with ComponentRun {

def parse(json: io.circe.Json): NoReport = new NoReport

override def createFilter: WasmFilter = new WasmFilter(this.context)

override lazy val dependenciesFinder: DependenciesFinder { val t: stainless.trees.type } = new WasmDependenciesFinder

private[stainless] def execute(functions: Seq[Identifier], symbols: trees.Symbols): Future[WasmAnalysis] = {
Future {
val intermSyms: intermediate.trees.Symbols = new intermediate.Lowering(context).transform(symbols)
val module = codegen.LinearMemoryCodeGen.transform(intermSyms, functions)
new wasm.FileWriter(module, context).writeFiles()
new WasmAnalysis
}
}
}
92 changes: 92 additions & 0 deletions core/src/main/scala/stainless/wasmgen/WasmDepFinder.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/* Copyright 2009-2019 EPFL, Lausanne */

package stainless.wasmgen

import stainless.Identifier
import stainless.trees._
import stainless.utils.{DefinitionIdFinder, DependenciesFinder}

class WasmDefIdFinder(val s: Symbols) extends DefinitionIdFinder { outer =>
val trees = stainless.trees
private val lib: LibProvider { val trees: outer.trees.type } = new LibProvider {
protected val trees = outer.trees
}
private def fun(name: String) = lib.fun(name)(s).id
private def sort(name: String) = lib.sort(name)(s).id
private lazy val setIds = Set(sort("Set"), fun("SNil$0ToString"), fun("SCons$0ToString"))
private lazy val bagIds = Set(sort("Bag"), fun("BNil$0ToString"), fun("BCons$0ToString"))
private lazy val mapIds = Set(sort("Map"), fun("MNil$0ToString"), fun("MCons$0ToString"))

override def traverse(expr: Expr, env: Env): Unit = {
expr match {
// Tuples
case Tuple(elems) =>
ids += sort(s"Tuple${elems.size}")
case TupleSelect(tuple, _) =>
val TupleType(ts) = tuple.getType(s)
ids += sort(s"Tuple${ts.size}")
// Sets
case FiniteSet(_, _) | SetAdd(_, _) =>
ids += fun("setAdd")
ids ++= setIds
case ElementOfSet(_, _) =>
ids += fun("elementOfSet")
ids ++= setIds
case SubsetOf(_, _) =>
ids += fun("subsetOf")
ids ++= setIds
case SetIntersection(_, _) =>
ids += fun("setIntersection")
ids ++= setIds
case SetUnion(_, _) =>
ids += fun("setUnion")
ids ++= setIds
case SetDifference(_, _) =>
ids += fun("setDifference")
ids ++= setIds
// Bags
case FiniteBag(_, _) | BagAdd(_, _) =>
ids += fun("bagAdd")
ids ++= bagIds
case MultiplicityInBag(_, _) =>
ids += fun("bagMultiplicity")
ids ++= bagIds
case BagIntersection(_, _) =>
ids += fun("bagIntersection")
ids ++= bagIds
case BagUnion(_, _) =>
ids += fun("bagUnion")
ids ++= bagIds
case BagDifference(_, _) =>
ids += fun("bagDifference")
ids ++= bagIds
// Maps
case FiniteMap(_, _, _, _) | MapUpdated(_, _, _) =>
ids += fun("mapUpdated")
ids ++= mapIds
case MapApply(_, _) =>
ids += fun("mapApply")
ids ++= mapIds
case _ =>
}
super.traverse(expr, env)
}
}


class WasmDependenciesFinder extends DependenciesFinder {
val t: stainless.trees.type = stainless.trees
def traverser(s: Symbols): DefinitionIdFinder { val trees: t.type } = new WasmDefIdFinder(s)
private val lib: LibProvider { val trees: t.type } = new LibProvider {
protected val trees = t
}
override def findDependencies(roots: Set[Identifier], s: Symbols): Symbols = {
super.findDependencies(roots, s)
.withFunctions(Seq(
"toString", "digitToStringL", "digitToStringI",
"i32ToString", "i64ToString", "f64ToString",
"booleanToString", "funToString", "unitToString"
).map(lib.fun(_)(s)))
}
}

Loading