Skip to content

Commit

Permalink
feat(torch): implement resume mllib option for torchlib: if true, reu…
Browse files Browse the repository at this point in the history
…se previous solver state
  • Loading branch information
fantes authored and mergify[bot] committed Nov 19, 2020
1 parent 72c0269 commit 02e3177
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Expand Up @@ -772,6 +772,7 @@ gradient_centralization | bool | yes | false for RANGER, true for RANGER_PLUS| f
test_interval | int | yes | N/A | Number of iterations between testing phases
base_lr | real | yes | N/A | Initial learning rate
iter_size | int | yes | 1 | Number of passes (iter_size * batch_size) at every iteration
resume | bool | yes | false | Whether to resume training from solver state

Net:

Expand Down
23 changes: 11 additions & 12 deletions src/backends/torch/torchlib.cc
Expand Up @@ -351,7 +351,7 @@ namespace dd
TMLModel>::save_if_best(APIData &meas_out,
int64_t elapsed_it,
TorchSolver &tsolver,
int64_t best_to_remove)
int64_t best_iteration_number)
{
double cur_meas = std::numeric_limits<double>::infinity();
std::string meas;
Expand All @@ -374,9 +374,9 @@ namespace dd
if (_best_metric_value == std::numeric_limits<double>::infinity()
|| is_better(cur_meas, _best_metric_value, meas))
{
if (best_to_remove != -1)
if (best_iteration_number != -1)
{
remove_model(best_to_remove);
remove_model(best_iteration_number);
}
_best_metric_value = cur_meas;
this->snapshot(elapsed_it, tsolver);
Expand All @@ -396,7 +396,7 @@ namespace dd
}
return elapsed_it;
}
return best_to_remove;
return best_iteration_number;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand Down Expand Up @@ -535,10 +535,10 @@ namespace dd

// create solver
tsolver.create(_module);
int64_t best_iteration_number = -1;

int it = 0;
// reload solver and set it value accordingly
it = tsolver.load(this->_mlmodel._sstate, _main_device);
int it = tsolver.resume(ad_mllib, this->_mlmodel, _main_device,
_best_metric_value, best_iteration_number);

bool skip_training = it >= iterations;
if (skip_training)
Expand All @@ -560,7 +560,6 @@ namespace dd
std::move(inputc._dataset), data::DataLoaderOptions(batch_size));

int batch_id = 0;
int64_t best_to_remove = -1;

// it is the iteration count (not epoch)
while (it < iterations)
Expand Down Expand Up @@ -717,8 +716,8 @@ namespace dd
APIData meas_obj = meas_out.getobj("measure");
std::vector<std::string> meas_names = meas_obj.list_keys();

best_to_remove = save_if_best(meas_obj, elapsed_it,
tsolver, best_to_remove);
best_iteration_number = save_if_best(
meas_obj, elapsed_it, tsolver, best_iteration_number);

for (auto name : meas_names)
{
Expand Down Expand Up @@ -761,11 +760,11 @@ namespace dd
if ((save_period != 0 && elapsed_it % save_period == 0)
|| elapsed_it == iterations)
{
if (best_to_remove == elapsed_it)
if (best_iteration_number == elapsed_it)
// current model already snapshoted as best model,
// do not remove regular snapshot if it is best
// model
best_to_remove = -1;
best_iteration_number = -1;
else
snapshot(elapsed_it, tsolver);
}
Expand Down
55 changes: 55 additions & 0 deletions src/backends/torch/torchsolver.cc
Expand Up @@ -198,4 +198,59 @@ namespace dd
}
return 0;
}

