From 1488e95ae4e5fc135a742440fe81204c14e5b7e4 Mon Sep 17 00:00:00 2001 From: kaiks Date: Sat, 15 Apr 2023 19:14:02 +0200 Subject: [PATCH 1/3] Add storing and restoring RL agent checkpoints --- examples/run_rl.py | 55 +++++++++---- rlcard/agents/dqn_agent.py | 153 +++++++++++++++++++++++++++++++++++- rlcard/agents/nfsp_agent.py | 138 +++++++++++++++++++++++++++++++- 3 files changed, 325 insertions(+), 21 deletions(-) diff --git a/examples/run_rl.py b/examples/run_rl.py index ab3a23fa0..822b070bb 100644 --- a/examples/run_rl.py +++ b/examples/run_rl.py @@ -16,6 +16,7 @@ plot_curve, ) + def train(args): # Check whether gpu is available @@ -35,21 +36,36 @@ def train(args): # Initialize the agent and use random agents as opponents if args.algorithm == 'dqn': from rlcard.agents import DQNAgent - agent = DQNAgent( - num_actions=env.num_actions, - state_shape=env.state_shape[0], - mlp_layers=[64,64], - device=device, - ) + if args.load_checkpoint_path != "": + dict = torch.load(args.load_checkpoint_path) + agent = DQNAgent.from_checkpoint(checkpoint = dict) + del dict + else: + agent = DQNAgent( + num_actions=env.num_actions, + state_shape=env.state_shape[0], + mlp_layers=[64,64], + device=device, + save_path=args.log_dir, + save_every=args.save_every + ) + elif args.algorithm == 'nfsp': from rlcard.agents import NFSPAgent - agent = NFSPAgent( - num_actions=env.num_actions, - state_shape=env.state_shape[0], - hidden_layers_sizes=[64,64], - q_mlp_layers=[64,64], - device=device, - ) + if args.load_checkpoint_path != "": + dict = torch.load(args.load_checkpoint_path) + agent = NFSPAgent.from_checkpoint(checkpoint = dict) + del dict + else: + agent = NFSPAgent( + num_actions=env.num_actions, + state_shape=env.state_shape[0], + hidden_layers_sizes=[64,64], + q_mlp_layers=[64,64], + device=device, + save_path=args.log_dir, + save_every=500 + ) agents = [agent] for _ in range(1, env.num_players): agents.append(RandomAgent(num_actions=env.num_actions)) @@ -95,7 +111,7 @@ def train(args): torch.save(agent, save_path) print('Model saved in', save_path) -if __name__ == '__main__': +if __name__ == '__main__': parser = argparse.ArgumentParser("DQN/NFSP example in RLCard") parser.add_argument( '--env', @@ -152,6 +168,17 @@ def train(args): type=str, default='experiments/leduc_holdem_dqn_result/', ) + + parser.add_argument( + "--load_checkpoint_path", + type=str, + default="", + ) + + parser.add_argument( + "--save_every", + type=int, + default=-1) args = parser.parse_args() diff --git a/rlcard/agents/dqn_agent.py b/rlcard/agents/dqn_agent.py index c42dedd2d..53b1c4a7a 100644 --- a/rlcard/agents/dqn_agent.py +++ b/rlcard/agents/dqn_agent.py @@ -56,7 +56,9 @@ def __init__(self, train_every=1, mlp_layers=None, learning_rate=0.00005, - device=None): + device=None, + save_path=None, + save_every=-1): ''' Q-Learning algorithm for off-policy TD control using Function Approximation. @@ -81,6 +83,8 @@ def __init__(self, mlp_layers (list): The layer number and the dimension of each layer in MLP learning_rate (float): The learning rate of the DQN agent. device (torch.device): whether to use the cpu or gpu + save_path (str): The path to save the model checkpoints + save_every (int): Save the model every X training steps ''' self.use_raw = False self.replay_memory_init_size = replay_memory_init_size @@ -105,7 +109,7 @@ def __init__(self, # The epsilon decay scheduler self.epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps) - + # Create estimators self.q_estimator = Estimator(num_actions=num_actions, learning_rate=learning_rate, state_shape=state_shape, \ mlp_layers=mlp_layers, device=self.device) @@ -114,6 +118,10 @@ def __init__(self, # Create replay memory self.memory = Memory(replay_memory_size, batch_size) + + # Checkpoint saving parameters + self.save_path = save_path + self.save_every = save_every def feed(self, ts): ''' Store data in to replay buffer and train the agent. There are two stages. @@ -218,9 +226,16 @@ def train(self): if self.train_t % self.update_target_estimator_every == 0: self.target_estimator = deepcopy(self.q_estimator) print("\nINFO - Copied model parameters to target network.") - + self.train_t += 1 + if self.save_path and self.train_t % self.save_every == 0: + # To preserve every checkpoint separately, + # add another argument to the function call parameterized by self.train_t + self.save_checkpoint(self.save_path) + print("\nINFO - Saved model checkpoint.") + + def feed_memory(self, state, action, reward, next_state, legal_actions, done): ''' Feed transition to memory @@ -239,6 +254,83 @@ def set_device(self, device): self.q_estimator.device = device self.target_estimator.device = device + def checkpoint_attributes(self): + ''' + Return the current checkpoint attributes (dict) + Checkpoint attributes are used to save and restore the model in the middle of training + Saves the model state dict, optimizer state dict, and all other instance variables + ''' + + return { + 'agent_type': 'DQNAgent', + 'q_estimator': self.q_estimator.checkpoint_attributes(), + 'memory': self.memory.checkpoint_attributes(), + 'total_t': self.total_t, + 'train_t': self.train_t, + 'epsilon_start': self.epsilons.min(), + 'epsilon_end': self.epsilons.max(), + 'epsilon_decay_steps': self.epsilon_decay_steps, + 'discount_factor': self.discount_factor, + 'update_target_estimator_every': self.update_target_estimator_every, + 'batch_size': self.batch_size, + 'num_actions': self.num_actions, + 'train_every': self.train_every, + 'device': self.device + } + + @classmethod + def from_checkpoint(cls, checkpoint): + ''' + Restore the model from a checkpoint + + Args: + checkpoint (dict): the checkpoint attributes generated by checkpoint_attributes() + ''' + + print("\nINFO - Restoring model from checkpoint...") + agent_instance = cls( + replay_memory_size=checkpoint['memory']['memory_size'], + update_target_estimator_every=checkpoint['update_target_estimator_every'], + discount_factor=checkpoint['discount_factor'], + epsilon_start=checkpoint['epsilon_start'], + epsilon_end=checkpoint['epsilon_end'], + epsilon_decay_steps=checkpoint['epsilon_decay_steps'], + batch_size=checkpoint['batch_size'], + num_actions=checkpoint['num_actions'], + device=checkpoint['device'], + state_shape=checkpoint['q_estimator']['state_shape'], + mlp_layers=checkpoint['q_estimator']['mlp_layers'], + train_every=checkpoint['train_every'] + ) + + agent_instance.total_t = checkpoint['total_t'] + agent_instance.train_t = checkpoint['train_t'] + + agent_instance.q_estimator = Estimator.from_checkpoint(checkpoint['q_estimator']) + agent_instance.target_estimator = deepcopy(agent_instance.q_estimator) + agent_instance.memory = Memory.from_checkpoint(checkpoint['memory']) + + + return agent_instance + + + + def save(self, path): + ''' Save the model (q_estimator weights only) + + Args: + path (str): the path to save the model + ''' + torch.save(self.q_estimator.model.state_dict(), path) + + def save_checkpoint(self, path, filename='checkpoint_dqn.pt'): + ''' Save the model checkpoint (all attributes) + + Args: + path (str): the path to save the model + ''' + torch.save(self.checkpoint_attributes(), path + '/' + filename) + class Estimator(object): ''' Approximate clone of rlcard.agents.dqn_agent.Estimator that @@ -334,6 +426,35 @@ def update(self, s, a, y): self.qnet.eval() return batch_loss + + def checkpoint_attributes(self): + ''' Return the attributes needed to restore the model from a checkpoint + ''' + return { + 'qnet': self.qnet.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'num_actions': self.num_actions, + 'learning_rate': self.learning_rate, + 'state_shape': self.state_shape, + 'mlp_layers': self.mlp_layers, + 'device': self.device + } + + @classmethod + def from_checkpoint(cls, checkpoint): + ''' Restore the model from a checkpoint + ''' + estimator = cls( + num_actions=checkpoint['num_actions'], + learning_rate=checkpoint['learning_rate'], + state_shape=checkpoint['state_shape'], + mlp_layers=checkpoint['mlp_layers'], + device=checkpoint['device'] + ) + + estimator.qnet.load_state_dict(checkpoint['qnet']) + estimator.optimizer.load_state_dict(checkpoint['optimizer']) + return estimator class EstimatorNetwork(nn.Module): @@ -415,3 +536,29 @@ def sample(self): samples = random.sample(self.memory, self.batch_size) samples = tuple(zip(*samples)) return tuple(map(np.array, samples[:-1])) + (samples[-1],) + + def checkpoint_attributes(self): + ''' Returns the attributes that need to be checkpointed + ''' + + return { + 'memory_size': self.memory_size, + 'batch_size': self.batch_size, + 'memory': self.memory + } + + @classmethod + def from_checkpoint(cls, checkpoint): + ''' + Restores the attributes from the checkpoint + + Args: + checkpoint (dict): the checkpoint dictionary + + Returns: + instance (Memory): the restored instance + ''' + + instance = cls(checkpoint['memory_size'], checkpoint['batch_size']) + instance.memory = checkpoint['memory'] + return instance diff --git a/rlcard/agents/nfsp_agent.py b/rlcard/agents/nfsp_agent.py index 0e98b2c4f..258ba8775 100644 --- a/rlcard/agents/nfsp_agent.py +++ b/rlcard/agents/nfsp_agent.py @@ -62,7 +62,9 @@ def __init__(self, q_train_every=1, q_mlp_layers=None, evaluate_with='average_policy', - device=None): + device=None, + save_path=None, + save_every=-1): ''' Initialize the NFSP agent. Args: @@ -112,9 +114,9 @@ def __init__(self, # Total timesteps self.total_t = 0 - - # Step counter to keep track of learning. - self._step_counter = 0 + + # Total training step + self.train_t = 0 # Build the action-value network self._rl_agent = DQNAgent(q_replay_memory_size, q_replay_memory_init_size, \ @@ -126,6 +128,10 @@ def __init__(self, self._build_model() self.sample_episode_policy() + + # Checkpoint saving parameters + self.save_path = save_path + self.save_every = save_every def _build_model(self): ''' Build the average policy network @@ -282,11 +288,90 @@ def train_sl(self): ce_loss = ce_loss.item() self.policy_network.eval() + self.train_t += 1 + + if self.save_path and self.train_t % self.save_every == 0: + # To preserve every checkpoint separately, + # add another argument to the function call parameterized by self.train_t + self.save_checkpoint(self.save_path) + print("\nINFO - Saved model checkpoint.") + return ce_loss def set_device(self, device): self.device = device self._rl_agent.set_device(device) + + def checkpoint_attributes(self): + ''' + Return the current checkpoint attributes (dict) + Checkpoint attributes are used to save and restore the model in the middle of training + Saves the model state dict, optimizer state dict, and all other instance variables + ''' + + return { + 'agent_type': 'NFSPAgent', + 'policy_network': self.policy_network.checkpoint_attributes(), + 'reservoir_buffer': self._reservoir_buffer.checkpoint_attributes(), + 'rl_agent': self._rl_agent.checkpoint_attributes(), + 'policy_network_optimizer': self.policy_network_optimizer.state_dict(), + 'device': self.device, + 'anticipatory_param': self._anticipatory_param, + 'batch_size': self._batch_size, + 'min_buffer_size_to_learn': self._min_buffer_size_to_learn, + 'num_actions': self._num_actions, + 'mode': self._mode, + 'evaluate_with': self.evaluate_with, + 'total_t': self.total_t, + 'train_t': self.train_t, + 'sl_learning_rate': self._sl_learning_rate, + 'train_every': self._train_every, + } + + @classmethod + def from_checkpoint(cls, checkpoint): + ''' + Restore the model from a checkpoint + + Args: + checkpoint (dict): the checkpoint attributes generated by checkpoint_attributes() + ''' + print("\nINFO - Restoring model from checkpoint...") + agent = cls( + anticipatory_param=checkpoint['anticipatory_param'], + batch_size=checkpoint['batch_size'], + min_buffer_size_to_learn=checkpoint['min_buffer_size_to_learn'], + num_actions=checkpoint['num_actions'], + sl_learning_rate=checkpoint['sl_learning_rate'], + train_every=checkpoint['train_every'], + evaluate_with=checkpoint['evaluate_with'], + device=checkpoint['device'], + q_mlp_layers=checkpoint['rl_agent']['q_estimator']['mlp_layers'], + state_shape=checkpoint['rl_agent']['q_estimator']['state_shape'], + hidden_layers_sizes=[], + ) + + agent.policy_network = AveragePolicyNetwork.from_checkpoint(checkpoint['policy_network']) + agent._reservoir_buffer = ReservoirBuffer.from_checkpoint(checkpoint['reservoir_buffer']) + agent._mode = checkpoint['mode'] + agent.total_t = checkpoint['total_t'] + agent.train_t = checkpoint['train_t'] + agent.policy_network.to(agent.device) + agent.policy_network.eval() + agent.policy_network_optimizer = torch.optim.Adam(agent.policy_network.parameters(), lr=agent._sl_learning_rate) + agent.policy_network_optimizer.load_state_dict(checkpoint['policy_network_optimizer']) + agent._rl_agent.from_checkpoint(checkpoint['rl_agent']) + agent._rl_agent.set_device(agent.device) + return agent + + def save_checkpoint(self, path, filename='checkpoint_nfsp.pt'): + ''' Save the model checkpoint (all attributes) + + Args: + path (str): the path to save the model + ''' + torch.save(self.checkpoint_attributes(), path + '/' + filename) + class AveragePolicyNetwork(nn.Module): ''' @@ -333,6 +418,37 @@ def forward(self, s): logits = self.mlp(s) log_action_probs = F.log_softmax(logits, dim=-1) return log_action_probs + + def checkpoint_attributes(self): + ''' + Return the current checkpoint attributes (dict) + Checkpoint attributes are used to save and restore the model in the middle of training + ''' + + return { + 'num_actions': self.num_actions, + 'state_shape': self.state_shape, + 'mlp_layers': self.mlp_layers, + 'mlp': self.mlp.state_dict(), + } + + @classmethod + def from_checkpoint(cls, checkpoint): + ''' + Restore the model from a checkpoint + + Args: + checkpoint (dict): the checkpoint attributes generated by checkpoint_attributes() + ''' + + agent = cls( + num_actions=checkpoint['num_actions'], + state_shape=checkpoint['state_shape'], + mlp_layers=checkpoint['mlp_layers'], + ) + + agent.mlp.load_state_dict(checkpoint['mlp']) + return agent class ReservoirBuffer(object): ''' Allows uniform sampling over a stream of data. @@ -386,6 +502,20 @@ def clear(self): ''' self._data = [] self._add_calls = 0 + + def checkpoint_attributes(self): + return { + 'data': self._data, + 'add_calls': self._add_calls, + 'reservoir_buffer_capacity': self._reservoir_buffer_capacity, + } + + @classmethod + def from_checkpoint(cls, checkpoint): + reservoir_buffer = cls(checkpoint['reservoir_buffer_capacity']) + reservoir_buffer._data = checkpoint['data'] + reservoir_buffer._add_calls = checkpoint['add_calls'] + return reservoir_buffer def __len__(self): return len(self._data) From 5e6d71546e5d6ddab9a1f7ad128dde3f3475ca11 Mon Sep 17 00:00:00 2001 From: kaiks Date: Sun, 16 Apr 2023 15:29:56 +0200 Subject: [PATCH 2/3] Don't save models by default (fix) --- rlcard/agents/dqn_agent.py | 2 +- rlcard/agents/nfsp_agent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rlcard/agents/dqn_agent.py b/rlcard/agents/dqn_agent.py index 53b1c4a7a..14d3f5b90 100644 --- a/rlcard/agents/dqn_agent.py +++ b/rlcard/agents/dqn_agent.py @@ -58,7 +58,7 @@ def __init__(self, learning_rate=0.00005, device=None, save_path=None, - save_every=-1): + save_every=float('inf'),): ''' Q-Learning algorithm for off-policy TD control using Function Approximation. diff --git a/rlcard/agents/nfsp_agent.py b/rlcard/agents/nfsp_agent.py index 258ba8775..dc5d59ad3 100644 --- a/rlcard/agents/nfsp_agent.py +++ b/rlcard/agents/nfsp_agent.py @@ -64,7 +64,7 @@ def __init__(self, evaluate_with='average_policy', device=None, save_path=None, - save_every=-1): + save_every=float('inf')): ''' Initialize the NFSP agent. Args: From 9ec4e9a528273e39c38c1d076a07eb4b2f066ecb Mon Sep 17 00:00:00 2001 From: kaiks Date: Tue, 18 Apr 2023 22:52:01 +0200 Subject: [PATCH 3/3] PR feedback --- examples/run_rl.py | 13 ++++--------- rlcard/agents/dqn_agent.py | 16 +++------------- rlcard/agents/nfsp_agent.py | 2 +- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/examples/run_rl.py b/examples/run_rl.py index 822b070bb..3727f3ae8 100644 --- a/examples/run_rl.py +++ b/examples/run_rl.py @@ -16,7 +16,6 @@ plot_curve, ) - def train(args): # Check whether gpu is available @@ -37,9 +36,7 @@ def train(args): if args.algorithm == 'dqn': from rlcard.agents import DQNAgent if args.load_checkpoint_path != "": - dict = torch.load(args.load_checkpoint_path) - agent = DQNAgent.from_checkpoint(checkpoint = dict) - del dict + agent = DQNAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path)) else: agent = DQNAgent( num_actions=env.num_actions, @@ -53,9 +50,7 @@ def train(args): elif args.algorithm == 'nfsp': from rlcard.agents import NFSPAgent if args.load_checkpoint_path != "": - dict = torch.load(args.load_checkpoint_path) - agent = NFSPAgent.from_checkpoint(checkpoint = dict) - del dict + agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path)) else: agent = NFSPAgent( num_actions=env.num_actions, @@ -64,7 +59,7 @@ def train(args): q_mlp_layers=[64,64], device=device, save_path=args.log_dir, - save_every=500 + save_every=args.save_every ) agents = [agent] for _ in range(1, env.num_players): @@ -111,7 +106,7 @@ def train(args): torch.save(agent, save_path) print('Model saved in', save_path) -if __name__ == '__main__': +if __name__ == '__main__': parser = argparse.ArgumentParser("DQN/NFSP example in RLCard") parser.add_argument( '--env', diff --git a/rlcard/agents/dqn_agent.py b/rlcard/agents/dqn_agent.py index 14d3f5b90..c33160cc1 100644 --- a/rlcard/agents/dqn_agent.py +++ b/rlcard/agents/dqn_agent.py @@ -109,7 +109,7 @@ def __init__(self, # The epsilon decay scheduler self.epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps) - + # Create estimators self.q_estimator = Estimator(num_actions=num_actions, learning_rate=learning_rate, state_shape=state_shape, \ mlp_layers=mlp_layers, device=self.device) @@ -226,7 +226,7 @@ def train(self): if self.train_t % self.update_target_estimator_every == 0: self.target_estimator = deepcopy(self.q_estimator) print("\nINFO - Copied model parameters to target network.") - + self.train_t += 1 if self.save_path and self.train_t % self.save_every == 0: @@ -312,17 +312,7 @@ def from_checkpoint(cls, checkpoint): return agent_instance - - - - def save(self, path): - ''' Save the model (q_estimator weights only) - - Args: - path (str): the path to save the model - ''' - torch.save(self.q_estimator.model.state_dict(), path) - + def save_checkpoint(self, path, filename='checkpoint_dqn.pt'): ''' Save the model checkpoint (all attributes) diff --git a/rlcard/agents/nfsp_agent.py b/rlcard/agents/nfsp_agent.py index dc5d59ad3..60f772dd9 100644 --- a/rlcard/agents/nfsp_agent.py +++ b/rlcard/agents/nfsp_agent.py @@ -114,7 +114,7 @@ def __init__(self, # Total timesteps self.total_t = 0 - + # Total training step self.train_t = 0