Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precision Loss whe saving ACPolicy using RL4J #5543

Closed
borkox opened this issue Jun 9, 2018 · 13 comments

Comments

@borkox
Copy link

commented Jun 9, 2018

I made an education of RL agent using the A3CDiscreteConv. Then I saved the policy in a file. I played the agent with the policy which I still hold in memory. After that I played the policy which I loaded from file. The results were very different and I had a message telling me I will loose precision:

2018-06-09 11:14:17 INFO  c.r.r.a.EvaluatorA3C - Loading agent from ./rl4j/a3c_policy.bin
2018-06-09 11:14:17 INFO  o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
2018-06-09 11:14:18 INFO  o.n.n.NativeOpsHolder - Number of threads used for NativeOps: 2
2018-06-09 11:14:19 INFO  o.n.n.Nd4jBlas - Number of threads used for BLAS: 2
2018-06-09 11:14:19 INFO  o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 10]
2018-06-09 11:14:19 INFO  o.n.l.a.o.e.DefaultOpExecutioner - Cores: [4]; Memory: [1.8GB];
2018-06-09 11:14:19 INFO  o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [OPENBLAS]
2018-06-09 11:14:19 WARN  o.n.l.a.b.BaseDataBuffer - Loading a data stream with opType different from what is set globally. Expect precision loss
2018-06-09 11:14:19 WARN  o.n.l.a.b.BaseDataBuffer - Loading a data stream with opType different from what is set globally. Expect precision loss
2018-06-09 11:14:20 INFO  o.d.n.g.ComputationGraph - Starting ComputationGraph with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]

So I didn't changed any of the default datatypes in global configuration. How can I make the model from the file to be equal to the one in memory.

Here is some code:


package com.rltrader.rl.a3c;

import com.rltrader.hprices.CachingPricesService;
import com.rltrader.model.CommonConfig;
import com.rltrader.model.ConfigurationModel;
import com.rltrader.model.LearnConfig;
import com.rltrader.rl.PricesObservation;
import com.rltrader.rl.mdp.LearningPricesMDP;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete;
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscreteConv;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.util.DataManager;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.lang.reflect.Field;

@Component
@Slf4j
public class LearnerA3C {



    @Value("${rl4j.work.folder}")
    private String rl4jDataFolder;

    @Value("${rl4j.work.a3c.policy}")
    private String policySaveFile;

    @Autowired
    private CachingPricesService cachingPricesService;

    @Autowired
    private ConfigurationModel configurationModel;

    @Autowired
    private EvaluatorA3C evaluatorA3C;


