Skip to content

Commit

Permalink
[unittest/Sema] Cover transitive protocol inference with unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xedin committed Oct 15, 2020
1 parent a3c3981 commit 9598f19
Showing 1 changed file with 131 additions and 15 deletions.
146 changes: 131 additions & 15 deletions unittests/Sema/BindingInferenceTests.cpp
Expand Up @@ -46,23 +46,41 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
ASSERT_TRUE(binding.hasDefaultedLiteralProtocol());
}

// Given a set of inferred protocol requirements, make sure that
// all of the expected types are present.
static void verifyProtocolInferenceResults(
const llvm::SmallPtrSetImpl<Constraint *> &protocols,
ArrayRef<Type> expectedTypes) {
ASSERT_TRUE(protocols.size() >= expectedTypes.size());

llvm::SmallPtrSet<Type, 2> inferredProtocolTypes;
for (auto *protocol : protocols)
inferredProtocolTypes.insert(protocol->getSecondType());

for (auto expectedTy : expectedTypes) {
ASSERT_TRUE(inferredProtocolTypes.count(expectedTy));
}
}

TEST_F(SemaTest, TestTransitiveProtocolInference) {
ConstraintSystemOptions options;
ConstraintSystem cs(DC, options);

auto *PD1 =
new (Context) ProtocolDecl(DC, SourceLoc(), SourceLoc(),
Context.getIdentifier("P1"), /*Inherited=*/{},
/*trailingWhere=*/nullptr);
PD1->setImplicit();
auto *protocolTy1 = createProtocol("P1");
auto *protocolTy2 = createProtocol("P2");

auto *protocolTy1 = ProtocolType::get(PD1, Type(), Context);
auto *GPT1 = cs.createTypeVariable(cs.getConstraintLocator({}),
/*options=*/TVO_CanBindToNoEscape);
auto *GPT2 = cs.createTypeVariable(cs.getConstraintLocator({}),
/*options=*/TVO_CanBindToNoEscape);

auto *GPT = cs.createTypeVariable(cs.getConstraintLocator({}),
/*options=*/TVO_CanBindToNoEscape);
cs.addConstraint(
ConstraintKind::ConformsTo, GPT1, protocolTy1,
cs.getConstraintLocator({}, LocatorPathElt::TypeParameterRequirement(
0, RequirementKind::Conformance)));

cs.addConstraint(
ConstraintKind::ConformsTo, GPT, protocolTy1,
ConstraintKind::ConformsTo, GPT2, protocolTy2,
cs.getConstraintLocator({}, LocatorPathElt::TypeParameterRequirement(
0, RequirementKind::Conformance)));

Expand All @@ -73,16 +91,114 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) {
/*options=*/0);

cs.addConstraint(
ConstraintKind::Conversion, typeVar, GPT,
ConstraintKind::Conversion, typeVar, GPT1,
cs.getConstraintLocator({}, LocatorPathElt::ContextualType()));

auto bindings = inferBindings(cs, typeVar);
ASSERT_TRUE(bindings.Protocols.empty());
ASSERT_TRUE(bool(bindings.TransitiveProtocols));
verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
{protocolTy1});
}

// Now, let's make sure that protocol requirements could be propagated
// down conversion/equality chains through multiple hops.
{
// GPT1 is a subtype of GPT2 and GPT2 is convertible to a target type
// variable, target should get both protocols inferred - P1 & P2.

const auto &inferredProtocols = bindings.TransitiveProtocols;
ASSERT_TRUE(bool(inferredProtocols));
ASSERT_EQ(inferredProtocols->size(), (unsigned)1);
ASSERT_TRUE(
(*inferredProtocols->begin())->getSecondType()->isEqual(protocolTy1));
auto *typeVar = cs.createTypeVariable(cs.getConstraintLocator({}),
/*options=*/0);

cs.addConstraint(ConstraintKind::Subtype, GPT1, GPT2,
cs.getConstraintLocator({}));

cs.addConstraint(ConstraintKind::Conversion, typeVar, GPT1,
cs.getConstraintLocator({}));

auto bindings = inferBindings(cs, typeVar);
ASSERT_TRUE(bindings.Protocols.empty());
ASSERT_TRUE(bool(bindings.TransitiveProtocols));
verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
{protocolTy1, protocolTy2});
}
}

