Skip to content

Commit

Permalink
PLANNER-713 loadBalanceByCount() - for now only on the tennis example…
Browse files Browse the repository at this point in the history
… because it's deviation from zero instead of mean (see blog)
  • Loading branch information
ge0ffrey committed Feb 3, 2017
1 parent 62b6e27 commit 2f30be0
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 26 deletions.
4 changes: 4 additions & 0 deletions optaplanner-examples/pom.xml
Expand Up @@ -96,6 +96,10 @@
<artifactId>optaplanner-test</artifactId> <artifactId>optaplanner-test</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.kie</groupId>
<artifactId>kie-api</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.drools</groupId> <groupId>org.drools</groupId>
<artifactId>drools-decisiontables</artifactId> <artifactId>drools-decisiontables</artifactId>
Expand Down
Expand Up @@ -14,9 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


package org.optaplanner.core.impl.score.director.drools.functions; package org.optaplanner.examples.tennis.solver.drools.functions;


import java.io.IOException;
import java.io.ObjectInput; import java.io.ObjectInput;
import java.io.ObjectOutput; import java.io.ObjectOutput;
import java.io.Serializable; import java.io.Serializable;
Expand All @@ -25,12 +24,13 @@


import org.kie.api.runtime.rule.AccumulateFunction; import org.kie.api.runtime.rule.AccumulateFunction;


public class LoadBalanceAccumulateFunction implements AccumulateFunction { public class LoadBalanceByCountAccumulateFunction implements AccumulateFunction {


protected static class LoadBalanceData implements Serializable { protected static class LoadBalanceData implements Serializable {


private Map<Object, Long> groupWeightMap; private Map<Object, Long> groupWeightMap;
private long variance; // the sum of squared deviation from zero
private long squaredDeviation;


} }


Expand All @@ -40,52 +40,77 @@ public Serializable createContext() {
} }


@Override @Override
public void init(Serializable context) throws Exception { public void init(Serializable context) {
LoadBalanceData data = (LoadBalanceData) context; LoadBalanceData data = (LoadBalanceData) context;
data.groupWeightMap = new HashMap<>(); data.groupWeightMap = new HashMap<>();
data.variance = 0L; data.squaredDeviation = 0L;
} }


@Override @Override
public void accumulate(Serializable context, Object groupBy) { public void accumulate(Serializable context, Object groupBy) {
LoadBalanceData data = (LoadBalanceData) context; LoadBalanceData data = (LoadBalanceData) context;
long count = data.groupWeightMap.compute(groupBy, long count = data.groupWeightMap.compute(groupBy,
(key, value) -> (value == null) ? 1L : value + 1L); (key, value) -> (value == null) ? 1L : value + 1L);
// variance = variance - (count - 1)² + count² // squaredDeviation = squaredDeviation - (count - 1)² + count²
// <=> variance = variance + (2 * count - 1) // <=> squaredDeviation = squaredDeviation + (2 * count - 1)
data.variance += (2 * count - 1); data.squaredDeviation += (2 * count - 1);
} }


@Override @Override
public void reverse(Serializable context, Object groupBy) throws Exception { public boolean supportsReverse() {
return true;
}

@Override
public void reverse(Serializable context, Object groupBy) {
LoadBalanceData data = (LoadBalanceData) context; LoadBalanceData data = (LoadBalanceData) context;
Long count = data.groupWeightMap.compute(groupBy, Long count = data.groupWeightMap.compute(groupBy,
(key, value) -> (value.longValue() == 1L) ? null : value - 1L); (key, value) -> (value.longValue() == 1L) ? null : value - 1L);
data.variance -= (count == null) ? 1L : (2 * count + 1); data.squaredDeviation -= (count == null) ? 1L : (2 * count + 1);
} }


@Override @Override
public Double getResult(Serializable context) throws Exception { public Class<LoadBalanceResult> getResultType() {
LoadBalanceData data = (LoadBalanceData) context; return LoadBalanceResult.class;
return Math.sqrt((double) data.variance);
} }


@Override @Override
public boolean supportsReverse() { public LoadBalanceResult getResult(Serializable context) {
return true; LoadBalanceData data = (LoadBalanceData) context;
return new LoadBalanceResult(data.squaredDeviation);
} }


@Override @Override
public Class<Double> getResultType() { public void writeExternal(ObjectOutput out) {
return Double.class;
} }


@Override @Override
public void writeExternal(ObjectOutput out) throws IOException { public void readExternal(ObjectInput in) {
} }


@Override public static class LoadBalanceResult implements Serializable {
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
private final long squaredDeviation;

public LoadBalanceResult(long squaredDeviation) {
this.squaredDeviation = squaredDeviation;
}

public long getSquaredDeviation() {
return squaredDeviation;
}

public long getRootSquaredDeviationMillis() {
return getRootSquaredDeviation(1_000.0);
}

public long getRootSquaredDeviationMicros() {
return getRootSquaredDeviation(1_000_000.0);
}

public long getRootSquaredDeviation(double scaleMultiplier) {
return (long) (Math.sqrt((double) squaredDeviation) * scaleMultiplier);
}
} }


} }
Expand Up @@ -25,7 +25,7 @@ import org.optaplanner.examples.tennis.domain.UnavailabilityPenalty;
import org.optaplanner.examples.tennis.domain.TeamAssignment; import org.optaplanner.examples.tennis.domain.TeamAssignment;


