Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
dylan djian
committed
Jun 6, 2018
1 parent
6b9ac64
commit 5613f2c
Showing
23 changed files
with
1,921 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,45 @@ | ||
# retro-sonic-contest | ||
World Models applied to the Open AI Sonic Retro Contest | ||
# retro-contest-sonic | ||
|
||
A student implementation of the World Models paper with documentation. | ||
|
||
Ongoing project. | ||
|
||
|
||
# TODO | ||
|
||
|
||
# CURRENTLY DOING | ||
|
||
* Submit learned agents | ||
* Improving the controller training and model to get a decent transfer between levels | ||
|
||
# DONE | ||
|
||
* β-VAE for the Visual model | ||
* MDN-LSTM for the Memory model | ||
* CMA-ES for the Controller model | ||
* Training pipelines for the 3 models | ||
* Human recordings to generate data | ||
* MongoDB to store data | ||
* LSTM and VAE trained "successfully" | ||
* Multiprocessing of the evaluation of a set of parameters given by the CMA-ES | ||
|
||
|
||
# LONG TERM PLAN ? | ||
|
||
* Cleaner code, more optimized and documented | ||
* Game agnostic | ||
* Online training instead of using a database | ||
|
||
|
||
# Resources | ||
|
||
* [My write-up on the code and concepts of this repository](https://dylandjian.github.io/world-models/) | ||
* [World Models paper](https://arxiv.org/pdf/1803.10122.pdf) | ||
|
||
|
||
# Differences with the official paper | ||
|
||
* No temperature | ||
* No flipping of the loss sign during training (to encourage exploration) | ||
* β-VAE instead of VAE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import torch | ||
import math | ||
|
||
|
||
##### CONFIG | ||
|
||
torch.set_printoptions(precision=10) | ||
## CUDA variable from Torch | ||
CUDA = torch.cuda.is_available() | ||
#torch.backends.cudnn.deterministic = True | ||
## Dtype of the tensors depending on CUDA | ||
DEVICE = torch.device("cuda") if CUDA else torch.device("cpu") | ||
|
||
## Eps for log | ||
EPSILON = 1e-6 | ||
|
||
## VAE | ||
LATENT_VEC = 200 | ||
BETA = 4 | ||
VAE_LOSS = "bce" | ||
|
||
## RNN | ||
OFFSET = 1 | ||
HIDDEN_UNITS = 1024 | ||
HIDDEN_DIM = 1024 | ||
TEMPERATURE = 1.25 | ||
GAUSSIANS = 8 | ||
NUM_LAYERS = 1 | ||
SEQUENCE = 100 | ||
PARAMS_FC1 = HIDDEN_UNITS * NUM_LAYERS * 2 | ||
MDN_CONST = 1.0 / math.sqrt(2.0 * math.pi) | ||
|
||
## Controller | ||
PARALLEL = 4 | ||
SIGMA_INIT = 4 | ||
POPULATION = 64 | ||
SCORE_CAP = 8000 | ||
REPEAT_ROLLOUT = 4 | ||
RENDER_TICK = 64 | ||
MAX_TIMESTEPS = 1800 | ||
TIMESTEP_DECAY = 150 | ||
TIMESTEP_DECAY_TICK = 5 | ||
REWARD_BUFFER = 300 | ||
SAVE_SOLVER_TICK = 1 | ||
MIN_REWARD = 10 | ||
|
||
## Image size | ||
HEIGHT = 128 | ||
WIDTH = 128 | ||
|
||
## Dataset | ||
SIZE = 50000 | ||
MAX_REPLACEMENT = 0.1 | ||
REPEAT = 0 | ||
|
||
## Play | ||
PARALLEL_PER_GAME = 2 | ||
PLAYOUTS = 2500 | ||
PLAYOUTS_PER_LEVEL = 10000 | ||
ACTION_SPACE = 4 | ||
ACTION_SPACE_DISCRETE = 10 | ||
|
||
## Training | ||
MOMENTUM = 0.9 ## SGD | ||
ADAM = True | ||
LR = 1e-3 | ||
L2_REG = 1e-4 | ||
LR_DECAY = 0.1 | ||
BATCH_SIZE = 2 | ||
SAMPLE_SIZE = 500 | ||
|
||
## Refresh | ||
LOSS_TICK = 2 | ||
REFRESH_TICK = 40 | ||
SAVE_PIC_TICK = 10 | ||
SAVE_TICK = 100 | ||
LR_DECAY_TICK = 100000 | ||
|
||
## Jerk | ||
EXPLOIT_BIAS = 0.25 | ||
TOTAL_TIMESTEPS = 1e6 | ||
|
||
## Env | ||
GAMES = [ | ||
"SonicTheHedgehog-Genesis", | ||
"SonicTheHedgehog2-Genesis", | ||
"SonicAndKnuckles3-Genesis" | ||
] | ||
|
||
LEVELS = { | ||
"SonicTheHedgehog-Genesis": [ | ||
"GreenHillZone.Act1", | ||
"GreenHillZone.Act2", | ||
"GreenHillZone.Act3", | ||
"SpringYardZone.Act1", | ||
"SpringYardZone.Act2", | ||
"SpringYardZone.Act3", | ||
"StarLightZone.Act1", | ||
"StarLightZone.Act2", | ||
"StarLightZone.Act3", | ||
"MarbleZone.Act1", | ||
"MarbleZone.Act2", | ||
"MarbleZone.Act3", | ||
"ScrapBrainZone.Act1", | ||
"ScrapBrainZone.Act2", | ||
"LabyrinthZone.Act1", | ||
"LabyrinthZone.Act2", | ||
"LabyrinthZone.Act3" | ||
], | ||
"SonicTheHedgehog2-Genesis": [ | ||
"EmeraldHillZone.Act1", | ||
"EmeraldHillZone.Act2", | ||
"ChemicalPlantZone.Act1", | ||
"ChemicalPlantZone.Act2", | ||
"MetropolisZone.Act1", | ||
"MetropolisZone.Act2", | ||
"MetropolisZone.Act3", | ||
"OilOceanZone.Act1", | ||
"OilOceanZone.Act2", | ||
"MysticCaveZone.Act1", | ||
"MysticCaveZone.Act2", | ||
"HillTopZone.Act1", | ||
"HillTopZone.Act2", | ||
"CasinoNightZone.Act1", | ||
"CasinoNightZone.Act2", | ||
"AquaticRuinZone.Act2", | ||
"AquaticRuinZone.Act1", | ||
"WingFortressZone" | ||
], | ||
"SonicAndKnuckles3-Genesis": [ | ||
"LavaReefZone.Act1", | ||
"LavaReefZone.Act2", | ||
"CarnivalNightZone.Act1", | ||
"CarnivalNightZone.Act2", | ||
"MarbleGardenZone.Act1", | ||
"MarbleGardenZone.Act2", | ||
"MushroomHillZone.Act1", | ||
"MushroomHillZone.Act2", | ||
"DeathEggZone.Act1", | ||
"DeathEggZone.Act2", | ||
"FlyingBatteryZone.Act1", | ||
"FlyingBatteryZone.Act2", | ||
"SandopolisZone.Act1", | ||
"SandopolisZone.Act2", | ||
"HydrocityZone.Act1", | ||
"HydrocityZone.Act2", | ||
"IcecapZone.Act1", | ||
"IcecapZone.Act2", | ||
"AngelIslandZone.Act1", | ||
"AngelIslandZone.Act2", | ||
"LaunchBaseZone.Act1", | ||
"LaunchBaseZone.Act2", | ||
"HiddenPalaceZone" | ||
] | ||
} | ||
|
||
LEVELS_VALID = { | ||
"SonicTheHedgehog-Genesis": [ | ||
"SpringYardZone.Act1", | ||
"GreenHillZone.Act2", | ||
"StarLightZone.Act3", | ||
"ScrapBrainZone.Act1" | ||
], | ||
"SonicTheHedgehog2-Genesis": [ | ||
"MetropolisZone.Act3", | ||
"HillTopZone.Act2", | ||
"CasinoNightZone.Act2" | ||
], | ||
"SonicAndKnuckles3-Genesis": [ | ||
"LavaReefZone.Act1", | ||
"FlyingBatteryZone.Act2", | ||
"HydrocityZone.Act1", | ||
"AngelIslandZone.Act2" | ||
] | ||
} | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import numpy as np | ||
import torch.multiprocessing as multiprocessing | ||
import timeit | ||
import time | ||
import torch | ||
from .env import create_env | ||
from .play_utils import _formate_img | ||
from const import * | ||
|
||
|
||
class VAECGame(multiprocessing.Process): | ||
def __init__(self, current_time, process_id, vae, lstm, controller, game, level, result_queue, max_timestep): | ||
super(VAECGame, self).__init__() | ||
self.process_id = process_id | ||
self.game = game | ||
self.level = level | ||
self.current_time = current_time | ||
self.vae = vae | ||
self.lstm = lstm | ||
self.controller = controller | ||
self.result_queue = result_queue | ||
self.max_timestep = max_timestep | ||
self.convert = {0: 0, 1: 5, 2: 6, 3: 7} | ||
|
||
|
||
def _convert(self, predicted_actions): | ||
""" Convert predicted action into an environment action """ | ||
|
||
predicted_actions = predicted_actions.numpy()[0] | ||
final_action = np.zeros((12,), dtype=np.bool) | ||
actions = np.where(predicted_actions > 0.5)[0] | ||
for i in actions: | ||
final_action[self.convert[i]] = True | ||
return final_action | ||
|
||
|
||
def run(self): | ||
final_reward = [] | ||
env = False | ||
start_time = timeit.default_timer() | ||
|
||
for i in range(REPEAT_ROLLOUT): | ||
if env: | ||
env.close() | ||
env = create_env(self.game, self.level) | ||
obs = env.reset() | ||
|
||
done = False | ||
total_reward = 0 | ||
total_steps = 0 | ||
current_rewards = [] | ||
|
||
while not done: | ||
with torch.no_grad(): | ||
if total_steps % SEQUENCE == 0: | ||
self.lstm.hidden = self.lstm.init_hidden(1) | ||
|
||
## Predict the latent representation of the current frame | ||
obs = torch.tensor(_formate_img(obs), dtype=torch.float, device=DEVICE).div(255) | ||
z = self.vae(obs.view(1, 3, HEIGHT, WIDTH), encode=True) | ||
|
||
## Use the latent representation and the hidden state of the LSTM | ||
## to predict an action vector | ||
actions = self.controller(torch.cat((z, | ||
self.lstm.hidden[0].view(1, -1), | ||
self.lstm.hidden[1].view(1, -1)), dim=1)) | ||
final_action = self._convert(actions.cpu()) | ||
obs, reward, done, info = env.step(final_action) | ||
|
||
## Update the hidden state of the LSTM | ||
action = torch.tensor(env.get_act(final_action), dtype=torch.float, device=DEVICE)\ | ||
.div(ACTION_SPACE_DISCRETE) | ||
lstm_input = torch.cat((z, action.view(1, 1)), dim=1) | ||
res = self.lstm(lstm_input.view(1, 1, LATENT_VEC + 1)) | ||
|
||
## Check for minimum reward duration the last buffer duration | ||
if len(current_rewards) == REWARD_BUFFER: | ||
if np.mean(current_rewards) < MIN_REWARD: | ||
break | ||
current_rewards.insert(0, reward) | ||
current_rewards.pop() | ||
else: | ||
current_rewards.append(reward) | ||
total_reward += reward | ||
|
||
## Check for rendering | ||
if (self.process_id + 1) % RENDER_TICK == 0: | ||
if total_steps % 200 == 0: | ||
print(actions) | ||
env.render() | ||
|
||
## Check for custom timelimit | ||
if total_steps > self.max_timestep: | ||
break | ||
|
||
total_steps += 1 | ||
final_reward.append(total_reward) | ||
|
||
final_time = timeit.default_timer() - start_time | ||
print("[{} / {}] Final mean reward: {}" \ | ||
.format(self.process_id + 1, POPULATION, np.mean(final_reward))) | ||
env.close() | ||
result = {} | ||
result[self.process_id] = (np.mean(final_reward), final_time) | ||
self.result_queue.put(result) |
Oops, something went wrong.