From 1d88777cd0ae2515ac7fd4cf58abd709b9f05bb9 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Wed, 15 Feb 2023 11:35:40 +0000 Subject: [PATCH] Extend pattern type constraining to closed hierarchies --- .../src/dotty/tools/dotc/core/Flags.scala | 2 ++ .../dotc/core/PatternTypeConstrainer.scala | 20 +++--------- .../dotty/tools/dotc/typer/RefChecks.scala | 32 +++++++++++-------- tests/neg/i18552.check | 4 +-- tests/pos/i4790.scala | 21 ++++++++++++ 5 files changed, 49 insertions(+), 30 deletions(-) create mode 100644 tests/pos/i4790.scala diff --git a/compiler/src/dotty/tools/dotc/core/Flags.scala b/compiler/src/dotty/tools/dotc/core/Flags.scala index b1bf7a266c91..91397311c6da 100644 --- a/compiler/src/dotty/tools/dotc/core/Flags.scala +++ b/compiler/src/dotty/tools/dotc/core/Flags.scala @@ -563,6 +563,8 @@ object Flags { val JavaOrPrivateOrSynthetic: FlagSet = Artifact | JavaDefined | Private | Synthetic val PrivateOrSynthetic: FlagSet = Artifact | Private | Synthetic val EnumCase: FlagSet = Case | Enum + val CaseOrFinalOrSealed: FlagSet = Case | Final | Sealed + val CaseOrSealed: FlagSet = Case | Sealed val CovariantLocal: FlagSet = Covariant | Local // A covariant type parameter val ContravariantLocal: FlagSet = Contravariant | Local // A contravariant type parameter val EffectivelyErased = ConstructorProxy | Erased diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 9baf0c40a80b..01930396c17b 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -221,29 +221,19 @@ trait PatternTypeConstrainer { self: TypeComparer => * * It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`. * Note, however, that if `D` was a final class, we *could* rely on that relationship. - * To support typical case classes, we also assume that this relationship holds for them and their parent traits. - * This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those - * case classes without also appropriately extending the relevant case class - * (see `RefChecks#checkCaseClassInheritanceInvariant`). + * Case classes and sealed traits (and sealed classes) are supported, + * by assuming that this relationship holds for them and their parent traits. + * This is enforced by checking no subclass of them mixes in any parent trait with a different type argument. + * (see `RefChecks#checkVariantInheritanceProblems`). */ def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = { def refinementIsInvariant(tp: Type): Boolean = tp match { case tp: SingletonType => true - case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case) + case tp: ClassInfo => tp.cls.isOneOf(CaseOrFinalOrSealed) case tp: TypeProxy => refinementIsInvariant(tp.superType) case _ => false } - def widenVariantParams(tp: Type) = tp match { - case tp @ AppliedType(tycon, args) => - val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) => - if (tparam.paramVarianceSign != 0) TypeBounds.empty else arg - ) - tp.derivedAppliedType(tycon, args1) - case tp => - tp - } - val patternCls = patternTp.classSymbol val scrutineeCls = scrutineeTp.classSymbol diff --git a/compiler/src/dotty/tools/dotc/typer/RefChecks.scala b/compiler/src/dotty/tools/dotc/typer/RefChecks.scala index cb1aea27c444..e954424bd381 100644 --- a/compiler/src/dotty/tools/dotc/typer/RefChecks.scala +++ b/compiler/src/dotty/tools/dotc/typer/RefChecks.scala @@ -901,20 +901,26 @@ object RefChecks { } } - /** Check that inheriting a case class does not constitute a variant refinement + /** Check that inheriting a case class or a sealed trait (or a sealed class) does not constitute a variant refinement * of a base type of the case class. It is because of this restriction that we - * can assume invariant refinement for case classes in `constrainPatternType`. + * can assume invariant refinement for these classes in `constrainSimplePatternType`. */ - def checkCaseClassInheritanceInvariant() = + def checkVariantInheritanceProblems() = for - caseCls <- clazz.info.baseClasses.tail.find(_.is(Case)) - baseCls <- caseCls.info.baseClasses.tail + middle <- clazz.info.baseClasses.tail + if middle.isOneOf(CaseOrSealed) + baseCls <- middle.info.baseClasses.tail if baseCls.typeParams.exists(_.paramVarianceSign != 0) - problem <- variantInheritanceProblems(baseCls, caseCls, i"base $baseCls", "case ") - withExplain = problem.appendExplanation: - """Refining a basetype of a case class is not allowed. - |This is a limitation that enables better GADT constraints in case class patterns""".stripMargin - do report.errorOrMigrationWarning(withExplain, clazz.srcPos, MigrationVersion.Scala2to3) + problem <- { + val middleStr = if middle.is(Case) then "case " else if middle.is(Sealed) then "sealed " else "" + variantInheritanceProblems(baseCls, middle, "", middleStr) + } + do + val withExplain = problem.appendExplanation: + """Refining a basetype of a case class or a sealed trait (or a sealed class) is not allowed. + |This is a limitation that enables better GADT constraints in case class and sealed hierarchy patterns""".stripMargin + report.errorOrMigrationWarning(withExplain, clazz.srcPos, MigrationVersion.Scala2to3) + checkNoAbstractMembers() if (abstractErrors.isEmpty) checkNoAbstractDecls(clazz) @@ -923,7 +929,7 @@ object RefChecks { report.error(abstractErrorMessage, clazz.srcPos) checkMemberTypesOK() - checkCaseClassInheritanceInvariant() + checkVariantInheritanceProblems() } if (!clazz.is(Trait) && checker.checkInheritedTraitParameters) { @@ -943,7 +949,7 @@ object RefChecks { for { cls <- clazz.info.baseClasses.tail if cls.paramAccessors.nonEmpty && !mixins.contains(cls) - problem <- variantInheritanceProblems(cls, clazz.asClass.superClass, i"parameterized base $cls", "super") + problem <- variantInheritanceProblems(cls, clazz.asClass.superClass, "parameterized ", "super") } report.error(problem, clazz.srcPos) } @@ -966,7 +972,7 @@ object RefChecks { if (combinedBT =:= thisBT) None // ok else Some( - em"""illegal inheritance: $clazz inherits conflicting instances of $baseStr. + em"""illegal inheritance: $clazz inherits conflicting instances of ${baseStr}base $baseCls. | | Direct basetype: $thisBT | Basetype via $middleStr$middle: $combinedBT""") diff --git a/tests/neg/i18552.check b/tests/neg/i18552.check index a7a04ed78b47..5918f2b97283 100644 --- a/tests/neg/i18552.check +++ b/tests/neg/i18552.check @@ -8,6 +8,6 @@ |--------------------------------------------------------------------------------------------------------------------- | Explanation (enabled by `-explain`) |- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - | Refining a basetype of a case class is not allowed. - | This is a limitation that enables better GADT constraints in case class patterns + | Refining a basetype of a case class or a sealed trait (or a sealed class) is not allowed. + | This is a limitation that enables better GADT constraints in case class and sealed hierarchy patterns --------------------------------------------------------------------------------------------------------------------- diff --git a/tests/pos/i4790.scala b/tests/pos/i4790.scala new file mode 100644 index 000000000000..5213870944b6 --- /dev/null +++ b/tests/pos/i4790.scala @@ -0,0 +1,21 @@ +class Test: + def foo(as: Seq[Int]) = + val List(_, bs: _*) = as: @unchecked + val cs: Seq[Int] = bs + +class Test2: + def foo(as: SSeq[Int]) = + val LList(_, tail) = as: @unchecked + val cs: SSeq[Int] = tail + +trait SSeq[+A] +sealed trait LList[+A] extends SSeq[A] +final case class CCons[+A](head: A, tail: LList[A]) extends LList[A] +case object NNil extends LList[Nothing] +object LList: + def unapply[A](xs: LList[A]): Extractor[A] = Extractor[A](xs) + final class Extractor[A](private val xs: LList[A]) extends AnyVal: + def get: this.type = this + def isEmpty: Boolean = xs.isInstanceOf[CCons[?]] + def _1: A = xs.asInstanceOf[CCons[A]].head + def _2: SSeq[A] = xs.asInstanceOf[CCons[A]].tail