    public void startLearn() throws IOException {
        LearnConfig learn = configurationModel.getLearn();
        CommonConfig common = configurationModel.getCommon();

        //prepare UI
        UIServer uiServer = UIServer.getInstance();
        StatsStorage statsStorage = new InMemoryStatsStorage();
        uiServer.attach(statsStorage);
        int listenerFrequency = 10;
        TrainingListener[] listeners = new TrainingListener[]{
                new StatsListener(statsStorage, listenerFrequency),
                new ScoreIterationListener(listenerFrequency)
        };

        final A3CDiscrete.A3CConfiguration A3C =
                new A3CDiscrete.A3CConfiguration(
                        123,            //Random seed
                        learn.getMaxStepByEpoch(),          //Max step By epoch
                        learn.getMaxStep(),        //Max step
                        8,              //Number of threads
                        32,             //t_max
                        500,            //num step noop warmup
                        1,            //reward scaling
                        0.99,           //gamma
                        10.0            //td-error clipping
                );

        final ActorCriticFactory.Configuration NET_A3C =
                new ActorCriticFactory.Configuration(
                        0.00,   //l2 regularization
                        new Nesterovs(0.00001, 0.78), //learning rate
                        listeners,
                        true
                );


        log.info("============ LEARN (Actor Critic)===========");
        cachingPricesService.cacheData();

        DataManager manager = new DataManager(rl4jDataFolder, true);
        LearningPricesMDP mdp = new LearningPricesMDP(configurationModel, cachingPricesService);
        ActorCriticFactory actorCriticFactory
                = new ActorCriticFactory(NET_A3C);
        //define the training
//        ActorCriticCompGraph actorCritic = actorCriticFactory.buildActorCritic(
//                mdp.getObservationSpace().getShape(),
//                mdp.getActionSpace().getSize());

//        A3CDiscrete<PricesObservation> asyncLearning = new A3CDiscrete(
//                mdp,
//                actorCritic,
//                A3C,
//                manager){};
        IHistoryProcessor.Configuration hpconf = new IHistoryProcessor.Configuration(
                4,
                common.getWindowSize(),
                common.getPortfolioSymbols().length,
                common.getWindowSize(),
                common.getPortfolioSymbols().length,
                0,
                0,
                1
        );
        A3CDiscrete<PricesObservation> asyncLearning = new A3CDiscreteConv(
                mdp,
                actorCriticFactory,
                hpconf,
                A3C,
                manager){
            @Override
            protected void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
                try {
                    Field field = Learning.class.getDeclaredField("historyProcessor");
                    field.setAccessible(true);
                    field.set(this, new HistoryProcessorStorage(conf));
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            @Override
            public AsyncThread newThread(int i) {
                try {
                    AsyncThread at = super.newThread(i);
                    Field field = AsyncThread.class.getDeclaredField("historyProcessor");
                    field.setAccessible(true);
                    field.set(at, new HistoryProcessorStorage(hpconf));
                    return at;
                }  catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        };

        //train
        asyncLearning.train();

        //get the final policy
        ACPolicy<PricesObservation> pol = asyncLearning.getPolicy();

        //serialize and save (serialization showcase, but not required)
        log.info("Saving agent to file {}", policySaveFile);
        pol.save(policySaveFile);

        //close the mdp (close http)
        mdp.close();
        log.info("End learn");

        // ALERT: The evaluation is not the same if the policy is loaded from a file
        evaluatorA3C.evaluate(pol);

        System.exit(0);
    }
}

@saudet

This comment has been minimized.

Copy link
Member

commented Jun 10, 2018

@saudet saudet closed this Jun 10, 2018

@borkox

This comment has been minimized.

Copy link
Author

commented Jun 10, 2018

image
The seed is stored in the file. I am sure it is not that the problem. I also think the coefficients are the same, but I suspect the connections between nodes are not recovered and thus the network loaded only looks the same, but makes completely different calculations.

@saudet saudet reopened this Jun 10, 2018

@saudet

This comment has been minimized.

Copy link
Member

commented Jun 10, 2018

Ok, does this happen with simple MDPs like the A3CCartpole example?

@borkox

This comment has been minimized.

Copy link
Author

commented Jun 10, 2018

I will try to make a failing test case, because otherwise I cannot explain it. But I took most of the code from A3CALE.

@borkox

This comment has been minimized.

Copy link
Author

commented Jun 10, 2018

image
I couldn't make a failing unit test, but I started my app in debug and the first difference, then I trained a policy, then I saved it and loaded from a file -> I noticed is that the Random is not loaded, see the picture.

@borkox

This comment has been minimized.

Copy link
Author

commented Jun 10, 2018

image
So I set the same Random() objects on both policies, and I compared the output from logs and the outputs was exactly the same. So this was the problem, the random object was not persisted and I didn't noticed that. I don't know how you will classify this, is it a bug or not well explained. For me it is logical the Random() object to be saved, otherwise the policy behaves odd.

@saudet saudet added Bug and removed Question labels Jun 11, 2018

@saudet

This comment has been minimized.

Copy link
Member

commented Jun 11, 2018

Yeah, I'd call that a bug :)

@saudet

This comment has been minimized.

Copy link
Member

commented Jun 11, 2018

But I'm pretty sure ACPolicy.load("/path/to/model.zip", new Random(123)) does what you need...

Anyway, the seed used for the network configuration is different from the one used by ACPolicy... Should we still have it use the seed from the network by default instead of not using randomness by default?

@borkox

This comment has been minimized.

Copy link
Author

commented Jun 11, 2018

Yes, I use this as a workaround, but I imagine that if I save with Random() inside it to restore with Random() inside it. It's up to your internal conventions to consider as a Bug or not.

@saudet

This comment has been minimized.

Copy link
Member

commented Jun 12, 2018

@borkox Right, let's try to be consistent with DL4J, at least...

@AlexDBlack What is DL4J doing to restore the state of Random objects?

@raver119

This comment has been minimized.

Copy link
Contributor

commented Jun 12, 2018

We dont store Random objects in dl4j models, we're storing seed. And with the same seed we're getting the same sequence on both backends. That's how it works there.

@saudet

This comment has been minimized.

Copy link
Member

commented Jun 12, 2018

@raver119 Ok, but then if we save a model, reload it and restart training, we're going to get the same sequence of random numbers, right?

@lock

This comment has been minimized.

Copy link

commented Sep 21, 2018

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@lock lock bot locked and limited conversation to collaborators Sep 21, 2018

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
3 participants
You can’t perform that action at this time.