Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add storing and restoring RL agent checkpoints #280

Merged
merged 3 commits into from
Apr 19, 2023

Conversation

kaiks
Copy link
Contributor

@kaiks kaiks commented Apr 15, 2023

This PR introduces checkpoints for RL agents (DQN and NSFP).

Checkpoints are data describing complete agent states (weights and parameters) during training.

To save an agent, you can either set the checkpoint path and save it automatically during training every n training steps. For example, specifying save_path and save_every parameters during agent instantiation:

agent = DQNAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
device=device,
save_path=args.log_dir,
save_every=500
)

will save the training progress every 500 steps to a single file.

A complete example of loading an agent looks as follows:

import torch
from rlcard.agents import DQNAgent
dict = torch.load('experiments/dqn_checkpoint_v2/checkpoint.pt')
agent = DQNAgent.from_checkpoint(checkpoint = dict)

Training can then be resumed or the agent attributes might be inspected or debugged.

Use cases for checkpoints that I can see:

  • ensuring that the training progress will not be lost if it gets interrupted (the host crashes, you need to shut down your computer, etc)
  • debugging an agent at training time
  • longer term, I could see value in experimenting with building training pipelines for sparse reward environments - first training the agent on a simplified variant of the game, changing opponents to increasingly more difficult, etc.

You can also, for instance, save the highest reward agent on every evaluation step when using the run_rl.py script by manually saving the agent state:

if current_reward > best_reward:
	agent.save_checkpoint(path, filename='best_checkpoint.pt'):

As a next step, it would be possible to streamline loading the models, for instance by extracting and modifying the load_model function found in evaluate.py:

def load_agent_from_checkpoint(checkpoint_dict):
	if checkpoint_dict["agent_type"] == "NFSPAgent":
		return NFSPAgent.from_checkpoint(checkpoint_dict)
	elif checkpoint_dict["agent_type"] == "DQNAgent":
		return DQNAgent.from_checkpoint(checkpoint_dict)
	else:
		raise Exception("Unknown agent type")
end

def load_model(model_path, env=None, position=None, device=None):
    if os.path.isfile(model_path):  # Torch model
        import torch
        agent = torch.load(model_path, map_location=device)
        if "agent_type" in agent:
	        agent = load_agent_checkpoint(agent)
        agent.set_device(device)
    elif os.path.isdir(model_path):  # CFR model
        from rlcard.agents import CFRAgent
        agent = CFRAgent(env, model_path)
        agent.load()
    elif model_path == 'random':  # Random model
        from rlcard.agents import RandomAgent
        agent = RandomAgent(num_actions=env.num_actions)
    else:  # A model in the model zoo
        from rlcard import models
        agent = models.load(model_path).agents[position]

But this requires more thought and would blow the PR up a bit.

In principle there's nothing stopping us from adding similar serialization for every other agent type and generalizing the loading of saved agent checkpoints.

@daochenzha daochenzha self-requested a review April 17, 2023 19:04
@daochenzha
Copy link
Member

@kaiks Thank you for the contribution! I have carefully reviewed the PR. It looks great. I just added some minor comments. Your proposal of modifying evaluate.py also makes lots of sense. Please consider submitting another PR as well. Have a great day

@kaiks
Copy link
Contributor Author

kaiks commented Apr 18, 2023

Hi @daochenzha, thank you for your feedback! I'm glad to hear the change looks useful. I'm not seeing the comments you mentioned. Maybe you forgot to publish the review?

I'll try submit a follow up PR with the discussed changes later - closer to the end of this or next week.

@@ -95,7 +111,7 @@ def train(args):
torch.save(agent, save_path)
print('Model saved in', save_path)

if __name__ == '__main__':
if __name__ == '__main__':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This blank modification seems not needed

q_mlp_layers=[64,64],
device=device,
save_path=args.log_dir,
save_every=500
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be save_every=args.save_every?

@@ -105,7 +109,7 @@ def __init__(self,

# The epsilon decay scheduler
self.epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean this blank

@@ -218,9 +226,16 @@ def train(self):
if self.train_t % self.update_target_estimator_every == 0:
self.target_estimator = deepcopy(self.q_estimator)
print("\nINFO - Copied model parameters to target network.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean this




def save(self, path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function used?


# Step counter to keep track of learning.
self._step_counter = 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this tab

device=device,
)
if args.load_checkpoint_path != "":
dict = torch.load(args.load_checkpoint_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict is preserved in Python. It is not a good practice. How about condense these three lines into one line

agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))

device=device,
)
if args.load_checkpoint_path != "":
dict = torch.load(args.load_checkpoint_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict is preserved in Python. It is not a good practice. How about condense these three lines into one line

agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))

@daochenzha
Copy link
Member

@kaiks Yes, I forgot to publish it. You should be able to see it now.

@kaiks
Copy link
Contributor Author

kaiks commented Apr 18, 2023

@daochenzha thanks for the review. I addressed your feedback

@kaiks kaiks requested a review from daochenzha April 18, 2023 20:53
@daochenzha
Copy link
Member

@kaiks LGTM, thank you1

@daochenzha daochenzha merged commit f4ae4fc into datamllab:master Apr 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants