diff --git a/src/ani/CudaANISymmetryFunctions.cu b/src/ani/CudaANISymmetryFunctions.cu index 0463771..c1ebca8 100644 --- a/src/ani/CudaANISymmetryFunctions.cu +++ b/src/ani/CudaANISymmetryFunctions.cu @@ -39,7 +39,7 @@ CudaANISymmetryFunctions::CudaANISymmetryFunctions(int numAtoms, int numSpecies, const std::vector& radialFunctions, const std::vector& angularFunctions, bool torchani) : ANISymmetryFunctions(numAtoms, numSpecies, radialCutoff, angularCutoff, periodic, atomSpecies, radialFunctions, angularFunctions, torchani), positions(0), neighbors(0), neighborCount(0), periodicBoxVectors(0), angularIndex(0), atomSpeciesArray(0), radialFunctionArray(0), angularFunctionArray(0), - radialValues(0), angularValues(0), positionDerivValues(0) { + radialValues(0), angularValues(0), positionDerivValues(0), stream(0) { CHECK_RESULT(cudaMallocManaged(&positions, numAtoms*sizeof(float3))); CHECK_RESULT(cudaMallocManaged(&neighbors, numAtoms*numAtoms*sizeof(int))); CHECK_RESULT(cudaMallocManaged(&neighborCount, numAtoms*sizeof(int))); @@ -96,6 +96,14 @@ CudaANISymmetryFunctions::~CudaANISymmetryFunctions() { cudaFree(positionDerivValues); } +void CudaANISymmetryFunctions::setStream(cudaStream_t stream) { + this->stream = stream; +} + +cudaStream_t CudaANISymmetryFunctions::getStream() const { + return stream; +} + template __device__ void computeDisplacement(float3 pos1, float3 pos2, float3& delta, float& r2, const float* periodicBoxVectors, float3 invBoxSize) { delta.x = pos2.x-pos1.x; @@ -363,31 +371,31 @@ void CudaANISymmetryFunctions::computeSymmetryFunctions(const float* positions, int numAngular = angularFunctions.size(); if (periodic) { if (triclinic) { - computeRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray); + computeRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray); if (torchani) - computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); else - computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); } else { - computeRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray); + computeRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray); if (torchani) - computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); else - computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); } } else { - computeRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray); + computeRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray); if (torchani) - computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); else - computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + computeAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); } // Apply the overall scale factors to the symmetry functions. - scaleSymmetryFunctions<<>>(numAtoms, numSpecies, numRadial, numAngular, torchani, radialPtr, angularPtr, angularFunctionArray); + scaleSymmetryFunctions<<>>(numAtoms, numSpecies, numRadial, numAngular, torchani, radialPtr, angularPtr, angularFunctionArray); // Copy the final values to the destination memory. @@ -633,26 +641,26 @@ void CudaANISymmetryFunctions::backprop(const float* radialDeriv, const float* a int numBlocks = min(maxBlocks, numAtoms); if (periodic) { if (triclinic) { - backpropRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani); + backpropRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani); if (torchani) - backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); else - backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); } else { - backpropRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani); + backpropRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani); if (torchani) - backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); else - backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); } } else { - backpropRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani); + backpropRadialFunctions<<>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani); if (torchani) - backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); else - backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); + backpropAngularFunctions<<>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex); } // Copy the final values to the destination memory. diff --git a/src/ani/CudaANISymmetryFunctions.h b/src/ani/CudaANISymmetryFunctions.h index ed542ed..b9ad6b5 100644 --- a/src/ani/CudaANISymmetryFunctions.h +++ b/src/ani/CudaANISymmetryFunctions.h @@ -24,6 +24,7 @@ * SOFTWARE. */ +#include #include "ANISymmetryFunctions.h" class CudaANISymmetryFunctions : public ANISymmetryFunctions { @@ -78,6 +79,18 @@ class CudaANISymmetryFunctions : public ANISymmetryFunctions { * @param positionDeriv an array of shape [numAtoms][3] to store the derivative of E with respect to the atom positions into */ void backprop(const float* radialDeriv, const float* angularDeriv, float* positionDeriv); + /** + * Set the CUDA stream. By default, it is set to 0, which means the default stream. + * + * @param stream a CUDA stream object + */ + void setStream(cudaStream_t stream); + /** + * Get the CUDA stream. + * + * @return a CUDA stream object + */ + cudaStream_t getStream() const; private: int* neighbors; int* neighborCount; @@ -92,6 +105,7 @@ class CudaANISymmetryFunctions : public ANISymmetryFunctions { float* positionDerivValues; bool triclinic; int maxBlocks; + cudaStream_t stream; }; #endif diff --git a/src/pytorch/SymmetryFunctions.cpp b/src/pytorch/SymmetryFunctions.cpp index b4284cc..6c508f9 100644 --- a/src/pytorch/SymmetryFunctions.cpp +++ b/src/pytorch/SymmetryFunctions.cpp @@ -22,8 +22,8 @@ */ #include -#include #include +#include #include "CpuANISymmetryFunctions.h" #include "CudaANISymmetryFunctions.h" @@ -96,6 +96,8 @@ class Holder : public torch::CustomClassHolder { radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, tensorOptions); angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, tensorOptions); positionsGrad = torch::empty({numAtoms, 3}, tensorOptions); + + cudaSymFunc = dynamic_cast(symFunc.get()); }; tensor_list forward(const Tensor& positions_, const optional& periodicBoxVectors_) { @@ -109,6 +111,11 @@ class Holder : public torch::CustomClassHolder { float* periodicBoxVectorsPtr = periodicBoxVectors.data_ptr(); } + if (cudaSymFunc) { + const torch::cuda::CUDAStream stream = torch::cuda::getCurrentCUDAStream(tensorOptions.device().index()); + cudaSymFunc->setStream(stream.stream()); + } + symFunc->computeSymmetryFunctions(positions.data_ptr(), periodicBoxVectorsPtr, radial.data_ptr(), angular.data_ptr()); return {radial, angular}; @@ -119,6 +126,11 @@ class Holder : public torch::CustomClassHolder { const Tensor radialGrad = grads[0].clone(); const Tensor angularGrad = grads[1].clone(); + if (cudaSymFunc) { + const torch::cuda::CUDAStream stream = torch::cuda::getCurrentCUDAStream(tensorOptions.device().index()); + cudaSymFunc->setStream(stream.stream()); + } + symFunc->backprop(radialGrad.data_ptr(), angularGrad.data_ptr(), positionsGrad.data_ptr()); return positionsGrad; @@ -134,6 +146,7 @@ class Holder : public torch::CustomClassHolder { Tensor radial; Tensor angular; Tensor positionsGrad; + CudaANISymmetryFunctions* cudaSymFunc; }; class AutogradFunctions : public torch::autograd::Function { diff --git a/src/pytorch/TestSymmetryFunctions.py b/src/pytorch/TestSymmetryFunctions.py index fa51de0..d1a2b62 100644 --- a/src/pytorch/TestSymmetryFunctions.py +++ b/src/pytorch/TestSymmetryFunctions.py @@ -104,5 +104,41 @@ def test_model_serialization(deviceString, molFile): energy_error = torch.abs((energy - energy_ref)/energy_ref) grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref)) + assert energy_error < 5e-7 + assert grad_error < 5e-3 + +@pytest.mark.parametrize('molFile', ['1hvj', '1hvk', '2iuz', '3hkw', '3hky', '3lka', '3o99']) +def test_non_default_stream(molFile): + + if not torch.cuda.is_available(): + pytest.skip('CUDA is not available') + + from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions + + device = torch.device('cuda') + + mol = mdtraj.load(os.path.join(molecules, f'{molFile}_ligand.mol2')) + atomicNumbers = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device) + atomicPositions = torch.tensor(mol.xyz * 10, dtype=torch.float32, requires_grad=True, device=device) + + nnp = torchani.models.ANI2x(periodic_table_index=True).to(device) + nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer) + + energy_ref = nnp((atomicNumbers, atomicPositions)).energies + energy_ref.backward() + grad_ref = atomicPositions.grad.clone() + + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + energy = nnp((atomicNumbers, atomicPositions)).energies + atomicPositions.grad.zero_() + energy.backward() + grad = atomicPositions.grad.clone() + torch.cuda.current_stream().wait_stream(stream) + + energy_error = torch.abs((energy - energy_ref)/energy_ref) + grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref)) + assert energy_error < 5e-7 assert grad_error < 5e-3 \ No newline at end of file