Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the execution of ANISymmetryFunctions on a non-default CUDA stream #37

Merged
merged 6 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions src/ani/CudaANISymmetryFunctions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ CudaANISymmetryFunctions::CudaANISymmetryFunctions(int numAtoms, int numSpecies,
const std::vector<RadialFunction>& radialFunctions, const std::vector<AngularFunction>& angularFunctions, bool torchani) :
ANISymmetryFunctions(numAtoms, numSpecies, radialCutoff, angularCutoff, periodic, atomSpecies, radialFunctions, angularFunctions, torchani),
positions(0), neighbors(0), neighborCount(0), periodicBoxVectors(0), angularIndex(0), atomSpeciesArray(0), radialFunctionArray(0), angularFunctionArray(0),
radialValues(0), angularValues(0), positionDerivValues(0) {
radialValues(0), angularValues(0), positionDerivValues(0), stream(0) {
CHECK_RESULT(cudaMallocManaged(&positions, numAtoms*sizeof(float3)));
CHECK_RESULT(cudaMallocManaged(&neighbors, numAtoms*numAtoms*sizeof(int)));
CHECK_RESULT(cudaMallocManaged(&neighborCount, numAtoms*sizeof(int)));
Expand Down Expand Up @@ -96,6 +96,14 @@ CudaANISymmetryFunctions::~CudaANISymmetryFunctions() {
cudaFree(positionDerivValues);
}

void CudaANISymmetryFunctions::setStream(cudaStream_t stream) {
this->stream = stream;
}

cudaStream_t CudaANISymmetryFunctions::getStream() const {
return stream;
}

template <bool PERIODIC, bool TRICLINIC>
__device__ void computeDisplacement(float3 pos1, float3 pos2, float3& delta, float& r2, const float* periodicBoxVectors, float3 invBoxSize) {
delta.x = pos2.x-pos1.x;
Expand Down Expand Up @@ -363,31 +371,31 @@ void CudaANISymmetryFunctions::computeSymmetryFunctions(const float* positions,
int numAngular = angularFunctions.size();
if (periodic) {
if (triclinic) {
computeRadialFunctions<true, true><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray);
computeRadialFunctions<true, true><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray);
if (torchani)
computeAngularFunctions<true, true, true><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
computeAngularFunctions<true, true, true><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
else
computeAngularFunctions<true, true, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
computeAngularFunctions<true, true, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
}
else {
computeRadialFunctions<true, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray);
computeRadialFunctions<true, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray);
if (torchani)
computeAngularFunctions<true, false, true><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
computeAngularFunctions<true, false, true><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
else
computeAngularFunctions<true, false, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
computeAngularFunctions<true, false, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
}
}
else {
computeRadialFunctions<false, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray);
computeRadialFunctions<false, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numRadial, radialCutoff, angularCutoff, radialPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray);
if (torchani)
computeAngularFunctions<false, false, true><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
computeAngularFunctions<false, false, true><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
else
computeAngularFunctions<false, false, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
computeAngularFunctions<false, false, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
}

// Apply the overall scale factors to the symmetry functions.

scaleSymmetryFunctions<<<maxBlocks, 128>>>(numAtoms, numSpecies, numRadial, numAngular, torchani, radialPtr, angularPtr, angularFunctionArray);
scaleSymmetryFunctions<<<maxBlocks, 128, 0, stream>>>(numAtoms, numSpecies, numRadial, numAngular, torchani, radialPtr, angularPtr, angularFunctionArray);

// Copy the final values to the destination memory.

Expand Down Expand Up @@ -633,26 +641,26 @@ void CudaANISymmetryFunctions::backprop(const float* radialDeriv, const float* a
int numBlocks = min(maxBlocks, numAtoms);
if (periodic) {
if (triclinic) {
backpropRadialFunctions<true, true><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani);
backpropRadialFunctions<true, true><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani);
if (torchani)
backpropAngularFunctions<true, true, true><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
backpropAngularFunctions<true, true, true><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
else
backpropAngularFunctions<true, true, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
backpropAngularFunctions<true, true, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
}
else {
backpropRadialFunctions<true, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani);
backpropRadialFunctions<true, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani);
if (torchani)
backpropAngularFunctions<true, false, true><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
backpropAngularFunctions<true, false, true><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
else
backpropAngularFunctions<true, false, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
backpropAngularFunctions<true, false, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
}
}
else {
backpropRadialFunctions<false, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani);
backpropRadialFunctions<false, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numRadial, radialCutoff, radialPtr, posPtr, (float3*) this->positions, this->periodicBoxVectors, radialFunctionArray, atomSpeciesArray, torchani);
if (torchani)
backpropAngularFunctions<false, false, true><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
backpropAngularFunctions<false, false, true><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
else
backpropAngularFunctions<false, false, false><<<numBlocks, blockSize>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
backpropAngularFunctions<false, false, false><<<numBlocks, blockSize, 0, stream>>>(numAtoms, numSpecies, numAngular, angularCutoff, angularPtr, posPtr, neighbors, neighborCount, (float3*) this->positions, this->periodicBoxVectors, angularFunctionArray, atomSpeciesArray, angularIndex);
}