/// Let's try a more complicated situation where there protocols
/// are inferred from multiple sources on different levels of
/// convertion chain.
///
/// (P1) T0 T4 (T3) T6 (P4)
/// \ / /
/// T3 = T1 (P2) = T5
/// \ /
/// T2

TEST_F(SemaTest, TestComplexTransitiveProtocolInference) {
ConstraintSystemOptions options;
ConstraintSystem cs(DC, options);

auto *protocolTy1 = createProtocol("P1");
auto *protocolTy2 = createProtocol("P2");
auto *protocolTy3 = createProtocol("P3");
auto *protocolTy4 = createProtocol("P4");

auto *nilLocator = cs.getConstraintLocator({});

auto typeVar0 = cs.createTypeVariable(nilLocator, /*options=*/0);
auto typeVar1 = cs.createTypeVariable(nilLocator, /*options=*/0);
auto typeVar2 = cs.createTypeVariable(nilLocator, /*options=*/0);
// Allow this type variable to be bound to l-value type to prevent
// it from being merged with the rest of the type variables.
auto typeVar3 =
cs.createTypeVariable(nilLocator, /*options=*/TVO_CanBindToLValue);
auto typeVar4 = cs.createTypeVariable(nilLocator, /*options=*/0);
auto typeVar5 =
cs.createTypeVariable(nilLocator, /*options=*/TVO_CanBindToLValue);
auto typeVar6 = cs.createTypeVariable(nilLocator, /*options=*/0);

cs.addConstraint(ConstraintKind::ConformsTo, typeVar0, protocolTy1,
nilLocator);
cs.addConstraint(ConstraintKind::ConformsTo, typeVar1, protocolTy2,
nilLocator);
cs.addConstraint(ConstraintKind::ConformsTo, typeVar4, protocolTy3,
nilLocator);
cs.addConstraint(ConstraintKind::ConformsTo, typeVar6, protocolTy4,
nilLocator);

// T3 <: T0, T3 <: T4
cs.addConstraint(ConstraintKind::Conversion, typeVar3, typeVar0, nilLocator);
cs.addConstraint(ConstraintKind::Conversion, typeVar3, typeVar4, nilLocator);

// T2 <: T3, T2 <: T1, T3 == T1
cs.addConstraint(ConstraintKind::Subtype, typeVar2, typeVar3, nilLocator);
cs.addConstraint(ConstraintKind::Conversion, typeVar2, typeVar1, nilLocator);
cs.addConstraint(ConstraintKind::Equal, typeVar3, typeVar1, nilLocator);
// T1 == T5, T <: T6
cs.addConstraint(ConstraintKind::Equal, typeVar1, typeVar5, nilLocator);
cs.addConstraint(ConstraintKind::Conversion, typeVar5, typeVar6, nilLocator);

auto bindingsForT1 = inferBindings(cs, typeVar1);
auto bindingsForT2 = inferBindings(cs, typeVar2);
auto bindingsForT3 = inferBindings(cs, typeVar3);
auto bindingsForT5 = inferBindings(cs, typeVar5);

ASSERT_TRUE(bool(bindingsForT1.TransitiveProtocols));
verifyProtocolInferenceResults(*bindingsForT1.TransitiveProtocols,
{protocolTy1, protocolTy3, protocolTy4});

ASSERT_TRUE(bool(bindingsForT2.TransitiveProtocols));
verifyProtocolInferenceResults(
*bindingsForT2.TransitiveProtocols,
{protocolTy1, protocolTy2, protocolTy3, protocolTy4});

ASSERT_TRUE(bool(bindingsForT3.TransitiveProtocols));
verifyProtocolInferenceResults(
*bindingsForT3.TransitiveProtocols,
{protocolTy1, protocolTy2, protocolTy3, protocolTy4});

ASSERT_TRUE(bool(bindingsForT5.TransitiveProtocols));
verifyProtocolInferenceResults(
*bindingsForT5.TransitiveProtocols,
{protocolTy1, protocolTy2, protocolTy3, protocolTy4});
}

0 comments on commit 9598f19

Please sign in to comment.