Skip to content

Commit

Permalink
feat: add trained model
Browse files Browse the repository at this point in the history
  • Loading branch information
kingyuluk committed Nov 29, 2020
1 parent 8aabc2b commit 43895c6
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 112 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ The following command will start to train without graphics:
mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird"
```

You can also run with arguments.The following command will run the game with graphics and using the pre-trained weights.
You can also run with arguments.The following command will test the model with graphics.
```
mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-g -p"
mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-g -p -t"
```

| Argument | Comments |
| ---------- | --------------------------------------- |
| `-g` | Training with graphics. |
| `-b` | Batch size to use for training. |
| `-p` | Use pre-trained weights. |
| `-t` | Test the trained model. |

## Deep Q-Network Algorithm

Expand All @@ -58,7 +59,9 @@ end for
```

## Notes

Trained Model
* It may take 10+ hours to train a bird to a perfect state. You can find the model trained with three million steps in project resource folder: ```src/main/resources/model/dqn-trained-0000-params```

Troubleshooting

* [X11 error](https://github.com/aws-samples/d2l-java/blob/master/documentation/troubleshoot.md#1-x11-error-when-running-object-detection-notebooks-on-ec2-instances)
Expand Down
82 changes: 21 additions & 61 deletions src/main/java/com/kingyu/rlbird/ai/TrainBird.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.kingyu.rlbird.game.FlappyBird;
import com.kingyu.rlbird.rl.env.RlEnv;
import com.kingyu.rlbird.util.Arguments;
import com.kingyu.rlbird.util.Constant;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -38,8 +39,9 @@
public class TrainBird {
private static final Logger logger = LoggerFactory.getLogger(TrainBird.class);

public final static int OBSERVE = 1000; // timeSteps to observe before training
public final static int OBSERVE = 1000; // gameSteps to observe before training
public final static int EXPLORE = 3000000; // frames over which to anneal epsilon
public final static int SAVE_EVERY_STEPS = 100000; // save model every 100,000 step
static RlEnv.Step[] batchSteps;

private TrainBird() {
Expand All @@ -51,11 +53,11 @@ public static void main(String[] args) throws ParseException {

public static void train(String[] args) throws ParseException {
Arguments arguments = Arguments.parseArgs(args);

boolean withGraphics = arguments.isWithGraphics();
boolean withGraphics = arguments.withGraphics();
boolean training = !arguments.isTesting();
int batchSize = arguments.getBatchSize(); // size of mini batch
String modelParamsPath = "model";
String modelParamsName = "dqn-latest";
String modelParamsPath = Constant.MODEL_PATH;
String modelParamsName = "dqn-trained";

int replayBufferSize = 50000; // number of previous transitions to remember;
float rewardDiscount = 0.9f; // decay rate of past observations
Expand All @@ -68,8 +70,8 @@ public static void train(String[] args) throws ParseException {
try (Model model = Model.newInstance("QNetwork")) {
model.setBlock(block);

if(arguments.isPreTrained()) {
File file = new File(modelParamsPath + "/" + modelParamsName + "-0000.params");
if (arguments.isPreTrained()) {
File file = new File(modelParamsPath + "/" + modelParamsName + "-0000.params");
if (file.exists()) {
try {
model.load(Paths.get(modelParamsPath), modelParamsName);
Expand All @@ -80,7 +82,7 @@ public static void train(String[] args) throws ParseException {
} else {
logger.info("Model doesn't exist");
}
}else{
} else {
logger.info("Start training");
}

Expand All @@ -101,8 +103,10 @@ public static void train(String[] args) throws ParseException {

int numOfThreads = 2;
List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent));
callables.add(new TrainerCallable(model, agent));
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
callables.add(new TrainerCallable(model, agent));
}
ExecutorService executorService = Executors.newFixedThreadPool(numOfThreads);
try {
try {
Expand All @@ -119,17 +123,6 @@ public static void train(String[] args) throws ParseException {
} finally {
executorService.shutdown();
}
// while (true) {
// game.runEnvironment(agent, true);
// }

//// 输出神经网络的结构
// Shape currentShape = new Shape(1, 4, 80, 80);
// for (int i = 0; i < block.getChildren().size(); i++) {
// Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(NDManager.newBaseManager(), new Shape[]{currentShape});
// currentShape = newShape[0];
// System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape);
// }
}
}
}
Expand All @@ -147,11 +140,11 @@ public TrainerCallable(Model model, RlAgent agent) {
public Object call() throws Exception {
while (FlappyBird.trainStep < EXPLORE) {
Thread.sleep(0);
if (FlappyBird.timeStep > OBSERVE) {
if (FlappyBird.gameStep > OBSERVE) {
this.agent.trainBatch(batchSteps);
FlappyBird.trainStep++;
if (FlappyBird.trainStep > 0 && FlappyBird.trainStep % 100000 == 0) {
model.save(Paths.get("model"), "dqn-" + FlappyBird.trainStep);
if (FlappyBird.trainStep > 0 && FlappyBird.trainStep % SAVE_EVERY_STEPS == 0) {
model.save(Paths.get(Constant.MODEL_PATH), "dqn-" + FlappyBird.trainStep);
}
}
}
Expand All @@ -162,16 +155,18 @@ public Object call() throws Exception {
private static class GeneratorCallable implements Callable<Object> {
private final FlappyBird game;
private final RlAgent agent;
private final boolean training;

public GeneratorCallable(FlappyBird game, RlAgent agent) {
public GeneratorCallable(FlappyBird game, RlAgent agent, boolean training) {
this.game = game;
this.agent = agent;
this.training = training;
}

@Override
public Object call() {
while (FlappyBird.trainStep < EXPLORE) {
batchSteps = game.runEnvironment(agent, true);
batchSteps = game.runEnvironment(agent, training);
}
return null;
}
Expand Down Expand Up @@ -211,41 +206,6 @@ public static SequentialBlock getBlock() {
.setUnits(2).build());
}

// public static SequentialBlock _getBlock() {
// return new SequentialBlock()
// .add(Conv2d.builder()
// .setKernelShape(new Shape(8, 8))
// .optStride(new Shape(4, 4))
// .optPadding(new Shape(3, 3))
// .setFilters(32).build())
// .add(Pool.maxPool2dBlock(new Shape(2, 2)))
// .add(Activation::relu)
//
// .add(Conv2d.builder()
// .setKernelShape(new Shape(4, 4))
// .optStride(new Shape(2, 2))
// .optPadding(new Shape(1, 1))
// .setFilters(64).build())
// .add(Activation::relu)
//
// .add(Conv2d.builder()
// .setKernelShape(new Shape(3, 3))
// .optStride(new Shape(1, 1))
// .optPadding(new Shape(1, 1))
// .setFilters(64).build())
// .add(Activation::relu)
//
// .add(Blocks.batchFlattenBlock())
// .add(Linear
// .builder()
// .setUnits(512).build())
// .add(Activation::relu)
//
// .add(Linear
// .builder()
// .setUnits(2).build());
// }

public static DefaultTrainingConfig setupTrainingConfig() {
return new DefaultTrainingConfig(Loss.l2Loss())
.optOptimizer(Adam.builder().optLearningRateTracker(Tracker.fixed(1e-6f)).build())
Expand Down
43 changes: 27 additions & 16 deletions src/main/java/com/kingyu/rlbird/game/FlappyBird.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public FlappyBird(NDManager manager, int batchSize, int replayBufferSize, boolea
this(manager, new LruReplayBuffer(batchSize, replayBufferSize));
this.withGraphics = withGraphics;
if (this.withGraphics) {
initFrame(); // 初始化游戏窗口
initFrame();
this.setVisible(true);
}
background = new GameBackground();
Expand All @@ -79,7 +79,7 @@ public FlappyBird(NDManager manager, ReplayBuffer replayBuffer) {
this.currentState = new State(createObservation(currentImg), currentReward, currentTerminal);
}

public static int timeStep = 0;
public static int gameStep = 0;
public static int trainStep = 0;
private static boolean currentTerminal = false;
private static float currentReward = 0.2f;
Expand All @@ -90,19 +90,24 @@ public FlappyBird(NDManager manager, ReplayBuffer replayBuffer) {
*/
@Override
public Step[] runEnvironment(RlAgent agent, boolean training) {
Step[] batchSteps = new Step[0];
reset();

// run the game
NDList action = agent.chooseAction(this, training);
step(action, training);
Step[] batchSteps = this.getBatch();
if (timeStep % 5000 == 0){
if(training) {
batchSteps = this.getBatch();
}
if (gameStep % 5000 == 0){
this.closeStep();
}
if (timeStep <= OBSERVE) {
if (gameStep <= OBSERVE) {
trainState = "observe";
} else {
trainState = "explore";
}
timeStep++;
gameStep++;
return batchSteps;
}

Expand All @@ -113,8 +118,8 @@ public Step[] runEnvironment(RlAgent agent, boolean training) {
*/
@Override
public void step(NDList action, boolean training) {
currentReward = 0.2f;
currentTerminal = false;
// currentReward = 0.2f;
// currentTerminal = false;
if (action.singletonOrThrow().getInt(1) == 1) {
bird.birdFlap();
}
Expand All @@ -135,7 +140,7 @@ public void step(NDList action, boolean training) {
if (training) {
replayBuffer.addStep(step);
}
logger.info("TIME_STEP " + timeStep +
logger.info("GAME_STEP " + gameStep +
" / " + "TRAIN_STEP " + trainStep +
" / " + getTrainState() +
" / " + "ACTION " + (Arrays.toString(action.singletonOrThrow().toArray())) +
Expand Down Expand Up @@ -245,30 +250,36 @@ public NDList getAction() {
return action;
}

/**
* {@inheritDoc}
*/
@Override
public NDList getPreObservation(NDManager manager) {
return preState.getObservation(manager);
}

/**
* {@inheritDoc}
*/
@Override
public NDList getPostObservation(NDManager manager) {
return postState.getObservation(manager);
}

/**
* {@inheritDoc}
*/
@Override
public void attachPostStateManager(NDManager manager) {
postState.attachManager(manager);
}

public void attachPreStateManager(NDManager manager) {
preState.attachManager(manager);
}


/**
* {@inheritDoc}
*/
@Override
public ActionSpace getPostActionSpace() {
return postState.getActionSpace(manager);
public void attachPreStateManager(NDManager manager) {
preState.attachManager(manager);
}

/**
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/com/kingyu/rlbird/rl/LruReplayBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
public class LruReplayBuffer implements ReplayBuffer {

private final int batchSize;
private final int bufferSize;
private final RlEnv.Step[] steps;
// private RlEnv.Step[] stepToClose;
private ArrayList<RlEnv.Step> stepToClose;
private final ArrayList<RlEnv.Step> stepToClose;
private int firstStepIndex;
private int stepsActualSize;

Expand All @@ -39,7 +37,6 @@ public class LruReplayBuffer implements ReplayBuffer {
*/
public LruReplayBuffer(int batchSize, int bufferSize) {
this.batchSize = batchSize;
this.bufferSize = bufferSize;
steps = new RlEnv.Step[bufferSize];
stepToClose = new ArrayList<>(bufferSize);
firstStepIndex = 0;
Expand All @@ -61,6 +58,9 @@ public RlEnv.Step[] getBatch() {
return batch;
}

/**
* {@inheritDoc}
*/
public void closeStep() {
for (RlEnv.Step step : stepToClose) {
step.close();
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/com/kingyu/rlbird/rl/ReplayBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ public interface ReplayBuffer {
*/
RlEnv.Step[] getBatch();

public void closeStep();
/**
* close the step not pointed to.
*/
void closeStep();

/**
* Adds a new step to the buffer.
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/com/kingyu/rlbird/rl/agent/EpsilonGreedy.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
*/
public class EpsilonGreedy implements RlAgent {

private RlAgent baseAgent;
private Tracker exploreRate;

private final RlAgent baseAgent;
private final Tracker exploreRate;
private int counter;

/**
Expand Down
10 changes: 6 additions & 4 deletions src/main/java/com/kingyu/rlbird/rl/agent/QAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,15 @@ public void trainBatch(Step[] batchSteps) {
BatchData batchData =
new BatchData(null, new ConcurrentHashMap<>(), new ConcurrentHashMap<>());

NDManager manager = NDManager.newBaseManager();
// temporary manager for attaching NDArray to reduce the gpu memory usage
NDManager temporaryManager = NDManager.newBaseManager();

NDList preObservationBatch = new NDList();
Arrays.stream(batchSteps).forEach(step -> preObservationBatch.addAll(step.getPreObservation(manager)));
Arrays.stream(batchSteps).forEach(step -> preObservationBatch.addAll(step.getPreObservation(temporaryManager)));
NDList preInput = new NDList(NDArrays.concat(preObservationBatch, 0));

NDList postObservationBatch = new NDList();
Arrays.stream(batchSteps).forEach(step -> postObservationBatch.addAll(step.getPostObservation(manager)));
Arrays.stream(batchSteps).forEach(step -> postObservationBatch.addAll(step.getPostObservation(temporaryManager)));
NDList postInput = new NDList(NDArrays.concat(postObservationBatch, 0));

NDList actionBatch = new NDList();
Expand Down Expand Up @@ -131,6 +133,6 @@ public void trainBatch(Step[] batchSteps) {
step.attachPostStateManager(step.getManager());
step.attachPreStateManager(step.getManager());
}
manager.close();
temporaryManager.close(); // close the temporary manager
}
}
Loading

0 comments on commit 43895c6

Please sign in to comment.