From 8aed2a5d41123304243f14bd7e99509ddb4785fc Mon Sep 17 00:00:00 2001 From: maorz1998 Date: Wed, 9 Nov 2022 09:42:30 +0800 Subject: [PATCH] fix a bug in using libtorch --- src/dfChemistryModel/dfChemistryModel.C | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/dfChemistryModel/dfChemistryModel.C b/src/dfChemistryModel/dfChemistryModel.C index 858c864c..dcf45832 100644 --- a/src/dfChemistryModel/dfChemistryModel.C +++ b/src/dfChemistryModel/dfChemistryModel.C @@ -121,18 +121,21 @@ Foam::dfChemistryModel::dfChemistryModel GPUsPerNode = this->subDict("torchParameters1").lookupOrDefault("GPUsPerNode", 4); // initialization the Inferencer (if use multi GPU) - if(!(Pstream::myProcNo() % coresPerGPU)) // Now is a master + if(torchSwitch_) { - int CUDANo = (Pstream::myProcNo() / coresPerGPU) % GPUsPerNode; - std::string device_ = "cuda:" + std::to_string(CUDANo); - Info << "location 0" << endl; - Info << "torchModelName1_ = " << torchModelName1_ << endl; - torch::jit::script::Module torchModel1_ = torch::jit::load(torchModelName1_); - torch::jit::script::Module torchModel2_ = torch::jit::load(torchModelName2_); - torch::jit::script::Module torchModel3_ = torch::jit::load(torchModelName3_); - Info << "location 1" << endl; - DNNInferencer DNNInferencer(torchModel1_, torchModel2_, torchModel3_, device_); - DNNInferencer_ = DNNInferencer; + if(!(Pstream::myProcNo() % coresPerGPU)) // Now is a master + { + int CUDANo = (Pstream::myProcNo() / coresPerGPU) % GPUsPerNode; + std::string device_ = "cuda:" + std::to_string(CUDANo); + Info << "location 0" << endl; + Info << "torchModelName1_ = " << torchModelName1_ << endl; + torch::jit::script::Module torchModel1_ = torch::jit::load(torchModelName1_); + torch::jit::script::Module torchModel2_ = torch::jit::load(torchModelName2_); + torch::jit::script::Module torchModel3_ = torch::jit::load(torchModelName3_); + Info << "location 1" << endl; + DNNInferencer DNNInferencer(torchModel1_, torchModel2_, torchModel3_, device_); + DNNInferencer_ = DNNInferencer; + } } #endif