Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add storing and restoring RL agent checkpoints #280

Merged
merged 3 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions examples/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
plot_curve,
)


def train(args):

# Check whether gpu is available
Expand All @@ -35,21 +36,36 @@ def train(args):
# Initialize the agent and use random agents as opponents
if args.algorithm == 'dqn':
from rlcard.agents import DQNAgent
agent = DQNAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
device=device,
)
if args.load_checkpoint_path != "":
dict = torch.load(args.load_checkpoint_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict is preserved in Python. It is not a good practice. How about condense these three lines into one line

agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))

agent = DQNAgent.from_checkpoint(checkpoint = dict)
del dict
else:
agent = DQNAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
device=device,
save_path=args.log_dir,
save_every=args.save_every
)

elif args.algorithm == 'nfsp':
from rlcard.agents import NFSPAgent
agent = NFSPAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
hidden_layers_sizes=[64,64],
q_mlp_layers=[64,64],
device=device,
)
if args.load_checkpoint_path != "":
dict = torch.load(args.load_checkpoint_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict is preserved in Python. It is not a good practice. How about condense these three lines into one line

agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))

agent = NFSPAgent.from_checkpoint(checkpoint = dict)
del dict
else:
agent = NFSPAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
hidden_layers_sizes=[64,64],
q_mlp_layers=[64,64],
device=device,
save_path=args.log_dir,
save_every=500
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be save_every=args.save_every?

)
agents = [agent]
for _ in range(1, env.num_players):
agents.append(RandomAgent(num_actions=env.num_actions))
Expand Down Expand Up @@ -95,7 +111,7 @@ def train(args):
torch.save(agent, save_path)
print('Model saved in', save_path)

if __name__ == '__main__':
if __name__ == '__main__':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This blank modification seems not needed

parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
parser.add_argument(
'--env',
Expand Down Expand Up @@ -152,6 +168,17 @@ def train(args):
type=str,
default='experiments/leduc_holdem_dqn_result/',
)

parser.add_argument(
"--load_checkpoint_path",
type=str,
default="",
)

parser.add_argument(
"--save_every",
type=int,
default=-1)

args = parser.parse_args()

Expand Down
153 changes: 150 additions & 3 deletions rlcard/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __init__(self,
train_every=1,
mlp_layers=None,
learning_rate=0.00005,
device=None):
device=None,
save_path=None,
save_every=float('inf'),):

