Skip to content

Commit

Permalink
Merge pull request BVLC#1228 from longjon/solver-step
Browse files Browse the repository at this point in the history
Refactor Solver to allow interactive stepping
  • Loading branch information
longjon committed Jan 1, 2015
2 parents 58e0d33 + 033bafe commit dcbe129
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 28 deletions.
5 changes: 4 additions & 1 deletion include/caffe/solver.hpp
Expand Up @@ -26,6 +26,7 @@ class Solver {
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
Expand All @@ -36,7 +37,7 @@ class Solver {
protected:
// PreSolve is run before any solving iteration starts, allowing one to
// put up some scaffold.
virtual void PreSolve() {}
virtual void PreSolve();
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
Expand All @@ -57,9 +58,11 @@ class Solver {

SolverParameter param_;
int iter_;
int start_iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
bool initialized_;

DISABLE_COPY_AND_ASSIGN(Solver);
};
Expand Down
3 changes: 2 additions & 1 deletion python/caffe/_caffe.cpp
Expand Up @@ -197,7 +197,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("test_nets", &PySGDSolver::test_nets)
.add_property("iter", &PySGDSolver::iter)
.def("solve", &PySGDSolver::Solve)
.def("solve", &PySGDSolver::SolveResume);
.def("solve", &PySGDSolver::SolveResume)
.def("step", &PySGDSolver::Step);

bp::class_<vector<shared_ptr<PyNet> > >("NetVec")
.def(bp::vector_indexing_suite<vector<shared_ptr<PyNet> >, true>());
Expand Down
1 change: 1 addition & 0 deletions python/caffe/_caffe.hpp
Expand Up @@ -181,6 +181,7 @@ class PySGDSolver {
vector<shared_ptr<PyNet> > test_nets() { return test_nets_; }
int iter() { return solver_->iter(); }
void Solve() { return solver_->Solve(); }
void Step(int iters) { solver_->Step(iters); }
void SolveResume(const string& resume_file);

protected:
Expand Down
64 changes: 38 additions & 26 deletions src/caffe/solver.cpp
Expand Up @@ -29,9 +29,12 @@ Solver<Dtype>::Solver(const string& param_file)

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
initialized_ = false;
iter_ = 0;
LOG(INFO) << "Initializing solver from parameters: " << std::endl
<< param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
if (param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
Expand Down Expand Up @@ -155,35 +158,20 @@ void Solver<Dtype>::InitTestNets() {
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
PreSolve();

iter_ = 0;
current_step_ = 0;
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
void Solver<Dtype>::Step(int iters) {
if (!initialized_) {
PreSolve();
}
// Remember the initial iter_ value; will be non-zero if we loaded from a
// resume_file above.
const int start_iter = iter_;

vector<Blob<Dtype>*> bottom_vec;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();

CHECK_GE(average_loss, 1) << "average_loss should be non-negative.";

vector<Dtype> losses;
Dtype smoothed_loss = 0;

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
for (; iter_ < param_.max_iter(); ++iter_) {
for (; iter_ < stop_iter; ++iter_) {
// Save a snapshot if needed.
if (param_.snapshot() && iter_ > start_iter &&
if (param_.snapshot() && iter_ > start_iter_ &&
iter_ % param_.snapshot() == 0) {
Snapshot();
}
Expand Down Expand Up @@ -226,7 +214,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
int size = losses.size();
smoothed_loss = (smoothed_loss * (size - 1) + loss) / size;
} else {
int idx = (iter_ - start_iter) % average_loss;
int idx = (iter_ - start_iter_) % average_loss;
smoothed_loss += (loss - losses[idx]) / average_loss;
losses[idx] = loss;
}
Expand All @@ -252,10 +240,33 @@ void Solver<Dtype>::Solve(const char* resume_file) {
}
}
}

ComputeUpdateValue();
net_->Update();
}
}

template <typename Dtype>
void Solver<Dtype>::PreSolve() {
initialized_ = true;
start_iter_ = iter_ = 0;
current_step_ = 0;
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

PreSolve();
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
Step(param_.max_iter() - iter_);
// Always save a snapshot after optimization, unless overridden by setting
// snapshot_after_train := false.
if (param_.snapshot_after_train()) { Snapshot(); }
Expand All @@ -267,7 +278,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == 0) {
Dtype loss;
net_->Forward(bottom_vec, &loss);
net_->ForwardPrefilled(&loss);
LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
Expand Down Expand Up @@ -377,7 +388,7 @@ void Solver<Dtype>::Restore(const char* state_file) {
ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
net_->CopyTrainedLayersFrom(net_param);
}
iter_ = state.iter();
start_iter_ = iter_ = state.iter();
current_step_ = state.current_step();
RestoreSolverState(state);
}
Expand Down Expand Up @@ -439,6 +450,7 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {

template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
Solver<Dtype>::PreSolve();
// Initialize the history
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
history_.clear();
Expand Down

0 comments on commit dcbe129

Please sign in to comment.