// Copy the final values to the destination memory.
Expand Down
14 changes: 14 additions & 0 deletions src/ani/CudaANISymmetryFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* SOFTWARE.
*/

#include <cuda_runtime.h>
#include "ANISymmetryFunctions.h"

class CudaANISymmetryFunctions : public ANISymmetryFunctions {
Expand Down Expand Up @@ -78,6 +79,18 @@ class CudaANISymmetryFunctions : public ANISymmetryFunctions {
* @param positionDeriv an array of shape [numAtoms][3] to store the derivative of E with respect to the atom positions into
*/
void backprop(const float* radialDeriv, const float* angularDeriv, float* positionDeriv);
/**
* Set the CUDA stream. By default, it is set to 0, which means the default stream.
*
* @param stream a CUDA stream object
*/
void setStream(cudaStream_t stream);
/**
* Get the CUDA stream.
*
* @return a CUDA stream object
*/
cudaStream_t getStream() const;
private:
int* neighbors;
int* neighborCount;
Expand All @@ -92,6 +105,7 @@ class CudaANISymmetryFunctions : public ANISymmetryFunctions {
float* positionDerivValues;
bool triclinic;
int maxBlocks;
cudaStream_t stream;
};

#endif
15 changes: 14 additions & 1 deletion src/pytorch/SymmetryFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
*/

#include <stdexcept>
#include <cuda_runtime.h>
#include <torch/script.h>
#include <c10/cuda/CUDAStream.h>
#include "CpuANISymmetryFunctions.h"
#include "CudaANISymmetryFunctions.h"

Expand Down Expand Up @@ -96,6 +96,8 @@ class Holder : public torch::CustomClassHolder {
radial = torch::empty({numAtoms, numSpecies * (int)radialFunctions.size()}, tensorOptions);
angular = torch::empty({numAtoms, numSpecies * (numSpecies + 1) / 2 * (int)angularFunctions.size()}, tensorOptions);
positionsGrad = torch::empty({numAtoms, 3}, tensorOptions);

cudaSymFunc = dynamic_cast<CudaANISymmetryFunctions*>(symFunc.get());
};

tensor_list forward(const Tensor& positions_, const optional<Tensor>& periodicBoxVectors_) {
Expand All @@ -109,6 +111,11 @@ class Holder : public torch::CustomClassHolder {
float* periodicBoxVectorsPtr = periodicBoxVectors.data_ptr<float>();
}

if (cudaSymFunc) {
const torch::cuda::CUDAStream stream = torch::cuda::getCurrentCUDAStream(tensorOptions.device().index());
cudaSymFunc->setStream(stream.stream());
}

symFunc->computeSymmetryFunctions(positions.data_ptr<float>(), periodicBoxVectorsPtr, radial.data_ptr<float>(), angular.data_ptr<float>());

return {radial, angular};
Expand All @@ -119,6 +126,11 @@ class Holder : public torch::CustomClassHolder {
const Tensor radialGrad = grads[0].clone();
const Tensor angularGrad = grads[1].clone();

if (cudaSymFunc) {
const torch::cuda::CUDAStream stream = torch::cuda::getCurrentCUDAStream(tensorOptions.device().index());
cudaSymFunc->setStream(stream.stream());
}

symFunc->backprop(radialGrad.data_ptr<float>(), angularGrad.data_ptr<float>(), positionsGrad.data_ptr<float>());

return positionsGrad;
Expand All @@ -134,6 +146,7 @@ class Holder : public torch::CustomClassHolder {
Tensor radial;
Tensor angular;
Tensor positionsGrad;
CudaANISymmetryFunctions* cudaSymFunc;
};

class AutogradFunctions : public torch::autograd::Function<AutogradFunctions> {
Expand Down
36 changes: 36 additions & 0 deletions src/pytorch/TestSymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,41 @@ def test_model_serialization(deviceString, molFile):
energy_error = torch.abs((energy - energy_ref)/energy_ref)
grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref))

assert energy_error < 5e-7
assert grad_error < 5e-3

@pytest.mark.parametrize('molFile', ['1hvj', '1hvk', '2iuz', '3hkw', '3hky', '3lka', '3o99'])
def test_non_default_stream(molFile):

if not torch.cuda.is_available():
pytest.skip('CUDA is not available')

from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions

device = torch.device('cuda')

mol = mdtraj.load(os.path.join(molecules, f'{molFile}_ligand.mol2'))
atomicNumbers = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
atomicPositions = torch.tensor(mol.xyz * 10, dtype=torch.float32, requires_grad=True, device=device)

nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer)

energy_ref = nnp((atomicNumbers, atomicPositions)).energies
energy_ref.backward()
grad_ref = atomicPositions.grad.clone()

stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
energy = nnp((atomicNumbers, atomicPositions)).energies
atomicPositions.grad.zero_()
energy.backward()
grad = atomicPositions.grad.clone()
torch.cuda.current_stream().wait_stream(stream)

energy_error = torch.abs((energy - energy_ref)/energy_ref)
grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref))

assert energy_error < 5e-7
assert grad_error < 5e-3