/
Compiler.scala
429 lines (395 loc) · 16.7 KB
/
Compiler.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
// See LICENSE for license details.
package firrtl
import logger._
import java.io.Writer
import annotations._
import scala.collection.mutable
import firrtl.annotations._ // Note that wildcard imports are not great....
import firrtl.ir.Circuit
import firrtl.Utils.{error, throwInternalError}
/**
* RenameMap maps old names to modified names. Generated by transformations
* that modify names
*/
object RenameMap {
def apply(map: Map[Named, Seq[Named]]) = {
val rm = new RenameMap
rm.addMap(map)
rm
}
def apply() = new RenameMap
}
class RenameMap {
val renameMap = new mutable.HashMap[Named, Seq[Named]]()
private var circuitName: String = ""
private var moduleName: String = ""
def setModule(s: String) =
moduleName = s
def setCircuit(s: String) =
circuitName = s
def rename(from: String, to: String): Unit = rename(from, Seq(to))
def rename(from: String, tos: Seq[String]): Unit = {
val fromName = ComponentName(from, ModuleName(moduleName, CircuitName(circuitName)))
val tosName = tos map { to =>
ComponentName(to, ModuleName(moduleName, CircuitName(circuitName)))
}
rename(fromName, tosName)
}
def rename(from: Named, to: Named): Unit = rename(from, Seq(to))
def rename(from: Named, tos: Seq[Named]): Unit = (from, tos) match {
case (x, Seq(y)) if x == y =>
case _ =>
renameMap(from) = renameMap.getOrElse(from, Seq.empty) ++ tos
}
def delete(names: Seq[String]): Unit = names.foreach(delete(_))
def delete(name: String): Unit =
delete(ComponentName(name, ModuleName(moduleName, CircuitName(circuitName))))
def delete(name: Named): Unit =
renameMap(name) = Seq.empty
def addMap(map: Map[Named, Seq[Named]]) =
renameMap ++= map
def serialize: String = renameMap.map { case (k, v) =>
k.serialize + "=>" + v.map(_.serialize).mkString(", ")
}.mkString("\n")
}
/**
* Container of all annotations for a Firrtl compiler.
*/
case class AnnotationMap(annotations: Seq[Annotation]) {
def get(id: Class[_]): Seq[Annotation] = annotations.filter(a => a.transform == id)
def get(named: Named): Seq[Annotation] = annotations.filter(n => n == named)
}
/** Current State of the Circuit
*
* @constructor Creates a CircuitState object
* @param circuit The current state of the Firrtl AST
* @param form The current form of the circuit
* @param annotations The current collection of [[firrtl.annotations.Annotation Annotation]]
* @param renames A map of [[firrtl.annotations.Named Named]] things that have been renamed.
* Generally only a return value from [[Transform]]s
*/
case class CircuitState(
circuit: Circuit,
form: CircuitForm,
annotations: Option[AnnotationMap] = None,
renames: Option[RenameMap] = None) {
/** Helper for getting just an emitted circuit */
def emittedCircuitOption: Option[EmittedCircuit] =
emittedComponents collectFirst { case x: EmittedCircuit => x }
/** Helper for getting an [[EmittedCircuit]] when it is known to exist */
def getEmittedCircuit: EmittedCircuit = emittedCircuitOption match {
case Some(emittedCircuit) => emittedCircuit
case None =>
throw new FIRRTLException(s"No EmittedCircuit found! Did you delete any annotations?\n$deletedAnnotations")
}
/** Helper function for extracting emitted components from annotations */
def emittedComponents: Seq[EmittedComponent] = {
val emittedOpt = annotations map (_.annotations collect {
case EmittedCircuitAnnotation(x) => x
case EmittedModuleAnnotation(x) => x
})
emittedOpt.getOrElse(Seq.empty)
}
def deletedAnnotations: Seq[Annotation] = {
val deletedOpt = annotations map (_.annotations collect {
case DeletedAnnotation(xformName, anno) =>
DeletedAnnotation(xformName, anno)
})
deletedOpt.getOrElse(Seq.empty)
}
}
/** Current form of the Firrtl Circuit
*
* Form is a measure of addition restrictions on the legality of a Firrtl
* circuit. There is a notion of "highness" and "lowness" implemented in the
* compiler by extending scala.math.Ordered. "Lower" forms add additional
* restrictions compared to "higher" forms. This means that "higher" forms are
* strictly supersets of the "lower" forms. Thus, that any transform that
* operates on [[HighForm]] can also operate on [[MidForm]] or [[LowForm]]
*/
sealed abstract class CircuitForm(private val value: Int) extends Ordered[CircuitForm] {
// Note that value is used only to allow comparisons
def compare(that: CircuitForm): Int = this.value - that.value
}
/** Chirrtl Form
*
* The form of the circuit emitted by Chisel. Not a true Firrtl form.
* Includes cmem, smem, and mport IR nodes which enable declaring memories
* separately form their ports. A "Higher" form than [[HighForm]]
*
* See [[CDefMemory]] and [[CDefMPort]]
*/
final case object ChirrtlForm extends CircuitForm(3)
/** High Form
*
* As detailed in the Firrtl specification
* [[https://github.com/ucb-bar/firrtl/blob/master/spec/spec.pdf]]
*
* Also see [[firrtl.ir]]
*/
final case object HighForm extends CircuitForm(2)
/** Middle Form
*
* A "lower" form than [[HighForm]] with the following restrictions:
* - All widths must be explicit
* - All whens must be removed
* - There can only be a single connection to any element
*/
final case object MidForm extends CircuitForm(1)
/** Low Form
*
* The "lowest" form. In addition to the restrictions in [[MidForm]]:
* - All aggregate types (vector/bundle) must have been removed
* - All implicit truncations must be made explicit
*/
final case object LowForm extends CircuitForm(0)
/** Unknown Form
*
* Often passes may modify a circuit (e.g. InferTypes), but return
* a circuit in the same form it was given.
*
* For this use case, use UnknownForm. It cannot be compared against other
* forms.
*
* TODO(azidar): Replace with PreviousForm, which more explicitly encodes
* this requirement.
*/
final case object UnknownForm extends CircuitForm(-1) {
override def compare(that: CircuitForm): Int = { error("Illegal to compare UnknownForm"); 0 }
}
/** The basic unit of operating on a Firrtl AST */
abstract class Transform extends LazyLogging {
/** A convenience function useful for debugging and error messages */
def name: String = this.getClass.getSimpleName
/** The [[firrtl.CircuitForm]] that this transform requires to operate on */
def inputForm: CircuitForm
/** The [[firrtl.CircuitForm]] that this transform outputs */
def outputForm: CircuitForm
/** Perform the transform, encode renaming with RenameMap, and can
* delete annotations
* Called by [[runTransform]].
*
* @param state Input Firrtl AST
* @return A transformed Firrtl AST
*/
protected def execute(state: CircuitState): CircuitState
/** Convenience method to get annotations relevant to this Transform
*
* @param state The [[CircuitState]] form which to extract annotations
* @return A collection of annotations
*/
final def getMyAnnotations(state: CircuitState): Seq[Annotation] = state.annotations match {
case Some(annotations) => annotations.get(this.getClass) //TODO(azidar): ++ annotations.get(classOf[Transform])
case None => Nil
}
/** Perform the transform and update annotations.
*
* @param state Input Firrtl AST
* @return A transformed Firrtl AST
*/
final def runTransform(state: CircuitState): CircuitState = {
logger.info(s"======== Starting Transform $name ========")
val (timeMillis, result) = Utils.time { execute(state) }
logger.info(s"""----------------------------${"-" * name.size}---------\n""")
logger.info(f"Time: $timeMillis%.1f ms")
val remappedAnnotations = propagateAnnotations(state.annotations, result.annotations, result.renames)
logger.info(s"Form: ${result.form}")
logger.debug(s"Annotations:")
remappedAnnotations.foreach { a =>
logger.debug(a.serialize)
}
logger.trace(s"Circuit:\n${result.circuit.serialize}")
logger.info(s"======== Finished Transform $name ========\n")
CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations)), None)
}
/** Propagate annotations and update their names.
*
* @param inAnno input AnnotationMap
* @param resAnno result AnnotationMap
* @param renameOpt result RenameMap
* @return the updated annotations
*/
final private def propagateAnnotations(
inAnno: Option[AnnotationMap],
resAnno: Option[AnnotationMap],
renameOpt: Option[RenameMap]): Seq[Annotation] = {
val newAnnotations = {
val inSet = inAnno.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet
val resSet = resAnno.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet
val deleted = (inSet -- resSet).map {
case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"$xFormName+$name", delAnno)
case anno => DeletedAnnotation(name, anno)
}
val created = resSet -- inSet
val unchanged = resSet & inSet
(deleted ++ created ++ unchanged)
}
// For each annotation, rename all annotations.
val renames = renameOpt.getOrElse(RenameMap()).renameMap
for {
anno <- newAnnotations.toSeq
newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target)))
} yield newAnno
}
}
trait SeqTransformBased {
def transforms: Seq[Transform]
protected def runTransforms(state: CircuitState): CircuitState =
transforms.foldLeft(state) { (in, xform) => xform.runTransform(in) }
}
/** For transformations that are simply a sequence of transforms */
abstract class SeqTransform extends Transform with SeqTransformBased {
def execute(state: CircuitState): CircuitState = {
/*
require(state.form <= inputForm,
s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}")
*/
val ret = runTransforms(state)
CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames)
}
}
/** Defines old API for Emission. Deprecated */
trait Emitter extends Transform {
@deprecated("Use emission annotations instead", "firrtl 1.0")
def emit(state: CircuitState, writer: Writer): Unit
}
object CompilerUtils extends LazyLogging {
/** Generates a sequence of [[Transform]]s to lower a Firrtl circuit
*
* @param inputForm [[CircuitForm]] to lower from
* @param outputForm [[CircuitForm to lower to
* @return Sequence of transforms that will lower if outputForm is lower than inputForm
*/
def getLoweringTransforms(inputForm: CircuitForm, outputForm: CircuitForm): Seq[Transform] = {
// If outputForm is equal-to or higher than inputForm, nothing to lower
if (outputForm >= inputForm) {
Seq.empty
} else {
inputForm match {
case ChirrtlForm =>
Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm)
case HighForm =>
Seq(new IRToWorkingIR, new ResolveAndCheck, new transforms.DedupModules,
new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm)
case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm)
case LowForm => throwInternalError // should be caught by if above
case UnknownForm => throwInternalError // should be caught by if above
}
}
}
/** Merge a Seq of lowering transforms with custom transforms
*
* Custom Transforms are inserted based on their [[Transform.inputForm]] and
* [[Transform.outputForm]]. Custom transforms are inserted in order at the
* last location in the Seq of transforms where previous.outputForm ==
* customTransform.inputForm. If a customTransform outputs a higher form
* than input, [[getLoweringTransforms]] is used to relower the circuit.
*
* @example
* {{{
* // Let Transforms be represented by CircuitForm => CircuitForm
* val A = HighForm => MidForm
* val B = MidForm => LowForm
* val lowering = List(A, B) // Assume these transforms are used by getLoweringTransforms
* // Some custom transforms
* val C = LowForm => LowForm
* val D = MidForm => MidForm
* val E = LowForm => HighForm
* // All of the following comparisons are true
* mergeTransforms(lowering, List(C)) == List(A, B, C)
* mergeTransforms(lowering, List(D)) == List(A, D, B)
* mergeTransforms(lowering, List(E)) == List(A, B, E, A, B)
* mergeTransforms(lowering, List(C, E)) == List(A, B, C, E, A, B)
* mergeTransforms(lowering, List(E, C)) == List(A, B, E, A, B, C)
* // Notice that in the following, custom transform order is NOT preserved (see note)
* mergeTransforms(lowering, List(C, D)) == List(A, D, B, C)
* }}}
*
* @note Order will be preserved for custom transforms so long as the
* inputForm of a latter transforms is equal to or lower than the outputForm
* of the previous transform.
*/
def mergeTransforms(lowering: Seq[Transform], custom: Seq[Transform]): Seq[Transform] = {
custom.foldLeft(lowering) { case (transforms, xform) =>
val index = transforms lastIndexWhere (_.outputForm == xform.inputForm)
assert(index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front
s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform")
val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index
front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back
}
}
}
trait Compiler extends LazyLogging {
def emitter: Emitter
/** The sequence of transforms this compiler will execute
* @note The inputForm of a given transform must be higher than or equal to the ouputForm of the
* preceding transform. See [[CircuitForm]]
*/
def transforms: Seq[Transform]
// Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms
def inputForm = transforms.head.inputForm
def outputForm = transforms.last.outputForm
private def transformsLegal(xforms: Seq[Transform]): Boolean =
if (xforms.size < 2) {
true
} else {
xforms.sliding(2, 1)
.map { case Seq(p, n) => n.inputForm >= p.outputForm }
.reduce(_ && _)
}
assert(transformsLegal(transforms),
"Illegal Compiler, each transform must be able to accept the output of the previous transform!")
/** Perform compilation
*
* @param state The Firrtl AST to compile
* @param writer The java.io.Writer where the output of compilation will be emitted
* @param customTransforms Any custom [[Transform]]s that will be inserted
* into the compilation process by [[CompilerUtils.mergeTransforms]]
*/
@deprecated("Please use compileAndEmit or other compile method instead", "firrtl 1.0")
def compile(state: CircuitState,
writer: Writer,
customTransforms: Seq[Transform] = Seq.empty): CircuitState = {
val finalState = compileAndEmit(state, customTransforms)
writer.write(finalState.getEmittedCircuit.value)
finalState
}
/** Perform compilation and emit the whole Circuit
*
* This is intended as a convenience method wrapping up Annotation creation for the common case.
* It creates a [[EmitCircuitAnnotation]] that will be consumed by this Transform's emitter. The
* [[EmittedCircuit]] can be extracted from the returned [[CircuitState]] via
* [[CircuitState.emittedCircuitOption]]
*
* @param state The Firrtl AST to compile
* @param customTransforms Any custom [[Transform]]s that will be inserted
* into the compilation process by [[CompilerUtils.mergeTransforms]]
* @return result of compilation with emitted circuit annotated
*/
def compileAndEmit(state: CircuitState,
customTransforms: Seq[Transform] = Seq.empty): CircuitState = {
val emitAnno = EmitCircuitAnnotation(emitter.getClass)
// TODO This is ridiculous. Fix Annotations
val annotations = state.annotations.map(_.annotations).getOrElse(Seq.empty)
val annotationMap = AnnotationMap(annotations :+ emitAnno)
// Run compiler
compile(state.copy(annotations = Some(annotationMap)), customTransforms)
}
/** Perform compilation
*
* Emission will only be performed if [[EmitAnnotation]]s are present
*
* @param state The Firrtl AST to compile
* @param customTransforms Any custom [[Transform]]s that will be inserted into the compilation
* process by [[CompilerUtils.mergeTransforms]]
* @return result of compilation
*/
def compile(state: CircuitState, customTransforms: Seq[Transform]): CircuitState = {
val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms) :+ emitter
val (timeMillis, finalState) = Utils.time {
allTransforms.foldLeft(state) { (in, xform) => xform.runTransform(in) }
}
logger.error(f"Total FIRRTL Compile Time: $timeMillis%.1f ms")
finalState
}
}