int TorchSolver::resume(const APIData &ad_mllib, const TorchModel &mlmodel,
const torch::Device &main_device,
double &best_metric_value,
int64_t &best_iteration_number)
{
// reload solver if asked for and set it value accordingly
if (ad_mllib.has("resume") && ad_mllib.get("resume").get<bool>())
{
std::string bestfilename
= mlmodel._repo + mlmodel._best_model_filename;
std::ifstream bestfile;
try
{
std::string tmp;
std::string bin;
bestfile.open(bestfilename, std::ios::in);
// first three fields are thrown away
bestfile >> tmp >> bin >> tmp >> tmp;
bestfile.close();
best_iteration_number = std::atof(bin.c_str());
best_metric_value = std::atof(tmp.c_str());
}
catch (std::exception &e)
{
_logger->info("no previous best model file");
}
if (mlmodel._sstate.empty())
{
throw MLLibBadParamException(
"resuming a model requires a solverstate file in model "
"repository");
}
else
try
{
return load(mlmodel._sstate, main_device);
}
catch (std::exception &e)
{
this->_logger->error("Failed to load solver state");
throw;
}
}
else if (!mlmodel._sstate.empty())
{
this->_logger->error("not resuming while a solverstate file remains "
"in model repository");
throw MLLibBadParamException(
"a solverstate file requires a resume argument for training, "
"otherwise delete existing training state files (with clear=lib) "
"to cleanup the model repository");
}
return 0;
}
}
10 changes: 10 additions & 0 deletions src/backends/torch/torchsolver.h
Expand Up @@ -70,6 +70,16 @@ namespace dd
*/
void save(std::string sfile);

/**
* \brief restore solver state, checks solverstate presence and returns
* iteration number, best metric value and corresponding iteration number
* according to ad_mllib.resume param
*/

int resume(const APIData &ad_mllib, const TorchModel &mlmodel,
const torch::Device &main_device, double &best_metric_value,
int64_t &best_iteration_number);

/**
* \brief zero_grad() indirection in order to mimic native optimizer
* behavior
Expand Down
55 changes: 55 additions & 0 deletions tests/ut-torchapi.cc
Expand Up @@ -547,6 +547,61 @@ TEST(torchapi, service_train_csvts_nbeats)
rmdir(csvts_nbeats_repo.c_str());
}

TEST(torchapi, service_train_csvts_nbeats_resume_fail)
{
// create service
JsonAPI japi;
std::string sname = "nbeats";
std::string csvts_data = sinus + "train";
std::string csvts_test = sinus + "test";
std::string csvts_predict = sinus + "predict";
std::string csvts_nbeats_repo = "csvts_nbeats";
mkdir(csvts_nbeats_repo.c_str(), 0777);

std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"nbeats\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ csvts_nbeats_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"csvts\",\"label\":["
"\"output\"],\"timesteps\":50},\"mllib\":{\"template\":\"nbeats\","
"\"template_params\":[\"t2\",\"s4\",\"g3\",\"b3\"],"
"\"loss\":\"L1\"}}}";

std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// train
std::string jtrainstr
= "{\"service\":\"" + sname
+ "\",\"async\":false,\"parameters\":{\"input\":{\"shuffle\":true,"
"\"separator\":\",\",\"scale\":true,\"timesteps\":50,\"label\":["
"\"output\"]},\"mllib\":{\"resume\":true,\"gpu\":false,\"solver\":{"
"\"iterations\":"
+ iterations_nbeats_cpu
+ ",\"test_interval\":10,\"base_lr\":0.1,\"snapshot\":500,\"test_"
"initialization\":false,\"solver_type\":\"ADAM\"},\"net\":{\"batch_"
"size\":2,\"test_batch_"
"size\":10}},\"output\":{\"measure\":[\"L1\",\"L2\"]}},\"data\":[\""
+ csvts_data + "\",\"" + csvts_test + "\"]}";

std::cerr << "jtrainstr=" << jtrainstr << std::endl;
joutstr = japi.jrender(japi.service_train(jtrainstr));
std::cout << "joutstr=" << joutstr << std::endl;
JDoc jd;
jd.Parse(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_TRUE(jd.HasMember("status"));
ASSERT_EQ(400, jd["status"]["code"].GetInt());
ASSERT_EQ("Service Bad Request Error: resuming a model requires a "
"solverstate file in model repository",
jd["status"]["dd_msg"]);
// remove service
jstr = "{\"clear\":\"full\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);
rmdir(csvts_nbeats_repo.c_str());
}

#if !defined(CPU_ONLY)

TEST(torchapi, service_train_csvts_nbeats_gpu)
Expand Down

0 comments on commit 02e3177

Please sign in to comment.