Skip to content

Commit

Permalink
Merge pull request #850 from sabify:patch-6
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452752859
Change-Id: I18f1a0975180ef140ee6f949ac7d35d40f354547
  • Loading branch information
lanctot committed Jun 4, 2022
2 parents dc25645 + d9e3a85 commit e0b4cff
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
17 changes: 17 additions & 0 deletions open_spiel/algorithms/dqn_torch/dqn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,23 @@ void DQN::Learn() {
optimizer_.step();
}

void DQN::Load(const std::string& data_path,
const std::string& optimizer_data_path) {
torch::load(q_network_, data_path);
torch::load(target_q_network_, data_path);
if (!optimizer_data_path.empty()) {
torch::load(optimizer_, optimizer_data_path);
}
}

void DQN::Save(const std::string& data_path,
const std::string& optimizer_data_path) {
torch::save(q_network_, data_path);
if (!optimizer_data_path.empty()) {
torch::save(optimizer_, optimizer_data_path);
}
}

std::vector<double> RunEpisodes(std::mt19937* rng, const Game& game,
const std::vector<Agent*>& agents,
int num_episodes, bool is_evaluation) {
Expand Down
7 changes: 7 additions & 0 deletions open_spiel/algorithms/dqn_torch/dqn.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ class DQN : public Agent {
double GetEpsilon(bool is_evaluation, int power = 1.0);
int seed() const { return seed_; }

// Load checkpoint/trained model and optimizer
void Load(const std::string& data_path,
const std::string& optimizer_data_path = "");
// Save checkpoint/trained model and optimizer
void Save(const std::string& data_path,
const std::string& optimizer_data_path = "");

private:
std::vector<float> GetInfoState(const State& state, Player player_id,
bool use_observation);
Expand Down
31 changes: 31 additions & 0 deletions open_spiel/python/pytorch/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,34 @@ def copy_with_noise(self, sigma=0.0, copy_weights=True):
for tq_model in target_q_network.model:
tq_model.weight *= (1 + sigma * torch.randn(tq_model.weight.shape))
return copied_object

def save(self, data_path, optimizer_data_path=None):
"""Save checkpoint/trained model and optimizer.
Args:
data_path: Path for saving model. It can be relative or absolute but the
filename should be included. For example: q_network.pt or
/path/to/q_network.pt
optimizer_data_path: Path for saving the optimizer states. It can be
relative or absolute but the filename should be included. For example:
optimizer.pt or /path/to/optimizer.pt
"""
torch.save(self._q_network, data_path)
if optimizer_data_path is not None:
torch.save(self._optimizer, optimizer_data_path)

def load(self, data_path, optimizer_data_path=None):
"""Load checkpoint/trained model and optimizer.
Args:
data_path: Path for loading model. It can be relative or absolute but the
filename should be included. For example: q_network.pt or
/path/to/q_network.pt
optimizer_data_path: Path for loading the optimizer states. It can be
relative or absolute but the filename should be included. For example:
optimizer.pt or /path/to/optimizer.pt
"""
torch.load(self._q_network, data_path)
torch.load(self._target_q_network, data_path)
if optimizer_data_path is not None:
torch.load(self._optimizer, optimizer_data_path)

0 comments on commit e0b4cff

Please sign in to comment.