Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RL4J - Extracted TD Target calculations (StandardDQN and DoubleDQN) #8267

Merged
merged 1 commit into from
Oct 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.deeplearning4j.rl4j.learning.sync;

import lombok.AllArgsConstructor;
import lombok.Value;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
Expand All @@ -27,6 +28,7 @@
* State, Action, Reward, (isTerminal), State
*/
@Value
@AllArgsConstructor
public class Transition<A> {

INDArray[] observation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
*/
@Slf4j
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN> {
extends SyncLearning<O, A, AS, IDQN> implements TargetQNetworkSource {

// FIXME Changed for refac
// @Getter
Expand All @@ -61,28 +61,19 @@ public QLearning(QLConfiguration conf) {

public abstract MDP<O, A, AS> getMdp();

protected abstract IDQN getCurrentDQN();
public abstract IDQN getQNetwork();

protected abstract IDQN getTargetDQN();
public abstract IDQN getTargetQNetwork();

protected abstract void setTargetDQN(IDQN dqn);

protected INDArray dqnOutput(INDArray input) {
return getCurrentDQN().output(input);
}

protected INDArray targetDqnOutput(INDArray input) {
return getTargetDQN().output(input);
}
protected abstract void setTargetQNetwork(IDQN dqn);

protected void updateTargetNetwork() {
log.info("Update target network");
setTargetDQN(getCurrentDQN().clone());
setTargetQNetwork(getQNetwork().clone());
}


public IDQN getNeuralNet() {
return getCurrentDQN();
return getQNetwork();
}

public abstract QLConfiguration getConfiguration();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.rl4j.learning.sync.qlearning;

import org.deeplearning4j.rl4j.network.dqn.IDQN;

/**
* An interface for all implementations capable of supplying a Q-Network
*
* @author Alexandre Boulanger
*/
public interface QNetworkSource {
IDQN getQNetwork();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.rl4j.learning.sync.qlearning;

import org.deeplearning4j.rl4j.network.dqn.IDQN;

/**
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
*
* @author Alexandre Boulanger
*/
public interface TargetQNetworkSource extends QNetworkSource {
IDQN getTargetQNetwork();
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,22 @@

package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.util.ArrayUtil;

import java.util.ArrayList;
Expand All @@ -53,29 +52,38 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Getter
final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final private IDQN currentDQN;
@Getter
private DQNPolicy<O> policy;
@Getter
private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;

@Getter
@Setter
private IDQN targetDQN;
final private IDQN qNetwork;
@Getter
@Setter(AccessLevel.PROTECTED)
private IDQN targetQNetwork;

private int lastAction;
private INDArray[] history = null;
private double accuReward = 0;

ITDTargetAlgorithm tdTargetAlgorithm;

public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep) {
super(conf);
this.configuration = conf;
this.mdp = mdp;
currentDQN = dqn;
targetDQN = dqn.clone();
policy = new DQNPolicy(getCurrentDQN());
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
this);
mdp.getActionSpace().setSeed(conf.getSeed());

tdTargetAlgorithm = conf.isDoubleDQN()
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());

}

public void postEpoch() {
Expand Down Expand Up @@ -134,7 +142,7 @@ protected QLStepReturn<O> trainStep(O obs) {
if (hstack.shape().length > 2)
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));

INDArray qs = getCurrentDQN().output(hstack);
INDArray qs = getQNetwork().output(hstack);
int maxAction = Learning.getMaxAction(qs);

maxQ = qs.getDouble(maxAction);
Expand All @@ -160,96 +168,31 @@ protected QLStepReturn<O> trainStep(O obs) {
getExpReplay().store(trans);

if (getStepCounter() > updateStart) {
Pair<INDArray, INDArray> targets = setTarget(getExpReplay().getBatch());
getCurrentDQN().fit(targets.getFirst(), targets.getSecond());
DataSet targets = setTarget(getExpReplay().getBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}

history = nhistory;
accuReward = 0;
}


return new QLStepReturn<O>(maxQ, getCurrentDQN().getLatestScore(), stepReply);

return new QLStepReturn<O>(maxQ, getQNetwork().getLatestScore(), stepReply);
}

protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions");

int size = transitions.size();

// TODO: Remove once we use DataSets in observations
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape();
int[] nshape = makeShape(size, shape);
INDArray obs = Nd4j.create(nshape);
INDArray nextObs = Nd4j.create(nshape);
int[] actions = new int[size];
boolean[] areTerminal = new boolean[size];

for (int i = 0; i < size; i++) {
Transition<Integer> trans = transitions.get(i);
areTerminal[i] = trans.isTerminal();
actions[i] = trans.getAction();

INDArray[] obsArray = trans.getObservation();
if (obs.rank() == 2) {
obs.putRow(i, obsArray[0]);
} else {
for (int j = 0; j < obsArray.length; j++) {
obs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]);
}
}

INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
if (nextObs.rank() == 2) {
nextObs.putRow(i, nextObsArray[0]);
} else {
for (int j = 0; j < nextObsArray.length; j++) {
nextObs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]);
}
}
}
if (getHistoryProcessor() != null) {
obs.muli(1.0 / getHistoryProcessor().getScale());
nextObs.muli(1.0 / getHistoryProcessor().getScale());
}

INDArray dqnOutputAr = dqnOutput(obs);

INDArray dqnOutputNext = dqnOutput(nextObs);
INDArray targetDqnOutputNext = targetDqnOutput(nextObs);

INDArray tempQ = null;
INDArray getMaxAction = null;
if (getConfiguration().isDoubleDQN()) {
getMaxAction = Nd4j.argMax(dqnOutputNext, 1);
} else {
tempQ = Nd4j.max(targetDqnOutputNext, 1);
}


for (int i = 0; i < size; i++) {
double yTar = transitions.get(i).getReward();
if (!areTerminal[i]) {
double q = 0;
if (getConfiguration().isDoubleDQN()) {
q += targetDqnOutputNext.getDouble(i, getMaxAction.getInt(i));
} else
q += tempQ.getDouble(i);

yTar += getConfiguration().getGamma() * q;

}

double previousV = dqnOutputAr.getDouble(i, actions[i]);
double lowB = previousV - getConfiguration().getErrorClamp();
double highB = previousV + getConfiguration().getErrorClamp();
double clamped = Math.min(highB, Math.max(yTar, lowB));
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setNShape(makeShape(transitions.size(), shape));

dqnOutputAr.putScalar(i, actions[i], clamped);
// TODO: Remove once we use DataSets in observations
if(getHistoryProcessor() != null) {
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setScale(getHistoryProcessor().getScale());
}

return new Pair(obs, dqnOutputAr);
return tdTargetAlgorithm.computeTDTargets(transitions);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;

import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;

/**
* The base of all DQN based algorithms
*
* @author Alexandre Boulanger
*
*/
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {

private final TargetQNetworkSource qTargetNetworkSource;

/**
* In litterature, this corresponds to Q{net}(s(t+1), a)
*/
protected INDArray qNetworkNextObservation;

/**
* In litterature, this corresponds to Q{tnet}(s(t+1), a)
*/
protected INDArray targetQNetworkNextObservation;

protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) {
super(qTargetNetworkSource, gamma);
this.qTargetNetworkSource = qTargetNetworkSource;
}

protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp);
this.qTargetNetworkSource = qTargetNetworkSource;
}

@Override
protected void initComputation(INDArray observations, INDArray nextObservations) {
super.initComputation(observations, nextObservations);

qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations);

IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
}
}
Loading