Skip to content

Commit

Permalink
feat(torch): native models can load weights from any jit file
Browse files Browse the repository at this point in the history
Added a function "copy_weights" to copy common weights from a
script module to a nn::Module. Weights that do not match are
ignored, and printed to the logs as warning.
  • Loading branch information
Bycob authored and mergify[bot] committed Dec 22, 2020
1 parent 0358c4a commit 69af7f4
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/backends/torch/torchmodule.cc
Expand Up @@ -122,7 +122,8 @@ namespace dd
_logger->info("loading " + tmodel._native);
try
{
torch::load(_native, tmodel._native, _device);
torch_utils::load_weights(*_native, tmodel._native, _device,
_logger);
}
catch (std::exception &e)
{
Expand Down
68 changes: 68 additions & 0 deletions src/backends/torch/torchutils.cc
Expand Up @@ -26,6 +26,7 @@
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <fcntl.h>
#include <unordered_set>

using google::protobuf::io::CodedInputStream;
using google::protobuf::io::CodedOutputStream;
Expand Down Expand Up @@ -90,6 +91,7 @@ namespace dd
}
for (auto child : module->children())
{
// XXX(louis): why is "requires_grad" not passed here?
add_parameters(std::make_shared<torch::jit::script::Module>(child),
params);
}
Expand Down Expand Up @@ -119,5 +121,71 @@ namespace dd
return { output };
}
}

void copy_weights(const torch::jit::script::Module &from,
torch::nn::Module &to, const torch::Device &device,
std::shared_ptr<spdlog::logger> logger)
{
auto from_params = from.named_parameters();
auto to_params = to.named_parameters();
std::unordered_set<std::string> copied_params;

for (const auto &from_item : from_params)
{
torch::Tensor *to_value_ptr = to_params.find(from_item.name);

if (to_value_ptr == nullptr)
{
if (logger)
{
logger->warn("skipped " + from_item.name
+ ": not found in destination module");
}
continue;
}
torch::Tensor &to_value = *to_value_ptr;

if (from_item.value.sizes() != to_value.sizes())
{
if (logger)
{
std::stringstream sstream;
sstream << "skipped " << from_item.name
<< ": cannot copy tensor of size "
<< from_item.value.sizes() << " into tensor of size "
<< to_value.sizes();
logger->warn(sstream.str());
}
continue;
}

to_value.set_data(from_item.value.to(device));
copied_params.insert(from_item.name);
if (logger)
logger->info("copied " + from_item.name);
}

if (copied_params.empty())
{
throw MLLibBadParamException(
"No weights were copied: models do not match.");
}

for (const auto &param_name : to_params.keys())
{
if (copied_params.find(param_name) == copied_params.end())
{
logger->warn(param_name + " was not found in source module.");
}
}
}

void load_weights(torch::nn::Module &module, const std::string &filename,
const torch::Device &device,
std::shared_ptr<spdlog::logger> logger)
{
auto jit_module = torch::jit::load(filename, device);
torch_utils::copy_weights(jit_module, module, device, logger);
}
}
}
8 changes: 8 additions & 0 deletions src/backends/torch/torchutils.h
Expand Up @@ -94,6 +94,14 @@ namespace dd
bool requires_grad = true);

std::vector<c10::IValue> unwrap_c10_vector(const c10::IValue &output);

void copy_weights(const torch::jit::script::Module &from,
torch::nn::Module &to, const torch::Device &device,
std::shared_ptr<spdlog::logger> logger = nullptr);

void load_weights(torch::nn::Module &module, const std::string &filename,
const torch::Device &device,
std::shared_ptr<spdlog::logger> logger = nullptr);
}
}
#endif

0 comments on commit 69af7f4

Please sign in to comment.