Skip to content

Commit

Permalink
ExplicitNonNullaryApply: workaround for scalameta/scalameta#1083
Browse files Browse the repository at this point in the history
and scalacenter/scalafix#1104

Given
```
trait Test {
  def shouldBe(r: Any) = ???
  def arg() = ""
  this shouldBe (arg)
}
```

Without changes in this commit then `this shouldBe (arg)`
+ will not be rewritten if scalacenter/scalafix#1104 is not fixed
+ will be rewritten to invalid code `this shouldBe (arg)()` if scalacenter/scalafix#1104 is fixed but scalameta/scalameta#1083 is not fixed
  • Loading branch information
giabao committed Apr 27, 2020
1 parent a51cd02 commit 27eda41
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
Expand Up @@ -19,4 +19,6 @@ trait ExplicitNonNullaryApplyInfixSpec extends Matchers {
lhs() shouldBe arg
lhs shouldBe arg()
lhs() shouldBe arg()

lhs shouldBe (arg)
}
Expand Up @@ -16,4 +16,6 @@ trait ExplicitNonNullaryApplyInfixSpec extends Matchers {
lhs() shouldBe arg()
lhs() shouldBe arg()
lhs() shouldBe arg()

lhs() shouldBe (arg())
}
33 changes: 33 additions & 0 deletions rewrites/src/main/scala/scalafix/v1/Workaround1104.scala
@@ -0,0 +1,33 @@
package scalafix.v1

import scala.meta._

/** workaround for:
* + https://github.com/scalameta/scalameta/issues/1083
* + https://github.com/scalacenter/scalafix/issues/1104 */
object Workaround1104 {
/** Same as [[scalafix.v1.XtensionTreeScalafix.symbol]]
* but because of scalameta/scalameta#1083,
* sometimes `XtensionTreeScalafix.symbol` return `Symbol.None``
* In that case, we retry getting symbols at pos `name.pos`` */
def symbol(name: Name)(implicit doc: SemanticDocument): Symbol =
doc.internal.symbol(name) match {
case Symbol.None if needWorkaround(name) =>
doc.internal.symbols(name.pos)
.toStream.headOption // Discard multi symbols
.getOrElse(Symbol.None)
case sym => sym
}

def needWorkaround(name: Name): Boolean = {
val s = name.syntax
s.startsWith("(") && s.endsWith(")") && s != name.value
}

/** @return if the token `name` instead of `)` for `(name)` */
def lastToken(name: Name): Token = {
val tokens = name.tokens
if (needWorkaround(name)) tokens(tokens.length - 2)
else tokens.last
}
}

0 comments on commit 27eda41

Please sign in to comment.