Permalink
Browse files

That took 10 minutes

  • Loading branch information...
olafurpg committed Dec 9, 2016
1 parent efdbe17 commit da7374d0c0ed2775c1e2fa1f024307fbcd1fbd1d
@@ -0,0 +1,27 @@
package scalafix.rewrite
import scala.meta._
import scala.meta.tokens.Token.Comment
import scala.meta.tokens.Token.LeftBrace
import scala.meta.tokens.Token.RightBrace
import scala.meta.tokens.Token.RightParen
import scala.meta.tokens.Token.Space
import scalafix.util.Patch
import scalafix.util.Whitespace
case object NoPostfix extends Rewrite {
override def rewrite(ast: Tree, ctx: RewriteCtx): Seq[Patch] = {
import ctx.tokenList._
val patches: Seq[Patch] = ast.collect {
case Term.Select(_, name)
if revFind(name.tokens.head)(!_.is[Whitespace])
.exists(!_.is[Token.Dot]) =>
val end = name.tokens.head
val start =
if (prev(end).is[Space]) prev(end)
else end
Patch(start, end, s".$end")
}
patches
}
}
@@ -16,7 +16,11 @@ object Rewrite {
t.map(x => x.source -> x.value).toMap
}
val syntaxRewrites: Seq[Rewrite] = Seq(ProcedureSyntax, VolatileLazyVal)
val syntaxRewrites: Seq[Rewrite] = Seq(
ProcedureSyntax,
VolatileLazyVal,
NoPostfix
)
val semanticRewrites: Seq[Rewrite] = Seq(ExplicitImplicit)
val allRewrites: Seq[Rewrite] = syntaxRewrites ++ semanticRewrites
val name2rewrite: Map[String, Rewrite] =
@@ -14,11 +14,18 @@ class TokenList(tokens: Tokens) {
map.result()
}
def find(start: Token)(f: Token => Boolean): Option[Token] = {
def find(start: Token)(f: Token => Boolean): Option[Token] =
genericFind(next)(start)(f)
def revFind(start: Token)(f: Token => Boolean): Option[Token] =
genericFind(prev)(start)(f)
def genericFind(step: Token => Token)(start: Token)(
f: Token => Boolean): Option[Token] = {
def loop(curr: Token): Option[Token] = {
if (f(curr)) Option(curr)
else {
val iter = next(curr)
val iter = step(curr)
if (iter == curr) None // reached EOF
else loop(iter)
}
@@ -0,0 +1,5 @@
// foo
<<< postfix 1
class A { List(1, 2) tail }
>>>
class A { List(1, 2).tail }
@@ -40,8 +40,8 @@ class SemanticTests extends FunSuite {
private object fixer extends NscSemanticApi {
lazy val global: SemanticTests.this.g.type = SemanticTests.this.g
def apply(unit: g.CompilationUnit): Fixed =
fix(unit, ScalafixConfig(rewrites = List(rewrite)))
def apply(unit: g.CompilationUnit, config: ScalafixConfig): Fixed =
fix(unit, config)
}
def wrap(code: String, diffTest: DiffTest): String = {
@@ -79,10 +79,17 @@ class SemanticTests extends FunSuite {
unit
}
def fix(code: String): String = {
val Fixed.Success(fixed) = fixer(getTypedCompilationUnit(code))
def fix(code: String, t: DiffTest): String = {
val config =
ScalafixConfig.fromNames(List(t.spec.replaceAll("/.*", ""))) match {
case Right(x) => x
case Left(msg) => throw new IllegalArgumentException(msg)
}
val Fixed.Success(fixed) =
fixer(getTypedCompilationUnit(code), config)
fixed
}
case class MismatchException(details: String) extends Exception
private def checkMismatchesModuloDesugarings(obtained: m.Tree,
expected: m.Tree): Unit = {
@@ -141,7 +148,7 @@ class SemanticTests extends FunSuite {
}
def check(original: String, expectedStr: String, diffTest: DiffTest): Unit = {
val fixed = fix(wrap(original, diffTest))
val fixed = fix(wrap(original, diffTest), diffTest)
val obtained = parse(fixed)
val expected = parse(expectedStr)
try {

0 comments on commit da7374d

Please sign in to comment.