Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiks committed Apr 18, 2023
1 parent 5e6d715 commit 9ec4e9a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 23 deletions.
13 changes: 4 additions & 9 deletions examples/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
plot_curve,
)


def train(args):

# Check whether gpu is available
Expand All @@ -37,9 +36,7 @@ def train(args):
if args.algorithm == 'dqn':
from rlcard.agents import DQNAgent
if args.load_checkpoint_path != "":
dict = torch.load(args.load_checkpoint_path)
agent = DQNAgent.from_checkpoint(checkpoint = dict)
del dict
agent = DQNAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
else:
agent = DQNAgent(
num_actions=env.num_actions,
Expand All @@ -53,9 +50,7 @@ def train(args):
elif args.algorithm == 'nfsp':
from rlcard.agents import NFSPAgent
if args.load_checkpoint_path != "":
dict = torch.load(args.load_checkpoint_path)
agent = NFSPAgent.from_checkpoint(checkpoint = dict)
del dict
agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
else:
agent = NFSPAgent(
num_actions=env.num_actions,
Expand All @@ -64,7 +59,7 @@ def train(args):
q_mlp_layers=[64,64],
device=device,
save_path=args.log_dir,
save_every=500
save_every=args.save_every
)
agents = [agent]
for _ in range(1, env.num_players):
Expand Down Expand Up @@ -111,7 +106,7 @@ def train(args):
torch.save(agent, save_path)
print('Model saved in', save_path)

if __name__ == '__main__':
if __name__ == '__main__':
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
parser.add_argument(
'--env',
Expand Down
16 changes: 3 additions & 13 deletions rlcard/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self,

# The epsilon decay scheduler
self.epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps)

# Create estimators
self.q_estimator = Estimator(num_actions=num_actions, learning_rate=learning_rate, state_shape=state_shape, \
mlp_layers=mlp_layers, device=self.device)
Expand Down Expand Up @@ -226,7 +226,7 @@ def train(self):
if self.train_t % self.update_target_estimator_every == 0:
self.target_estimator = deepcopy(self.q_estimator)
print("\nINFO - Copied model parameters to target network.")

self.train_t += 1

if self.save_path and self.train_t % self.save_every == 0:
Expand Down Expand Up @@ -312,17 +312,7 @@ def from_checkpoint(cls, checkpoint):


return agent_instance



def save(self, path):
''' Save the model (q_estimator weights only)
Args:
path (str): the path to save the model
'''
torch.save(self.q_estimator.model.state_dict(), path)


def save_checkpoint(self, path, filename='checkpoint_dqn.pt'):
''' Save the model checkpoint (all attributes)
Expand Down
2 changes: 1 addition & 1 deletion rlcard/agents/nfsp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self,

# Total timesteps
self.total_t = 0

# Total training step
self.train_t = 0

Expand Down

0 comments on commit 9ec4e9a

Please sign in to comment.