Skip to content

Commit

Permalink
fix c++ interface bug (#3613)
Browse files Browse the repository at this point in the history
#3578
#3579

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CaRoLZhangxy and pre-commit-ci[bot] committed Mar 28, 2024
1 parent 571bd52 commit c2371cd
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 118 deletions.
54 changes: 47 additions & 7 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "DeepPotPT.h"

#include "common.h"
#include "device.h"
using namespace deepmd;
torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) {
std::vector<torch::Tensor> row_tensors;
Expand Down Expand Up @@ -36,15 +37,19 @@ void DeepPotPT::init(const std::string& model,
<< std::endl;
return;
}
gpu_id = gpu_rank;
torch::Device device(torch::kCUDA, gpu_rank);
int gpu_num = torch::cuda::device_count();
if (gpu_num > 0) {
gpu_id = gpu_rank % gpu_num;
} else {
gpu_id = 0;
}
torch::Device device(torch::kCUDA, gpu_id);
gpu_enabled = torch::cuda::is_available();
if (!gpu_enabled) {
device = torch::Device(torch::kCPU);
std::cout << "load model from: " << model << " to cpu " << gpu_rank
<< std::endl;
std::cout << "load model from: " << model << " to cpu " << std::endl;
} else {
std::cout << "load model from: " << model << " to gpu " << gpu_rank
std::cout << "load model from: " << model << " to gpu " << gpu_id
<< std::endl;
}
module = torch::jit::load(model, device);
Expand Down Expand Up @@ -107,7 +112,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
auto int32_options = torch::TensorOptions().dtype(torch::kInt32);

// select real atoms
std::vector<VALUETYPE> dcoord, dforce, aparam_, datom_energy, datom_virial;
std::vector<int> datype, fwd_map, bkw_map;
Expand All @@ -116,6 +120,25 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map,
bkw_map, nall_real, nloc_real, coord, atype, aparam,
nghost, ntypes, 1, daparam, nall, aparam_nall);
int nloc = nall_real - nghost_real;
int nframes = 1;
if (nloc == 0) {
// no backward map needed
ener.resize(nframes);
// dforce of size nall * 3
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
fill(force.begin(), force.end(), (VALUETYPE)0.0);
// dvirial of size 9
virial.resize(static_cast<size_t>(nframes) * 9);
fill(virial.begin(), virial.end(), (VALUETYPE)0.0);
// datom_energy_ of size nall
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
fill(atom_energy.begin(), atom_energy.end(), (VALUETYPE)0.0);
// datom_virial_ of size nall * 9
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
fill(atom_virial.begin(), atom_virial.end(), (VALUETYPE)0.0);
return;
}
std::vector<VALUETYPE> coord_wrapped = dcoord;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options)
Expand Down Expand Up @@ -185,7 +208,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
datom_virial.assign(
cpu_atom_virial_.data_ptr<VALUETYPE>(),
cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());
int nframes = 1;
// bkw map
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
Expand Down Expand Up @@ -249,6 +271,24 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
floatType = torch::kFloat32;
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
int nframes = 1;
if (natoms == 0) {
// no backward map needed
ener.resize(nframes);
// dforce of size nall * 3
force.resize(static_cast<size_t>(nframes) * natoms * 3);
fill(force.begin(), force.end(), (VALUETYPE)0.0);
// dvirial of size 9
virial.resize(static_cast<size_t>(nframes) * 9);
fill(virial.begin(), virial.end(), (VALUETYPE)0.0);
// datom_energy_ of size nall
atom_energy.resize(static_cast<size_t>(nframes) * natoms);
fill(atom_energy.begin(), atom_energy.end(), (VALUETYPE)0.0);
// datom_virial_ of size nall * 9
atom_virial.resize(static_cast<size_t>(nframes) * natoms * 9);
fill(atom_virial.begin(), atom_virial.end(), (VALUETYPE)0.0);
return;
}
std::vector<torch::jit::IValue> inputs;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options)
Expand Down

0 comments on commit c2371cd

Please sign in to comment.