Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The output from the current DQN and target DQN #8107

Closed
flyheroljg opened this issue Aug 17, 2019 · 6 comments
Closed

The output from the current DQN and target DQN #8107

flyheroljg opened this issue Aug 17, 2019 · 6 comments
Assignees
Labels

Comments

@flyheroljg
Copy link

@flyheroljg flyheroljg commented Aug 17, 2019

Issue Description

When I study the file ‘QLearningDiscrete.java’, I find a question in the function ‘setTarge’.
Function setTarge should be prepare for updating the current DQN which approximates the values of behavior functions. Therefore, in the return value pair(obs, dqnOutputAr), the first is the input for training and the second is the input for labeling.
The standard DQN algorithm requires that obs and dqnOutputAr should be outputted by the current DQN and the target DQN, respectively. However, in your code, the variable ‘yTar’, which should be the target value of TD, is inferred from the current DQN, rather than the target DQN. Because the variable ‘dqnOutAr’ is outputted by the current DQN, both obs and dqnOutAr are outputted by the current DQN. Why is it? Did I not understand the standard DQN algorithm correctly?

Version Information

Please indicate relevant versions, including, if relevant:
RL4J, file: QLearningDiscrete.java

Additional Information

Where applicable, please also provide:

protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition> transitions) {
if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions");

    int size = transitions.size();

    int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
                    : getHistoryProcessor().getConf().getShape();
    int[] nshape = makeShape(size, shape);
    INDArray obs = Nd4j.create(nshape);
    INDArray nextObs = Nd4j.create(nshape);
    int[] actions = new int[size];
    boolean[] areTerminal = new boolean[size];

    for (int i = 0; i < size; i++) {
        Transition<Integer> trans = transitions.get(i);
        areTerminal[i] = trans.isTerminal();
        actions[i] = trans.getAction();

        INDArray[] obsArray = trans.getObservation();
        if (obs.rank() == 2) {
            obs.putRow(i, obsArray[0]);
        } else {
            for (int j = 0; j < obsArray.length; j++) {
                obs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]);
            }
        }

        INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
        if (nextObs.rank() == 2) {
            nextObs.putRow(i, nextObsArray[0]);
        } else {
            for (int j = 0; j < nextObsArray.length; j++) {
                nextObs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]);
            }
        }
    }
    if (getHistoryProcessor() != null) {
        obs.muli(1.0 / getHistoryProcessor().getScale());
        nextObs.muli(1.0 / getHistoryProcessor().getScale());
    }

    INDArray dqnOutputAr = dqnOutput(obs);

    INDArray dqnOutputNext = dqnOutput(nextObs);
    INDArray targetDqnOutputNext = null;

    INDArray tempQ = null;
    INDArray getMaxAction = null;
    if (getConfiguration().isDoubleDQN()) {
        targetDqnOutputNext = targetDqnOutput(nextObs);
        getMaxAction = Nd4j.argMax(dqnOutputNext, 1);
    } else {
        tempQ = Nd4j.max(dqnOutputNext, 1);
    }


    for (int i = 0; i < size; i++) {
        double yTar = transitions.get(i).getReward();
        if (!areTerminal[i]) {
            double q = 0;
            if (getConfiguration().isDoubleDQN()) {
                q += targetDqnOutputNext.getDouble(i, getMaxAction.getInt(i));
            } else
                q += tempQ.getDouble(i);

            yTar += getConfiguration().getGamma() * q;

        }



        double previousV = dqnOutputAr.getDouble(i, actions[i]);
        double lowB = previousV - getConfiguration().getErrorClamp();
        double highB = previousV + getConfiguration().getErrorClamp();
        double clamped = Math.min(highB, Math.max(yTar, lowB));

        dqnOutputAr.putScalar(i, actions[i], clamped);
    }

    return new Pair(obs, dqnOutputAr);
}

Contributing

If you'd like to help us fix the issue by contributing some code, but would
like guidance or help in doing so, please mention it!

@saudet

This comment has been minimized.

Copy link
Member

@saudet saudet commented Aug 17, 2019

It's entirely possible that there is a bug. Could you send a pull request with the proposed changes?

@flyheroljg

This comment has been minimized.

Copy link
Author

@flyheroljg flyheroljg commented Aug 17, 2019

Because I am a new user, I can not know how to send a pull request with the proposed changes. I think there one change in the file ‘QLearningDiscrete.java’.
The variable ‘dqnOutputAr’ should be outputted by the target DQN. The code can is change into the follow: “INDArray dqnOutputAr = targetDqnOutput(obs)”
As thus, in the return value pair(obs, dqnOutputAr), the obs is the training input. The target DQN output the dqnOutputAr for labeling.

@aboulang2002

This comment has been minimized.

Copy link
Contributor

@aboulang2002 aboulang2002 commented Aug 17, 2019

@saudet Please assign me this issue, I'll make the changes.

@saudet

This comment has been minimized.

Copy link
Member

@saudet saudet commented Aug 18, 2019

@aboulang2002 Thanks! Feel free to roll this in your latest PR for simplicity :)

@aboulang2002

This comment has been minimized.

Copy link
Contributor

@aboulang2002 aboulang2002 commented Aug 18, 2019

I agree there is a problem here but I think it's more than just dqnOutputAr.

The Q-Learning algorithm says we should minimize the loss function:
L = SquaredErrorLoss(yTar - Q(s)) Eq.1
where yTar = reward + gamma * greedy(Q(nextS))
and where greedy() is simply to evaluate the Q-value for all actions and take the highest value.

In Eq.1, the Q-Network should be used for the computation of Q(s).
As for the computation of yTar, it is different whether we're doing standard DQN or Double-DQN:
Standard: We should only use the Target-Network. Everywhere.
Double-DQN: The Q-Network decides the action and the Target-Network computes the Q-value.

So, since in all cases it's the Target-Network that computes the Q-value, I think we should use targetDqnOutput() for dqnOutputAr (as proposed by @flyheroljg) and use targetDqnOutputNext to compute Q-values.

At line 262 there is some kind of Q-value clamping:

            double previousV = dqnOutputAr.getDouble(i, actions[i]);
            double lowB = previousV - getConfiguration().getErrorClamp();
            double highB = previousV + getConfiguration().getErrorClamp();
            double clamped = Math.min(highB, Math.max(yTar, lowB));

I suppose its purpose is to reduce the impact of outliers and/or very large updates, but I couldn't find much info on that on the net.
Anyways, it seems more logical to me to use targetDqnOutput() there too.

TL,DR:

  • Make dqnOutputAr use the target network and rename it targetDqnOutputArray or targetDqnOutput
  • Move the assignation of targetDqnOutputNext (line 240) to its declaration (line 235)
  • Make sure we use targetDqnOutputNext everywhere, except where we decide the action in Double-DQN (line 241) (Q-Network should be used)

Am I mistaken?

@saudet

This comment has been minimized.

Copy link
Member

@saudet saudet commented Oct 1, 2019

Fixed with pull #8250. Thank you!

@saudet saudet closed this Oct 1, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.