Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RL4J - Added a unit test to help refac QLearningDiscrete.trainStep() #8065

Merged
merged 2 commits into from Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -45,8 +45,11 @@
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN> {

@Getter
final private IExpReplay<A> expReplay;
// FIXME Changed for refac
// @Getter
// final private IExpReplay<A> expReplay;
@Getter @Setter
private IExpReplay<A> expReplay;
aboulang2002 marked this conversation as resolved.
Show resolved Hide resolved

public QLearning(QLConfiguration conf) {
super(conf);
Expand Down
@@ -1,17 +1,16 @@
package org.deeplearning4j.rl4j.learning.async;

import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.deeplearning4j.rl4j.support.MockDataManager;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
Expand Down Expand Up @@ -93,7 +92,7 @@ public void refac_withHistoryProcessor_isSaveFalse_checkDataManagerCallsRemainTh
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i + 1, entry.getStepCounter());
assertEquals(i, entry.getEpochCounter());
assertEquals(1.0, entry.getReward(), 0.0);
assertEquals(79.0, entry.getReward(), 0.0);
}

assertEquals(10, dataManager.isSaveDataCallCount);
Expand Down Expand Up @@ -128,7 +127,7 @@ public void refac_withHistoryProcessor_isSaveTrue_checkDataManagerCallsRemainThe
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
assertEquals(i + 1, entry.getStepCounter());
assertEquals(i, entry.getEpochCounter());
assertEquals(1.0, entry.getReward(), 0.0);
assertEquals(79.0, entry.getReward(), 0.0);
}

assertEquals(1, dataManager.isSaveDataCallCount);
Expand Down Expand Up @@ -308,91 +307,6 @@ public void save(String filename) throws IOException {
}
}

public static class MockEncodable implements Encodable {

private final int value;

public MockEncodable(int value) {

this.value = value;
}

@Override
public double[] toArray() {
return new double[] { value };
}
}

public static class MockObservationSpace implements ObservationSpace {

@Override
public String getName() {
return null;
}

@Override
public int[] getShape() {
return new int[] { 1 };
}

@Override
public INDArray getLow() {
return null;
}

@Override
public INDArray getHigh() {
return null;
}
}

public static class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {

private final DiscreteSpace actionSpace;
private int currentObsValue = 0;
private final ObservationSpace observationSpace;

public MockMDP(ObservationSpace observationSpace) {
actionSpace = new DiscreteSpace(5);
this.observationSpace = observationSpace;
}

@Override
public ObservationSpace getObservationSpace() {
return observationSpace;
}

@Override
public DiscreteSpace getActionSpace() {
return actionSpace;
}

@Override
public MockEncodable reset() {
return new MockEncodable(++currentObsValue);
}

@Override
public void close() {

}

@Override
public StepReply<MockEncodable> step(Integer obs) {
return new StepReply<MockEncodable>(new MockEncodable(obs), (double)obs, isDone(), null);
}

@Override
public boolean isDone() {
return false;
}

@Override
public MDP newInstance() {
return null;
}
}

public static class MockAsyncConfiguration implements AsyncConfiguration {

private final int nStep;
Expand Down
@@ -0,0 +1,142 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;

import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.*;

public class QLearningDiscreteTest {
@Test
public void refac_QLearningDiscrete_trainStep() {
// Arrange
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdp = new MockMDP(observationSpace);
MockDQN dqn = new MockDQN();
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
0, 1.0, 0, 0, 0, 0, true);
MockDataManager dataManager = new MockDataManager(false);
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp);
MockExpReplay expReplay = new MockExpReplay();
sut.setExpReplay(expReplay);
MockEncodable obs = new MockEncodable(1);
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();

// Act
sut.initMdp();
for(int step = 0; step < 16; ++step) {
results.add(sut.trainStep(obs));
sut.incrementStep();
}

// Assert
// HistoryProcessor calls
assertEquals(24, hp.recordCallCount);
assertEquals(13, hp.addCallCount);
assertEquals(0, hp.startMonitorCallCount);
assertEquals(0, hp.stopMonitorCallCount);

// DQN calls
assertEquals(1, dqn.fitParams.size());
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][] {
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },

};
for(int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i);

assertEquals(5, outputParam.shape()[0]);
assertEquals(1, outputParam.shape()[1]);

double[] expectedRow = expectedDQNOutput[i];
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j] / 255.0, outputParam.getDouble(j), 0.00001);
}
}

// MDP calls
assertArrayEquals(new Integer[] { 0, 0, 0, 0, 0, 0, 0, 0, 0 ,0, 4, 4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());

// ExpReplay calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
int[] expectedTrActions = new int[] { 0, 4, 3, 4, 4, 4, 4, 4 };
double[] expectedTrNextObservation = new double[] { 0, 0, 0, 1.0, 9.0, 11.0, 13.0, 15.0 };
double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
};
for(int i = 0; i < expectedTrRewards.length; ++i) {
Transition tr = expReplay.transitions.get(i);
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
assertEquals(expectedTrActions[i], tr.getAction());
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals(expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
}
}

// trainStep results
assertEquals(16, results.size());
double[] expectedMaxQ = new double[] { 1.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
for(int i=0; i < 16; ++i) {
QLearning.QLStepReturn<MockEncodable> result = results.get(i);
if(i % 2 == 0) {
assertEquals(expectedMaxQ[i/2] / 255.0, result.getMaxQ(), 0.001);
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
}
else {
assertTrue(result.getMaxQ().isNaN());
}
}
}

public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) {
super(mdp, dqn, conf, dataManager, epsilonNbStep);
}

@Override
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
}
}
}
@@ -0,0 +1,100 @@
package org.deeplearning4j.rl4j.support;

import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;

import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;

public class MockDQN implements IDQN {

public final List<INDArray> outputParams = new ArrayList<>();
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();

@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}

@Override
public boolean isRecurrent() {
return false;
}

@Override
public void reset() {

}

@Override
public void fit(INDArray input, INDArray labels) {
fitParams.add(new Pair<>(input, labels));
}

@Override
public void fit(INDArray input, INDArray[] labels) {

}

@Override
public INDArray output(INDArray batch){
outputParams.add(batch);
return batch;
}

@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}

@Override
public IDQN clone() {
return null;
}

@Override
public void copy(NeuralNet from) {

}

@Override
public void copy(IDQN from) {

}

@Override
public Gradient[] gradient(INDArray input, INDArray label) {
return new Gradient[0];
}

@Override
public Gradient[] gradient(INDArray input, INDArray[] label) {
return new Gradient[0];
}

@Override
public void applyGradient(Gradient[] gradient, int batchSize) {

}

@Override
public double getLatestScore() {
return 0;
}

@Override
public void save(OutputStream os) throws IOException {

}

@Override
public void save(String filename) throws IOException {

}
}