import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import accumulate org.optaplanner.core.impl.score.director.drools.functions.LoadBalanceAccumulateFunction loadBalance; import accumulate org.optaplanner.examples.tennis.solver.drools.functions.LoadBalanceByCountAccumulateFunction loadBalanceByCount;


global HardMediumSoftScoreHolder scoreHolder; global HardMediumSoftScoreHolder scoreHolder;


Expand Down Expand Up @@ -57,10 +57,10 @@ rule "fairAssignmentCountPerTeam"
when when
accumulate( accumulate(
TeamAssignment(team != null, $t : team); TeamAssignment(team != null, $t : team);
$total : loadBalance($t) $total : loadBalanceByCount($t)
) )
then then
scoreHolder.addMediumConstraintMatch(kcontext, - (int) ($total * 1000000.0)); scoreHolder.addMediumConstraintMatch(kcontext, - (int) $total.getRootSquaredDeviationMillis());
end end


// ############################################################################ // ############################################################################
Expand All @@ -72,8 +72,8 @@ rule "evenlyConfrontationCount"
accumulate( accumulate(
TeamAssignment(team != null, $t1 : team, $d : day) TeamAssignment(team != null, $t1 : team, $d : day)
and TeamAssignment(team != null, $t1.getId() < team.getId(), $t2 : team, day == $d); and TeamAssignment(team != null, $t1.getId() < team.getId(), $t2 : team, day == $d);
$total : loadBalance(Pair.of($t1, $t2)) $total : loadBalanceByCount(Pair.of($t1, $t2))
) )
then then
scoreHolder.addSoftConstraintMatch(kcontext, - (int) ($total * 1000000.0)); scoreHolder.addSoftConstraintMatch(kcontext, - (int) $total.getRootSquaredDeviationMillis());
end end
@@ -0,0 +1,53 @@
/*
* Copyright 2017 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.optaplanner.examples.tennis.solver.drools.functions;

import java.io.Serializable;

import org.junit.Test;

import static org.junit.Assert.*;

public class LoadBalanceByCountAccumulateFunctionTest {

@Test
public void accumulate() {
LoadBalanceByCountAccumulateFunction function = new LoadBalanceByCountAccumulateFunction();
Serializable context = function.createContext();
function.init(context);
Object a = new Object();
Object b = new Object();
Object c = new Object();
function.accumulate(context, a);
assertEquals(1000, function.getResult(context).getRootSquaredDeviationMillis());
function.accumulate(context, a);
assertEquals(2000, function.getResult(context).getRootSquaredDeviationMillis());
function.accumulate(context, a);
assertEquals(3000, function.getResult(context).getRootSquaredDeviationMillis());
function.reverse(context, a);
assertEquals(2000, function.getResult(context).getRootSquaredDeviationMillis());
function.accumulate(context, b);
assertEquals(2236, function.getResult(context).getRootSquaredDeviationMillis());
function.accumulate(context, c);
assertEquals(2449, function.getResult(context).getRootSquaredDeviationMillis());
function.accumulate(context, c);
assertEquals(3000, function.getResult(context).getRootSquaredDeviationMillis());
function.reverse(context, b);
assertEquals(2828, function.getResult(context).getRootSquaredDeviationMillis());
}

}

0 comments on commit 2f30be0

Please sign in to comment.