-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add storing and restoring RL agent checkpoints #280
Conversation
@kaiks Thank you for the contribution! I have carefully reviewed the PR. It looks great. I just added some minor comments. Your proposal of modifying |
Hi @daochenzha, thank you for your feedback! I'm glad to hear the change looks useful. I'm not seeing the comments you mentioned. Maybe you forgot to publish the review? I'll try submit a follow up PR with the discussed changes later - closer to the end of this or next week. |
examples/run_rl.py
Outdated
@@ -95,7 +111,7 @@ def train(args): | |||
torch.save(agent, save_path) | |||
print('Model saved in', save_path) | |||
|
|||
if __name__ == '__main__': | |||
if __name__ == '__main__': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This blank modification seems not needed
examples/run_rl.py
Outdated
q_mlp_layers=[64,64], | ||
device=device, | ||
save_path=args.log_dir, | ||
save_every=500 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be save_every=args.save_every
?
rlcard/agents/dqn_agent.py
Outdated
@@ -105,7 +109,7 @@ def __init__(self, | |||
|
|||
# The epsilon decay scheduler | |||
self.epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps) | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clean this blank
rlcard/agents/dqn_agent.py
Outdated
@@ -218,9 +226,16 @@ def train(self): | |||
if self.train_t % self.update_target_estimator_every == 0: | |||
self.target_estimator = deepcopy(self.q_estimator) | |||
print("\nINFO - Copied model parameters to target network.") | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clean this
rlcard/agents/dqn_agent.py
Outdated
|
||
|
||
|
||
def save(self, path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function used?
rlcard/agents/nfsp_agent.py
Outdated
|
||
# Step counter to keep track of learning. | ||
self._step_counter = 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this tab
examples/run_rl.py
Outdated
device=device, | ||
) | ||
if args.load_checkpoint_path != "": | ||
dict = torch.load(args.load_checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict
is preserved in Python. It is not a good practice. How about condense these three lines into one line
agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
examples/run_rl.py
Outdated
device=device, | ||
) | ||
if args.load_checkpoint_path != "": | ||
dict = torch.load(args.load_checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict
is preserved in Python. It is not a good practice. How about condense these three lines into one line
agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
@kaiks Yes, I forgot to publish it. You should be able to see it now. |
@daochenzha thanks for the review. I addressed your feedback |
@kaiks LGTM, thank you1 |
This PR introduces checkpoints for RL agents (DQN and NSFP).
Checkpoints are data describing complete agent states (weights and parameters) during training.
To save an agent, you can either set the checkpoint path and save it automatically during training every n training steps. For example, specifying save_path and save_every parameters during agent instantiation:
will save the training progress every 500 steps to a single file.
A complete example of loading an agent looks as follows:
Training can then be resumed or the agent attributes might be inspected or debugged.
Use cases for checkpoints that I can see:
You can also, for instance, save the highest reward agent on every evaluation step when using the
run_rl.py
script by manually saving the agent state:As a next step, it would be possible to streamline loading the models, for instance by extracting and modifying the
load_model
function found in evaluate.py:But this requires more thought and would blow the PR up a bit.
In principle there's nothing stopping us from adding similar serialization for every other agent type and generalizing the loading of saved agent checkpoints.