'''
Q-Learning algorithm for off-policy TD control using Function Approximation.
Expand All @@ -81,6 +83,8 @@ def __init__(self,
mlp_layers (list): The layer number and the dimension of each layer in MLP
learning_rate (float): The learning rate of the DQN agent.
device (torch.device): whether to use the cpu or gpu
save_path (str): The path to save the model checkpoints
save_every (int): Save the model every X training steps
'''
self.use_raw = False
self.replay_memory_init_size = replay_memory_init_size
Expand All @@ -105,7 +109,7 @@ def __init__(self,

# The epsilon decay scheduler
self.epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean this blank

# Create estimators
self.q_estimator = Estimator(num_actions=num_actions, learning_rate=learning_rate, state_shape=state_shape, \
mlp_layers=mlp_layers, device=self.device)
Expand All @@ -114,6 +118,10 @@ def __init__(self,

# Create replay memory
self.memory = Memory(replay_memory_size, batch_size)

# Checkpoint saving parameters
self.save_path = save_path
self.save_every = save_every

def feed(self, ts):
''' Store data in to replay buffer and train the agent. There are two stages.
Expand Down Expand Up @@ -218,9 +226,16 @@ def train(self):
if self.train_t % self.update_target_estimator_every == 0:
self.target_estimator = deepcopy(self.q_estimator)
print("\nINFO - Copied model parameters to target network.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean this

self.train_t += 1

if self.save_path and self.train_t % self.save_every == 0:
# To preserve every checkpoint separately,
# add another argument to the function call parameterized by self.train_t
self.save_checkpoint(self.save_path)
print("\nINFO - Saved model checkpoint.")


def feed_memory(self, state, action, reward, next_state, legal_actions, done):
''' Feed transition to memory

Expand All @@ -239,6 +254,83 @@ def set_device(self, device):
self.q_estimator.device = device
self.target_estimator.device = device

def checkpoint_attributes(self):
'''
Return the current checkpoint attributes (dict)
Checkpoint attributes are used to save and restore the model in the middle of training
Saves the model state dict, optimizer state dict, and all other instance variables
'''

return {
'agent_type': 'DQNAgent',
'q_estimator': self.q_estimator.checkpoint_attributes(),
'memory': self.memory.checkpoint_attributes(),
'total_t': self.total_t,
'train_t': self.train_t,
'epsilon_start': self.epsilons.min(),
'epsilon_end': self.epsilons.max(),
'epsilon_decay_steps': self.epsilon_decay_steps,
'discount_factor': self.discount_factor,
'update_target_estimator_every': self.update_target_estimator_every,
'batch_size': self.batch_size,
'num_actions': self.num_actions,
'train_every': self.train_every,
'device': self.device
}

@classmethod
def from_checkpoint(cls, checkpoint):
'''
Restore the model from a checkpoint

Args:
checkpoint (dict): the checkpoint attributes generated by checkpoint_attributes()
'''

print("\nINFO - Restoring model from checkpoint...")
agent_instance = cls(
replay_memory_size=checkpoint['memory']['memory_size'],
update_target_estimator_every=checkpoint['update_target_estimator_every'],
discount_factor=checkpoint['discount_factor'],
epsilon_start=checkpoint['epsilon_start'],
epsilon_end=checkpoint['epsilon_end'],
epsilon_decay_steps=checkpoint['epsilon_decay_steps'],
batch_size=checkpoint['batch_size'],
num_actions=checkpoint['num_actions'],
device=checkpoint['device'],
state_shape=checkpoint['q_estimator']['state_shape'],
mlp_layers=checkpoint['q_estimator']['mlp_layers'],
train_every=checkpoint['train_every']
)

agent_instance.total_t = checkpoint['total_t']
agent_instance.train_t = checkpoint['train_t']

agent_instance.q_estimator = Estimator.from_checkpoint(checkpoint['q_estimator'])
agent_instance.target_estimator = deepcopy(agent_instance.q_estimator)
agent_instance.memory = Memory.from_checkpoint(checkpoint['memory'])


return agent_instance



def save(self, path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function used?

''' Save the model (q_estimator weights only)

Args:
path (str): the path to save the model
'''
torch.save(self.q_estimator.model.state_dict(), path)

def save_checkpoint(self, path, filename='checkpoint_dqn.pt'):
''' Save the model checkpoint (all attributes)

Args:
path (str): the path to save the model
'''
torch.save(self.checkpoint_attributes(), path + '/' + filename)

class Estimator(object):
'''
Approximate clone of rlcard.agents.dqn_agent.Estimator that
Expand Down Expand Up @@ -334,6 +426,35 @@ def update(self, s, a, y):
self.qnet.eval()

return batch_loss

def checkpoint_attributes(self):
''' Return the attributes needed to restore the model from a checkpoint
'''
return {
'qnet': self.qnet.state_dict(),
'optimizer': self.optimizer.state_dict(),
'num_actions': self.num_actions,
'learning_rate': self.learning_rate,
'state_shape': self.state_shape,
'mlp_layers': self.mlp_layers,
'device': self.device
}

@classmethod
def from_checkpoint(cls, checkpoint):
''' Restore the model from a checkpoint
'''
estimator = cls(
num_actions=checkpoint['num_actions'],
learning_rate=checkpoint['learning_rate'],
state_shape=checkpoint['state_shape'],
mlp_layers=checkpoint['mlp_layers'],
device=checkpoint['device']
)

estimator.qnet.load_state_dict(checkpoint['qnet'])
estimator.optimizer.load_state_dict(checkpoint['optimizer'])
return estimator


class EstimatorNetwork(nn.Module):
Expand Down Expand Up @@ -415,3 +536,29 @@ def sample(self):
samples = random.sample(self.memory, self.batch_size)
samples = tuple(zip(*samples))
return tuple(map(np.array, samples[:-1])) + (samples[-1],)

def checkpoint_attributes(self):
''' Returns the attributes that need to be checkpointed
'''

return {
'memory_size': self.memory_size,
'batch_size': self.batch_size,
'memory': self.memory
}

@classmethod
def from_checkpoint(cls, checkpoint):
'''
Restores the attributes from the checkpoint

Args:
checkpoint (dict): the checkpoint dictionary

Returns:
instance (Memory): the restored instance
'''

instance = cls(checkpoint['memory_size'], checkpoint['batch_size'])
instance.memory = checkpoint['memory']
return instance
Loading