From d845f4aa1eac0a31fcf54811d411a3ced5741972 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 24 Feb 2014 23:47:22 +0000 Subject: [PATCH] Scalaxy/Privacy: recognize trivial types from collection constructors: List(1, 2), Map[A, B](...), but not List(1, "2") --- .../ExplicitTypeAnnotationsComponent.scala | 85 ++++++++++++++----- .../scalaxy/ExplicitTypeAnnotationsTest.scala | 68 ++++++++++++--- .../PluginComponentsIntegrationTest.scala | 16 ++-- .../src/test/scala/scalaxy/PrivacyTest.scala | 4 +- Privacy/src/test/scala/scalaxy/TestBase.scala | 1 + 5 files changed, 134 insertions(+), 40 deletions(-) diff --git a/Privacy/src/main/scala/scalaxy/privacy/ExplicitTypeAnnotationsComponent.scala b/Privacy/src/main/scala/scalaxy/privacy/ExplicitTypeAnnotationsComponent.scala index f4462126..c85db617 100644 --- a/Privacy/src/main/scala/scalaxy/privacy/ExplicitTypeAnnotationsComponent.scala +++ b/Privacy/src/main/scala/scalaxy/privacy/ExplicitTypeAnnotationsComponent.scala @@ -15,7 +15,8 @@ object ExplicitTypeAnnotationsComponent { val phaseName = "scalaxy-explicit-annotations" } class ExplicitTypeAnnotationsComponent( - val global: Global, runAfter: String = "parser") + val global: Global, + runAfter: String = "parser") extends PluginComponent { import global._ import definitions._ @@ -27,7 +28,7 @@ class ExplicitTypeAnnotationsComponent( override val runsAfter = runsRightAfter.toList override val runsBefore = List("typer") - object N { + private object N { def unapply(n: Name): Option[String] = if (n == null) None @@ -35,37 +36,81 @@ class ExplicitTypeAnnotationsComponent( Some(decode(n.toString)) } + private object TrivialCollectionName { + var rx = "List|Array|Set|Seq|Iterable|Traversable".r + def unapply(n: Name): Boolean = n.toString match { + case rx() => true + case _ => false + } + } + + private val StringTpe = typeOf[String] + override def newPhase(prev: Phase) = new StdPhase(prev) { def apply(unit: CompilationUnit) { new Traverser { - def isTrivialRHS(rhs: Tree): Boolean = rhs match { - case Literal(Constant(_)) => - true + def getTrivialTypeTree(tree: Tree): Option[Tree] = tree match { + /** Constant literals are the obvious case of trivial type. */ + case Literal(c @ Constant(_)) => + Some(TypeTree(c.tpe)) + /** String interpolation have a trivial String type. */ case Apply(Select(Apply(Ident(N("StringContext")), _), N("s")), _) => - true - - case Apply(Select(left, N("+" | "*")), List(right)) => - isTrivialRHS(left) && isTrivialRHS(right) + Some(TypeTree(StringTpe)) + + /** Collection-constructing expression like Array(a, b...) with a, b... trivially typed has a trivial type. */ + case Apply(colTpt @ Ident(TrivialCollectionName()), args) => + args.map(getTrivialTypeTree(_)).map(t => (t.toString -> t)).toMap.toList match { + // Don't accept nesting of collections as trivial: the component type must not have type params. + case List((str, Some(componentTpt))) if !str.contains("[") => + Some(TypeApply(colTpt, List(componentTpt))) + + case _ => + None + } + + /** + * Collection-constructing expressions like Array[Int](...) or Map[A, B](...) have + * a trivial type. + */ + case Apply(ta @ TypeApply(Ident(TrivialCollectionName() | N("Map")), tparams), _) => + Some(ta) + + /** + * String + Any has a trivial type. + * Homogeneous multiplications have a trivial type. + */ + case Apply(Select(left, N(op @ ("+" | "*"))), List(right)) => + (getTrivialTypeTree(left), getTrivialTypeTree(right)) match { + case (Some(leftTpe), Some(rightTpe)) if leftTpe.toString == rightTpe.toString || (op == "+" && leftTpe.toString == "String") => + Some(leftTpe) + case _ => + None + } case _ => - false + None } - def checkTypeTree(d: ValOrDefDef) { - // reporter.info(d.pos, "d.pos", force = true) - // reporter.info(d.tpt.pos, "d.tpt.pos (" + d.tpt.getClass.getName + ")", force = true) - // reporter.info(d.rhs.pos, "d.rhs.pos", force = true) + def checkTypeTree(d: ValOrDefDef) { if (d.tpt.pos != NoPosition && d.tpt.pos == d.pos && d.name != nme.CONSTRUCTOR && - d.mods.hasNoFlags(PRIVATE | PROTECTED | SYNTHETIC | OVERRIDE) && - !isTrivialRHS(d.rhs)) { - - reporter.warning( - if (d.pos == NoPosition) d.rhs.pos else d.pos, - s"Public member `${d.name}` with non-trivial value should have a explicit type annotation") + d.mods.hasNoFlags(PRIVATE | PROTECTED | SYNTHETIC | OVERRIDE)) { + + val pos = if (d.pos == NoPosition) d.rhs.pos else d.pos + getTrivialTypeTree(d.rhs) match { + case Some(tpt) => + reporter.info( + pos, + s"Extracted trivial type $tpt", force = true) + + case _ => + reporter.warning( + pos, + s"Public member `${d.name}` with non-trivial value should have an explicit type annotation") + } } } diff --git a/Privacy/src/test/scala/scalaxy/ExplicitTypeAnnotationsTest.scala b/Privacy/src/test/scala/scalaxy/ExplicitTypeAnnotationsTest.scala index a76678aa..1159ffdf 100644 --- a/Privacy/src/test/scala/scalaxy/ExplicitTypeAnnotationsTest.scala +++ b/Privacy/src/test/scala/scalaxy/ExplicitTypeAnnotationsTest.scala @@ -9,11 +9,17 @@ import scala.tools.nsc.reporters.{ StoreReporter, Reporter } import org.junit._ import org.junit.Assert._ +object ExplicitTypeAnnotationsTest { + def shouldHaveAnnotationMsg(name: String) = + s"Public member `$name` with non-trivial value should have an explicit type annotation" +} class ExplicitTypeAnnotationsTest extends TestBase { override def getInternalPhases(global: Global) = List(new ExplicitTypeAnnotationsComponent(global)) + import ExplicitTypeAnnotationsTest.shouldHaveAnnotationMsg + @Test def allGood { assertEquals( @@ -31,34 +37,72 @@ class ExplicitTypeAnnotationsTest extends TestBase { } @Test - def trivial { + def trivials { assertEquals( Nil, compile(""" class Foo { - def f = 10 + "blah" - val v = 10 + "blah" + def nullaryFunctionWithStringPlusNumber = "blah" + 10.0 + def valWithNumberTimesNumber = 10.0 * 20.0 + + def interpolatedString = s"blah $valWithNumberTimesNumber" + + def list = List(1, 2) + def set = Set("1", "2") + val seq = Seq(1, 2) + def array = Array(1, 2) + def iterable = Iterable(1, 2) + def traversable = Traversable(1, 2) + + val arrayWithType = Array[Int]() + val listWithType = List[(String, Long)]() + val mapWithTypes = Map[Int, (Double, Float)]() + + override def toString = if (arrayWithType.length > 10) "1" else "2" + } + """) + ) + } + @Test + def almostTrivials { + assertEquals( + List( + shouldHaveAnnotationMsg("nullaryFunctionWithNumberPlusString"), + shouldHaveAnnotationMsg("valWithNumberPlusString"), + shouldHaveAnnotationMsg("heterogeneousList"), + shouldHaveAnnotationMsg("heterogeneousArray"), + shouldHaveAnnotationMsg("nestedArray") + ), + compile(""" + class Foo { + def nullaryFunctionWithNumberPlusString = 10 + "blah" + val valWithNumberPlusString = 10 + "blah" + + def heterogeneousList = List(1, "2") + val heterogeneousArray = Array(1, "2") + def nestedArray = Array(Set[Int](), Set[Int]()) } """) ) } @Test - def allBad { + def nonTrivials { assertEquals( List( - "Public member `f` with non-trivial value should have a explicit type annotation", - "Public member `ff` with non-trivial value should have a explicit type annotation", - "Public member `v` with non-trivial value should have a explicit type annotation", - "Public member `vv` with non-trivial value should have a explicit type annotation" + shouldHaveAnnotationMsg("functionWithArgWithBranch"), + shouldHaveAnnotationMsg("nullaryFunctionWithCall"), + shouldHaveAnnotationMsg("valWithBranch"), + shouldHaveAnnotationMsg("valWithRef") ), compile(""" class Foo { - def f(x: Int) = if (x < 10) 1 else 2 - def ff = f(10) - val v = if (ff == 1) 1 else "blah" - val vv = v + def functionWithArgWithBranch(x: Int) = if (x < 10) 1 else 2 + def nullaryFunctionWithCall = functionWithArgWithBranch(10) + + val valWithBranch = if (nullaryFunctionWithCall == 1) 1 else "blah" + val valWithRef = valWithBranch } """) ) diff --git a/Privacy/src/test/scala/scalaxy/PluginComponentsIntegrationTest.scala b/Privacy/src/test/scala/scalaxy/PluginComponentsIntegrationTest.scala index 085379e8..a7f6cbc0 100644 --- a/Privacy/src/test/scala/scalaxy/PluginComponentsIntegrationTest.scala +++ b/Privacy/src/test/scala/scalaxy/PluginComponentsIntegrationTest.scala @@ -14,17 +14,19 @@ class PluginComponentsIntegrationTest extends TestBase { override def getInternalPhases(global: Global) = PrivacyPlugin.getInternalPhases(global) + import ExplicitTypeAnnotationsTest.shouldHaveAnnotationMsg + @Test def mixAllFeatures { assertEquals( List( - "Public member `ff` with non-trivial value should have a explicit type annotation", - "Public member `ffff` with non-trivial value should have a explicit type annotation", - "Public member `g` with non-trivial value should have a explicit type annotation", - "Public member `vvv` with non-trivial value should have a explicit type annotation", - "Public member `barf` with non-trivial value should have a explicit type annotation", - "Public member `barv` with non-trivial value should have a explicit type annotation", - "Public member `barvv` with non-trivial value should have a explicit type annotation" + shouldHaveAnnotationMsg("ff"), + shouldHaveAnnotationMsg("ffff"), + shouldHaveAnnotationMsg("g"), + shouldHaveAnnotationMsg("vvv"), + shouldHaveAnnotationMsg("barf"), + shouldHaveAnnotationMsg("barv"), + shouldHaveAnnotationMsg("barvv") ), compile(""" @public class Foo { diff --git a/Privacy/src/test/scala/scalaxy/PrivacyTest.scala b/Privacy/src/test/scala/scalaxy/PrivacyTest.scala index f179d149..fdb4a167 100644 --- a/Privacy/src/test/scala/scalaxy/PrivacyTest.scala +++ b/Privacy/src/test/scala/scalaxy/PrivacyTest.scala @@ -80,7 +80,9 @@ class PrivacyTest extends TestBase { @Test def moduleMembers { assertEquals( - List("value privateByDefault is not a member of object Foo"), + List( + "value privateByDefault is not a member of object Foo" + ), compile(""" object Foo { val privateByDefault = 10 diff --git a/Privacy/src/test/scala/scalaxy/TestBase.scala b/Privacy/src/test/scala/scalaxy/TestBase.scala index 72271220..831e2e77 100644 --- a/Privacy/src/test/scala/scalaxy/TestBase.scala +++ b/Privacy/src/test/scala/scalaxy/TestBase.scala @@ -3,6 +3,7 @@ package scalaxy.privacy.test import scala.language.existentials import scalaxy.privacy.PrivacyCompiler +import scala.reflect.ClassTag import scala.tools.nsc.Global import scala.reflect.internal.util.Position import scala.tools.nsc.reporters.StoreReporter