From 74cf54cce30b791def7712eabd0c93c31eebf91b Mon Sep 17 00:00:00 2001 From: Guillaume Infantes Date: Thu, 11 Mar 2021 14:02:18 +0100 Subject: [PATCH] feat(torch): SWA for RANGER/torch (https://arxiv.org/abs/1803.05407) --- docs/api.md | 1 + src/backends/torch/optim/ranger.cc | 43 ++++++++++++++++++++++++++++-- src/backends/torch/optim/ranger.h | 10 +++++++ src/backends/torch/torchlib.cc | 11 ++++++-- src/backends/torch/torchsolver.cc | 7 +++++ src/backends/torch/torchsolver.h | 19 +++++++++++++ 6 files changed, 87 insertions(+), 4 deletions(-) diff --git a/docs/api.md b/docs/api.md index 3794c1c0c..9aea6dff9 100644 --- a/docs/api.md +++ b/docs/api.md @@ -773,6 +773,7 @@ adabelief | bool | yes | false for RANGER, true for RANGER_PLUS | f gradient_centralization | bool | yes | false for RANGER, true for RANGER_PLUS| for RANGER* : enable/disable gradient centralization sam | bool | yes | false | Sharpness Aware Minimization (https://arxiv.org/abs/2010.01412) sam_rho | real | yes | 0.05 | neighborhood size for SAM (see above) +swa | bool | yes | false | SWA https://arxiv.org/abs/1803.05407 , implemented only for RANGER / RANGER_PLUS solver types. test_interval | int | yes | N/A | Number of iterations between testing phases base_lr | real | yes | N/A | Initial learning rate iter_size | int | yes | 1 | Number of passes (iter_size * batch_size) at every iteration diff --git a/src/backends/torch/optim/ranger.cc b/src/backends/torch/optim/ranger.cc index 2b938ebb1..3ef5681f4 100644 --- a/src/backends/torch/optim/ranger.cc +++ b/src/backends/torch/optim/ranger.cc @@ -52,7 +52,8 @@ namespace dd && (lhs.lookahead() == rhs.lookahead()) && (lhs.adabelief() == rhs.adabelief()) && (lhs.gradient_centralization() == rhs.gradient_centralization()) - && (lhs.lsteps() == rhs.lsteps()) && (lhs.lalpha() == rhs.lalpha()); + && (lhs.lsteps() == rhs.lsteps()) && (lhs.lalpha() == rhs.lalpha()) + && (lhs.swa() == rhs.swa()); } void RangerOptions::serialize(torch::serialize::OutputArchive &archive) const @@ -68,6 +69,7 @@ namespace dd _TORCH_OPTIM_SERIALIZE_TORCH_ARG(gradient_centralization); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lsteps); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lalpha); + _TORCH_OPTIM_SERIALIZE_TORCH_ARG(swa); } void RangerOptions::serialize(torch::serialize::InputArchive &archive) @@ -83,6 +85,7 @@ namespace dd _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, gradient_centralization); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int, lsteps); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lalpha); + _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, swa); } bool operator==(const RangerParamState &lhs, const RangerParamState &rhs) @@ -90,7 +93,8 @@ namespace dd return ((lhs.step() == rhs.step()) && torch::equal(lhs.exp_avg(), rhs.exp_avg()) && torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) - && torch::equal(lhs.slow_buffer(), rhs.slow_buffer())); + && torch::equal(lhs.slow_buffer(), rhs.slow_buffer()) + && torch::equal(lhs.swa_buffer(), rhs.swa_buffer())); } void @@ -100,6 +104,7 @@ namespace dd _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(slow_buffer); + _TORCH_OPTIM_SERIALIZE_TORCH_ARG(swa_buffer); } void RangerParamState::serialize(torch::serialize::InputArchive &archive) @@ -108,12 +113,14 @@ namespace dd _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, exp_avg); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, exp_avg_sq); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, slow_buffer); + _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, swa_buffer); } torch::Tensor Ranger::step(LossClosure closure) { torch::NoGradGuard no_grad; torch::Tensor loss = {}; + if (closure != nullptr) { at::AutoGradMode enable_grad(true); @@ -151,6 +158,9 @@ namespace dd state->slow_buffer().copy_(p.data()); state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = std::move(state); + if (options.swa()) + state->swa_buffer(torch::zeros_like( + p.data(), torch::MemoryFormat::Preserve)); } auto &state = static_cast( @@ -227,11 +237,40 @@ namespace dd slow_p.add_(p.data() - slow_p, options.lalpha()); p.data().copy_(slow_p); } + + if (options.swa()) + { + auto &swa_buf = state.swa_buffer(); + double swa_decay = 1.0 / (state.step() + 1); + torch::Tensor diff = (p.data() - swa_buf) * swa_decay; + swa_buf.add_(diff); + } } } return loss; } + void Ranger::swap_swa_sgd() + { + for (auto &group : param_groups_) + { + auto &options = static_cast(group.options()); + if (!options.swa()) + continue; + for (auto &p : group.params()) + { + auto &state = static_cast( + *state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); + auto &swa_buf = state.swa_buffer(); + + auto tmp = torch::empty_like(p.data()); + tmp.copy_(p.data()); + p.data().copy_(swa_buf); + swa_buf.copy_(tmp); + } + } + } + void Ranger::save(torch::serialize::OutputArchive &archive) const { serialize(*this, archive); diff --git a/src/backends/torch/optim/ranger.h b/src/backends/torch/optim/ranger.h index 9931f28fb..d600711c1 100644 --- a/src/backends/torch/optim/ranger.h +++ b/src/backends/torch/optim/ranger.h @@ -19,6 +19,9 @@ * along with deepdetect. If not, see . */ +#ifndef RANGER_H +#define RANGER_H + #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #include @@ -57,6 +60,7 @@ namespace dd TORCH_ARG(bool, gradient_centralization) = false; TORCH_ARG(int, lsteps) = 6; TORCH_ARG(double, lalpha) = 0.5; + TORCH_ARG(bool, swa) = false; public: void serialize(torch::serialize::InputArchive &archive) override; @@ -73,6 +77,7 @@ namespace dd TORCH_ARG(torch::Tensor, exp_avg); TORCH_ARG(torch::Tensor, exp_avg_sq); TORCH_ARG(torch::Tensor, slow_buffer); + TORCH_ARG(torch::Tensor, swa_buffer); public: void serialize(torch::serialize::InputArchive &archive) override; @@ -118,11 +123,16 @@ namespace dd void save(torch::serialize::OutputArchive &archive) const override; void load(torch::serialize::InputArchive &archive) override; + void swap_swa_sgd(); + private: template static void serialize(Self &self, Archive &archive) { _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Ranger); } + bool swa_in_params = false; }; } // namespace dd + +#endif diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 0eb4b6981..5f67b18ec 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -506,10 +506,13 @@ namespace dd TMLModel>::snapshot(int64_t elapsed_it, TorchSolver &tsolver) { this->_logger->info("Saving checkpoint after {} iterations", elapsed_it); + // solver is allowed to modify net during eval()/train() => do this call + // before saving net itself + tsolver.eval(); this->_module.save_checkpoint(this->_mlmodel, std::to_string(elapsed_it)); - // Save optimizer tsolver.save(this->_mlmodel._repo + "/solver-" + std::to_string(elapsed_it) + ".pt"); + tsolver.train(); } template _logger->info("Start test"); tstart = steady_clock::now(); + tsolver.eval(); test(ad, inputc, eval_dataset, test_batch_size, meas_out); + tsolver.train(); last_test_time = duration_cast( steady_clock::now() - tstart) .count(); @@ -891,7 +896,9 @@ namespace dd } } if (!snapshotted) - snapshot(elapsed_it, tsolver); + { + snapshot(elapsed_it, tsolver); + } } ++it; diff --git a/src/backends/torch/torchsolver.cc b/src/backends/torch/torchsolver.cc index 9d67b1721..bc1b1e339 100644 --- a/src/backends/torch/torchsolver.cc +++ b/src/backends/torch/torchsolver.cc @@ -70,12 +70,16 @@ namespace dd _sam = ad_solver.get("sam").get(); if (ad_solver.has("sam_rho")) _sam_rho = ad_solver.get("sam_rho").get(); + if (ad_solver.has("swa")) + _swa = ad_solver.get("swa").get(); create(); } void TorchSolver::create() { + bool want_swa = true; + _swa = false; this->_logger->info("Selected solver type: {}", _solver_type); _params = _module.parameters(); @@ -107,6 +111,8 @@ namespace dd } else if (_solver_type == "RANGER" || _solver_type == "RANGER_PLUS") { + if (want_swa) + _swa = true; _optimizer = std::unique_ptr( new Ranger(_params, RangerOptions(_base_lr) .betas(std::make_tuple(_beta1, _beta2)) @@ -251,6 +257,7 @@ namespace dd try { torch::load(*_optimizer, sstate, device); + this->train(); } catch (std::exception &e) { diff --git a/src/backends/torch/torchsolver.h b/src/backends/torch/torchsolver.h index 15295b47f..9a590db60 100644 --- a/src/backends/torch/torchsolver.h +++ b/src/backends/torch/torchsolver.h @@ -30,6 +30,7 @@ #include "apidata.h" #include "torchmodule.h" #include "torchloss.h" +#include "optim/ranger.h" #define DEFAULT_CLIP_VALUE 5.0 #define DEFAULT_CLIP_NORM 100.0 @@ -105,6 +106,16 @@ namespace dd return _base_lr; } + void eval() + { + swap_swa_sgd(); + } + + void train() + { + swap_swa_sgd(); + } + protected: /** * \brief allocates solver for real @@ -115,6 +126,12 @@ namespace dd void sam_first_step(); void sam_second_step(); + void swap_swa_sgd() + { + if (_swa) + (reinterpret_cast(_optimizer.get()))->swap_swa_sgd(); + } + std::vector _sam_ew; std::vector _params; /**< list of parameter to optimize, @@ -141,6 +158,8 @@ namespace dd bool _sam = false; double _sam_rho = DEFAULT_SAM_RHO; + bool _swa = false; /**< stochastic weights averaging 1803.05407 */ + TorchModule &_module; TorchLoss &_tloss; std::vector