Skip to content

Commit

Permalink
PLANNER-2798 Allow multiple entities with chained variables (#2182)
Browse files Browse the repository at this point in the history
  • Loading branch information
yurloc committed Nov 10, 2022
1 parent aab2d60 commit 985ccac
Show file tree
Hide file tree
Showing 22 changed files with 465 additions and 52 deletions.
Expand Up @@ -67,6 +67,20 @@ public void setMoveSelectorConfigList(List<MoveSelectorConfig> moveSelectorConfi
this.moveSelectorConfigList = moveSelectorConfigList;
}

// ************************************************************************
// With methods
// ************************************************************************

public QueuedEntityPlacerConfig withEntitySelectorConfig(EntitySelectorConfig entitySelectorConfig) {
this.entitySelectorConfig = entitySelectorConfig;
return this;
}

public QueuedEntityPlacerConfig withMoveSelectorConfigList(List<MoveSelectorConfig> moveSelectorConfigList) {
this.moveSelectorConfigList = moveSelectorConfigList;
return this;
}

@Override
public QueuedEntityPlacerConfig inherit(QueuedEntityPlacerConfig inheritedConfig) {
entitySelectorConfig = ConfigUtils.inheritConfig(entitySelectorConfig, inheritedConfig.getEntitySelectorConfig());
Expand Down
Expand Up @@ -47,6 +47,20 @@ public void setValueSelectorConfig(ValueSelectorConfig valueSelectorConfig) {
this.valueSelectorConfig = valueSelectorConfig;
}

// ************************************************************************
// With methods
// ************************************************************************

public TailChainSwapMoveSelectorConfig withEntitySelectorConfig(EntitySelectorConfig entitySelectorConfig) {
this.entitySelectorConfig = entitySelectorConfig;
return this;
}

public TailChainSwapMoveSelectorConfig withValueSelectorConfig(ValueSelectorConfig valueSelectorConfig) {
this.valueSelectorConfig = valueSelectorConfig;
return this;
}

@Override
public TailChainSwapMoveSelectorConfig inherit(TailChainSwapMoveSelectorConfig inheritedConfig) {
super.inherit(inheritedConfig);
Expand Down
Expand Up @@ -21,6 +21,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -378,36 +379,34 @@ public static Class<?> extractCollectionGenericTypeParameterStrictly(
return extractCollectionGenericTypeParameter(
parentClassConcept, parentClass,
type, genericType,
annotationClass, memberName,
true);
annotationClass, memberName).orElseThrow(
() -> new IllegalArgumentException("The " + parentClassConcept + " (" + parentClass + ") has a "
+ (annotationClass == null ? "auto discovered"
: "@" + annotationClass.getSimpleName() + " annotated")
+ " member (" + memberName
+ ") with a member type (" + type
+ ") which has no generic parameters.\n"
+ "Maybe the member (" + memberName + ") should return a parameterized "
+ type.getSimpleName()
+ "."));
}

public static Class<?> extractCollectionGenericTypeParameterLeniently(
public static Optional<Class<?>> extractCollectionGenericTypeParameterLeniently(
String parentClassConcept, Class<?> parentClass,
Class<?> type, Type genericType,
Class<? extends Annotation> annotationClass, String memberName) {
return extractCollectionGenericTypeParameter(
parentClassConcept, parentClass,
type, genericType,
annotationClass, memberName,
false);
annotationClass, memberName);
}

private static Class<?> extractCollectionGenericTypeParameter(
private static Optional<Class<?>> extractCollectionGenericTypeParameter(
String parentClassConcept, Class<?> parentClass,
Class<?> type, Type genericType,
Class<? extends Annotation> annotationClass, String memberName, boolean strict) {
Class<? extends Annotation> annotationClass, String memberName) {
if (!(genericType instanceof ParameterizedType)) {
if (strict) {
throw new IllegalArgumentException("The " + parentClassConcept + " (" + parentClass + ") has a "
+ (annotationClass == null ? "auto discovered" : "@" + annotationClass.getSimpleName() + " annotated")
+ " member (" + memberName
+ ") with a member type (" + type
+ ") which has no generic parameters.\n"
+ "Maybe the member (" + memberName + ") should return a parameterized " + type.getSimpleName() + ".");
} else {
return Object.class;
}
return Optional.empty();
}
ParameterizedType parameterizedType = (ParameterizedType) genericType;
Type[] typeArguments = parameterizedType.getActualTypeArguments();
Expand Down Expand Up @@ -444,10 +443,10 @@ private static Class<?> extractCollectionGenericTypeParameter(
}
}
if (typeArgument instanceof Class) {
return ((Class<?>) typeArgument);
return Optional.of((Class<?>) typeArgument);
} else if (typeArgument instanceof ParameterizedType) {
// Turns SomeGenericType<T> into SomeGenericType.
return (Class<?>) ((ParameterizedType) typeArgument).getRawType();
return Optional.of((Class<?>) ((ParameterizedType) typeArgument).getRawType());
} else {
throw new IllegalArgumentException("The " + parentClassConcept + " (" + parentClass + ") has a "
+ (annotationClass == null ? "auto discovered" : "@" + annotationClass.getSimpleName() + " annotated")
Expand Down
Expand Up @@ -13,6 +13,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Predicate;

import org.optaplanner.core.api.domain.entity.PinningFilter;
Expand Down Expand Up @@ -541,7 +542,13 @@ public String buildInvalidVariableNameExceptionMessage(String variableName) {
// ************************************************************************

public List<Object> extractEntities(Solution_ solution) {
return solutionDescriptor.getEntityListByEntityClass(solution, entityClass);
List<Object> entityList = new ArrayList<>();
visitAllEntities(solution, entityList::add);
return entityList;
}

public void visitAllEntities(Solution_ solution, Consumer<Object> visitor) {
solutionDescriptor.visitEntitiesByEntityClass(solution, entityClass, visitor);
}

public long getMaximumValueCount(Solution_ solution, Object entity) {
Expand Down
Expand Up @@ -21,6 +21,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand Down Expand Up @@ -338,7 +339,7 @@ private Class<? extends Annotation> extractFactEntityOrScoreAnnotationClassOrAut
elementType = ConfigUtils.extractCollectionGenericTypeParameterLeniently(
"solutionClass", solutionClass,
type, genericType,
null, member.getName());
null, member.getName()).orElse(Object.class);
} else {
elementType = type.getComponentType();
}
Expand Down Expand Up @@ -571,7 +572,7 @@ private Set<Class<?>> collectEntityAndProblemFactClasses() {
.map(accessor -> ConfigUtils.extractCollectionGenericTypeParameterLeniently(
"solutionClass", getSolutionClass(),
accessor.getType(), accessor.getGenericType(), ProblemFactCollectionProperty.class,
accessor.getName()));
accessor.getName()).orElse(Object.class));
problemFactOrEntityClassStream = concat(problemFactOrEntityClassStream, factCollectionClassStream);
// Add constraint configuration, if configured.
if (constraintConfigurationDescriptor != null) {
Expand Down Expand Up @@ -927,27 +928,43 @@ private void visitAllEntities(Solution_ solution, Consumer<Object> visitor,
}
}

public List<Object> getEntityListByEntityClass(Solution_ solution, Class<?> entityClass) {
List<Object> entityList = new ArrayList<>();
public void visitEntitiesByEntityClass(Solution_ solution, Class<?> entityClass, Consumer<Object> visitor) {
for (MemberAccessor entityMemberAccessor : entityMemberAccessorMap.values()) {
if (entityMemberAccessor.getType().isAssignableFrom(entityClass)) {
if (entityClass.isAssignableFrom(entityMemberAccessor.getType())) {
Object entity = extractMemberObject(entityMemberAccessor, solution);
if (entity != null && entityClass.isInstance(entity)) {
entityList.add(entity);
if (entity != null) {
visitor.accept(entity);
}
}
}
for (MemberAccessor entityCollectionMemberAccessor : entityCollectionMemberAccessorMap.values()) {
// TODO if (entityCollectionPropertyAccessor.getPropertyType().getElementType().isAssignableFrom(entityClass)) {
Collection<Object> entityCollection = extractMemberCollectionOrArray(entityCollectionMemberAccessor, solution,
false);
for (Object entity : entityCollection) {
if (entityClass.isInstance(entity)) {
entityList.add(entity);
Optional<Class<?>> optionalTypeParameter = ConfigUtils.extractCollectionGenericTypeParameterLeniently(
"solutionClass", entityCollectionMemberAccessor.getDeclaringClass(),
entityCollectionMemberAccessor.getType(),
entityCollectionMemberAccessor.getGenericType(),
null,
entityCollectionMemberAccessor.getName());
if (optionalTypeParameter.isPresent()) {
// In a typical case, typeParameter is specified, so we can skip the collection if typeParam
// is not assignable to entityClass.
Class<?> typeParameter = optionalTypeParameter.get();
if (entityClass.isAssignableFrom(typeParameter)) {
Collection<Object> entityCollection =
extractMemberCollectionOrArray(entityCollectionMemberAccessor, solution, false);
entityCollection.forEach(visitor);
}
} else {
// If the collection is raw, we have to visit its elements and check if each element
// is an instance of entityClass.
Collection<Object> entityCollection =
extractMemberCollectionOrArray(entityCollectionMemberAccessor, solution, false);
for (Object entity : entityCollection) {
if (entityClass.isInstance(entity)) {
visitor.accept(entity);
}
}
}
}
return entityList;
}

/**
Expand Down
Expand Up @@ -36,8 +36,7 @@ public VariableDescriptor<Solution_> getSourceVariableDescriptor() {
@Override
public void resetWorkingSolution(ScoreDirector<Solution_> scoreDirector) {
anchorMap = new IdentityHashMap<>();
previousVariableDescriptor.getEntityDescriptor().getSolutionDescriptor()
.visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
previousVariableDescriptor.getEntityDescriptor().visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
}

@Override
Expand Down
Expand Up @@ -35,8 +35,7 @@ public VariableDescriptor<Solution_> getSourceVariableDescriptor() {
@Override
public void resetWorkingSolution(ScoreDirector<Solution_> scoreDirector) {
indexMap = new IdentityHashMap<>();
sourceVariableDescriptor.getEntityDescriptor().getSolutionDescriptor()
.visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
sourceVariableDescriptor.getEntityDescriptor().visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
}

@Override
Expand Down
Expand Up @@ -35,8 +35,7 @@ public VariableDescriptor<Solution_> getSourceVariableDescriptor() {
@Override
public void resetWorkingSolution(ScoreDirector<Solution_> scoreDirector) {
inverseEntitySetMap = new IdentityHashMap<>();
sourceVariableDescriptor.getEntityDescriptor().getSolutionDescriptor()
.visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
sourceVariableDescriptor.getEntityDescriptor().visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
}

@Override
Expand Down
Expand Up @@ -32,8 +32,7 @@ public VariableDescriptor<Solution_> getSourceVariableDescriptor() {
@Override
public void resetWorkingSolution(ScoreDirector<Solution_> scoreDirector) {
inverseEntityMap = new IdentityHashMap<>();
sourceVariableDescriptor.getEntityDescriptor().getSolutionDescriptor()
.visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
sourceVariableDescriptor.getEntityDescriptor().visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
}

@Override
Expand Down
Expand Up @@ -34,8 +34,7 @@ public VariableDescriptor<Solution_> getSourceVariableDescriptor() {
@Override
public void resetWorkingSolution(ScoreDirector<Solution_> scoreDirector) {
inverseEntityMap = new IdentityHashMap<>();
sourceVariableDescriptor.getEntityDescriptor().getSolutionDescriptor()
.visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
sourceVariableDescriptor.getEntityDescriptor().visitAllEntities(scoreDirector.getWorkingSolution(), this::insert);
}

@Override
Expand Down
Expand Up @@ -63,7 +63,7 @@ private void linkShadowSources(DescriptorPolicy descriptorPolicy) {
sourceClass = ConfigUtils.extractCollectionGenericTypeParameterLeniently(
"entityClass", entityDescriptor.getEntityClass(),
variablePropertyType, genericType,
InverseRelationShadowVariable.class, variableMemberAccessor.getName());
InverseRelationShadowVariable.class, variableMemberAccessor.getName()).orElse(Object.class);
singleton = false;
} else {
sourceClass = variablePropertyType;
Expand Down
Expand Up @@ -220,10 +220,8 @@ private List<Phase<Solution_>> buildPhaseList(HeuristicConfigPolicy<Solution_> c
DefaultConstructionHeuristicPhaseFactory.buildListVariableQueuedValuePlacerConfig(configPolicy,
listVariableDescriptorList.get(0));
} else {
QueuedEntityPlacerConfig queuedEntityPlacerConfig = new QueuedEntityPlacerConfig();
queuedEntityPlacerConfig.setEntitySelectorConfig(AbstractFromConfigFactory
entityPlacerConfig = new QueuedEntityPlacerConfig().withEntitySelectorConfig(AbstractFromConfigFactory
.getDefaultEntitySelectorConfigForEntity(configPolicy, genuineEntityDescriptor));
entityPlacerConfig = queuedEntityPlacerConfig;
}

constructionHeuristicPhaseConfig.setEntityPlacerConfig(entityPlacerConfig);
Expand Down
Expand Up @@ -3,11 +3,16 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

import java.util.List;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.optaplanner.core.api.score.director.ScoreDirector;
import org.optaplanner.core.impl.domain.solution.descriptor.SolutionDescriptor;
import org.optaplanner.core.impl.heuristic.selector.common.decorator.SelectionFilter;
import org.optaplanner.core.impl.testdata.domain.TestdataEntity;
import org.optaplanner.core.impl.testdata.domain.extended.TestdataUnannotatedExtendedEntity;
import org.optaplanner.core.impl.testdata.domain.extended.entity.TestdataExtendedEntitySolution;
import org.optaplanner.core.impl.testdata.domain.pinned.TestdataPinnedEntity;
import org.optaplanner.core.impl.testdata.domain.pinned.TestdataPinnedSolution;
import org.optaplanner.core.impl.testdata.domain.pinned.extended.TestdataExtendedPinnedEntity;
Expand Down Expand Up @@ -90,4 +95,32 @@ void extendedMovableEntitySelectionFilterUsedByChildSelector() {
new TestdataExtendedPinnedEntity("e8", null, true, true, null, true, true))).isFalse();
}

@Test
void extractExtendedEntities() {
TestdataExtendedEntitySolution solution = new TestdataExtendedEntitySolution();

TestdataEntity entity = new TestdataEntity("entity-singleton");
solution.setEntity(entity);

TestdataUnannotatedExtendedEntity subEntity = new TestdataUnannotatedExtendedEntity("subEntity-singleton");
solution.setSubEntity(subEntity);

TestdataEntity e1 = new TestdataEntity("entity1");
TestdataEntity e2 = new TestdataEntity("entity2");
solution.setEntityList(List.of(e1, e2));

TestdataUnannotatedExtendedEntity s1 = new TestdataUnannotatedExtendedEntity("subEntity1");
TestdataUnannotatedExtendedEntity s2 = new TestdataUnannotatedExtendedEntity("subEntity2");
TestdataUnannotatedExtendedEntity s3 = new TestdataUnannotatedExtendedEntity("subEntity3");
solution.setSubEntityList(List.of(s1, s2, s3));

TestdataUnannotatedExtendedEntity r1 = new TestdataUnannotatedExtendedEntity("subEntity1-R");
TestdataUnannotatedExtendedEntity r2 = new TestdataUnannotatedExtendedEntity("subEntity2-R");
solution.setRawEntityList(List.of(r1, r2));

EntityDescriptor<TestdataExtendedEntitySolution> entityDescriptor =
TestdataExtendedEntitySolution.buildEntityDescriptor();
assertThat(entityDescriptor.extractEntities(solution))
.containsExactlyInAnyOrder(entity, subEntity, e1, e2, s1, s2, s3, r1, r2);
}
}

0 comments on commit 985ccac

Please sign in to comment.