-
Notifications
You must be signed in to change notification settings - Fork 614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add storing and restoring RL agent checkpoints #280
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
plot_curve, | ||
) | ||
|
||
|
||
def train(args): | ||
|
||
# Check whether gpu is available | ||
|
@@ -35,21 +36,36 @@ def train(args): | |
# Initialize the agent and use random agents as opponents | ||
if args.algorithm == 'dqn': | ||
from rlcard.agents import DQNAgent | ||
agent = DQNAgent( | ||
num_actions=env.num_actions, | ||
state_shape=env.state_shape[0], | ||
mlp_layers=[64,64], | ||
device=device, | ||
) | ||
if args.load_checkpoint_path != "": | ||
dict = torch.load(args.load_checkpoint_path) | ||
agent = DQNAgent.from_checkpoint(checkpoint = dict) | ||
del dict | ||
else: | ||
agent = DQNAgent( | ||
num_actions=env.num_actions, | ||
state_shape=env.state_shape[0], | ||
mlp_layers=[64,64], | ||
device=device, | ||
save_path=args.log_dir, | ||
save_every=args.save_every | ||
) | ||
|
||
elif args.algorithm == 'nfsp': | ||
from rlcard.agents import NFSPAgent | ||
agent = NFSPAgent( | ||
num_actions=env.num_actions, | ||
state_shape=env.state_shape[0], | ||
hidden_layers_sizes=[64,64], | ||
q_mlp_layers=[64,64], | ||
device=device, | ||
) | ||
if args.load_checkpoint_path != "": | ||
dict = torch.load(args.load_checkpoint_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
agent = NFSPAgent.from_checkpoint(checkpoint = dict) | ||
del dict | ||
else: | ||
agent = NFSPAgent( | ||
num_actions=env.num_actions, | ||
state_shape=env.state_shape[0], | ||
hidden_layers_sizes=[64,64], | ||
q_mlp_layers=[64,64], | ||
device=device, | ||
save_path=args.log_dir, | ||
save_every=500 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be |
||
) | ||
agents = [agent] | ||
for _ in range(1, env.num_players): | ||
agents.append(RandomAgent(num_actions=env.num_actions)) | ||
|
@@ -95,7 +111,7 @@ def train(args): | |
torch.save(agent, save_path) | ||
print('Model saved in', save_path) | ||
|
||
if __name__ == '__main__': | ||
if __name__ == '__main__': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This blank modification seems not needed |
||
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard") | ||
parser.add_argument( | ||
'--env', | ||
|
@@ -152,6 +168,17 @@ def train(args): | |
type=str, | ||
default='experiments/leduc_holdem_dqn_result/', | ||
) | ||
|
||
parser.add_argument( | ||
"--load_checkpoint_path", | ||
type=str, | ||
default="", | ||
) | ||
|
||
parser.add_argument( | ||
"--save_every", | ||
type=int, | ||
default=-1) | ||
|
||
args = parser.parse_args() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,9 @@ def __init__(self, | |
train_every=1, | ||
mlp_layers=None, | ||
learning_rate=0.00005, | ||
device=None): | ||
device=None, | ||
save_path=None, | ||
save_every=float('inf'),): | ||
|
||
''' | ||
Q-Learning algorithm for off-policy TD control using Function Approximation. | ||
|
@@ -81,6 +83,8 @@ def __init__(self, | |
mlp_layers (list): The layer number and the dimension of each layer in MLP | ||
learning_rate (float): The learning rate of the DQN agent. | ||
device (torch.device): whether to use the cpu or gpu | ||
save_path (str): The path to save the model checkpoints | ||
save_every (int): Save the model every X training steps | ||
''' | ||
self.use_raw = False | ||
self.replay_memory_init_size = replay_memory_init_size | ||
|
@@ -105,7 +109,7 @@ def __init__(self, | |
|
||
# The epsilon decay scheduler | ||
self.epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please clean this blank |
||
# Create estimators | ||
self.q_estimator = Estimator(num_actions=num_actions, learning_rate=learning_rate, state_shape=state_shape, \ | ||
mlp_layers=mlp_layers, device=self.device) | ||
|
@@ -114,6 +118,10 @@ def __init__(self, | |
|
||
# Create replay memory | ||
self.memory = Memory(replay_memory_size, batch_size) | ||
|
||
# Checkpoint saving parameters | ||
self.save_path = save_path | ||
self.save_every = save_every | ||
|
||
def feed(self, ts): | ||
''' Store data in to replay buffer and train the agent. There are two stages. | ||
|
@@ -218,9 +226,16 @@ def train(self): | |
if self.train_t % self.update_target_estimator_every == 0: | ||
self.target_estimator = deepcopy(self.q_estimator) | ||
print("\nINFO - Copied model parameters to target network.") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please clean this |
||
self.train_t += 1 | ||
|
||
if self.save_path and self.train_t % self.save_every == 0: | ||
# To preserve every checkpoint separately, | ||
# add another argument to the function call parameterized by self.train_t | ||
self.save_checkpoint(self.save_path) | ||
print("\nINFO - Saved model checkpoint.") | ||
|
||
|
||
def feed_memory(self, state, action, reward, next_state, legal_actions, done): | ||
''' Feed transition to memory | ||
|
||
|
@@ -239,6 +254,83 @@ def set_device(self, device): | |
self.q_estimator.device = device | ||
self.target_estimator.device = device | ||
|
||
def checkpoint_attributes(self): | ||
''' | ||
Return the current checkpoint attributes (dict) | ||
Checkpoint attributes are used to save and restore the model in the middle of training | ||
Saves the model state dict, optimizer state dict, and all other instance variables | ||
''' | ||
|
||
return { | ||
'agent_type': 'DQNAgent', | ||
'q_estimator': self.q_estimator.checkpoint_attributes(), | ||
'memory': self.memory.checkpoint_attributes(), | ||
'total_t': self.total_t, | ||
'train_t': self.train_t, | ||
'epsilon_start': self.epsilons.min(), | ||
'epsilon_end': self.epsilons.max(), | ||
'epsilon_decay_steps': self.epsilon_decay_steps, | ||
'discount_factor': self.discount_factor, | ||
'update_target_estimator_every': self.update_target_estimator_every, | ||
'batch_size': self.batch_size, | ||
'num_actions': self.num_actions, | ||
'train_every': self.train_every, | ||
'device': self.device | ||
} | ||
|
||
@classmethod | ||
def from_checkpoint(cls, checkpoint): | ||
''' | ||
Restore the model from a checkpoint | ||
|
||
Args: | ||
checkpoint (dict): the checkpoint attributes generated by checkpoint_attributes() | ||
''' | ||
|
||
print("\nINFO - Restoring model from checkpoint...") | ||
agent_instance = cls( | ||
replay_memory_size=checkpoint['memory']['memory_size'], | ||
update_target_estimator_every=checkpoint['update_target_estimator_every'], | ||
discount_factor=checkpoint['discount_factor'], | ||
epsilon_start=checkpoint['epsilon_start'], | ||
epsilon_end=checkpoint['epsilon_end'], | ||
epsilon_decay_steps=checkpoint['epsilon_decay_steps'], | ||
batch_size=checkpoint['batch_size'], | ||
num_actions=checkpoint['num_actions'], | ||
device=checkpoint['device'], | ||
state_shape=checkpoint['q_estimator']['state_shape'], | ||
mlp_layers=checkpoint['q_estimator']['mlp_layers'], | ||
train_every=checkpoint['train_every'] | ||
) | ||
|
||
agent_instance.total_t = checkpoint['total_t'] | ||
agent_instance.train_t = checkpoint['train_t'] | ||
|
||
agent_instance.q_estimator = Estimator.from_checkpoint(checkpoint['q_estimator']) | ||
agent_instance.target_estimator = deepcopy(agent_instance.q_estimator) | ||
agent_instance.memory = Memory.from_checkpoint(checkpoint['memory']) | ||
|
||
|
||
return agent_instance | ||
|
||
|
||
|
||
def save(self, path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this function used? |
||
''' Save the model (q_estimator weights only) | ||
|
||
Args: | ||
path (str): the path to save the model | ||
''' | ||
torch.save(self.q_estimator.model.state_dict(), path) | ||
|
||
def save_checkpoint(self, path, filename='checkpoint_dqn.pt'): | ||
''' Save the model checkpoint (all attributes) | ||
|
||
Args: | ||
path (str): the path to save the model | ||
''' | ||
torch.save(self.checkpoint_attributes(), path + '/' + filename) | ||
|
||
class Estimator(object): | ||
''' | ||
Approximate clone of rlcard.agents.dqn_agent.Estimator that | ||
|
@@ -334,6 +426,35 @@ def update(self, s, a, y): | |
self.qnet.eval() | ||
|
||
return batch_loss | ||
|
||
def checkpoint_attributes(self): | ||
''' Return the attributes needed to restore the model from a checkpoint | ||
''' | ||
return { | ||
'qnet': self.qnet.state_dict(), | ||
'optimizer': self.optimizer.state_dict(), | ||
'num_actions': self.num_actions, | ||
'learning_rate': self.learning_rate, | ||
'state_shape': self.state_shape, | ||
'mlp_layers': self.mlp_layers, | ||
'device': self.device | ||
} | ||
|
||
@classmethod | ||
def from_checkpoint(cls, checkpoint): | ||
''' Restore the model from a checkpoint | ||
''' | ||
estimator = cls( | ||
num_actions=checkpoint['num_actions'], | ||
learning_rate=checkpoint['learning_rate'], | ||
state_shape=checkpoint['state_shape'], | ||
mlp_layers=checkpoint['mlp_layers'], | ||
device=checkpoint['device'] | ||
) | ||
|
||
estimator.qnet.load_state_dict(checkpoint['qnet']) | ||
estimator.optimizer.load_state_dict(checkpoint['optimizer']) | ||
return estimator | ||
|
||
|
||
class EstimatorNetwork(nn.Module): | ||
|
@@ -415,3 +536,29 @@ def sample(self): | |
samples = random.sample(self.memory, self.batch_size) | ||
samples = tuple(zip(*samples)) | ||
return tuple(map(np.array, samples[:-1])) + (samples[-1],) | ||
|
||
def checkpoint_attributes(self): | ||
''' Returns the attributes that need to be checkpointed | ||
''' | ||
|
||
return { | ||
'memory_size': self.memory_size, | ||
'batch_size': self.batch_size, | ||
'memory': self.memory | ||
} | ||
|
||
@classmethod | ||
def from_checkpoint(cls, checkpoint): | ||
''' | ||
Restores the attributes from the checkpoint | ||
|
||
Args: | ||
checkpoint (dict): the checkpoint dictionary | ||
|
||
Returns: | ||
instance (Memory): the restored instance | ||
''' | ||
|
||
instance = cls(checkpoint['memory_size'], checkpoint['batch_size']) | ||
instance.memory = checkpoint['memory'] | ||
return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict
is preserved in Python. It is not a good practice. How about condense these three lines into one lineagent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))