Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarking SchNet #14

Open
peastman opened this issue Oct 21, 2020 · 16 comments
Open

Benchmarking SchNet #14

peastman opened this issue Oct 21, 2020 · 16 comments
Labels
help wanted Extra attention is needed

Comments

@peastman
Copy link
Member

I've been trying to figure out how to write a benchmarking script for SchNet. Here's what I have so far with SchNetPack. It loads a PDB file and computes the energy 1000 times with one of the pre-trained QM9 models. I haven't figured out yet how to get it to compute forces, so any advice on that would be appreciated. There probably are other ways this could be improved too.

import torch
import schnetpack as spk
import schnetpack.md.calculators
import sys
import ase.io
import time

device = torch.device('cuda')
model = torch.load("trained_schnet_models/qm9_energy_U0/best_model", map_location=device)

atoms = ase.io.read(sys.argv[1])
system = spk.md.System(1, device=device)
system.load_molecules([atoms])

calculator = spk.md.calculators.SchnetPackCalculator(
    model,
    required_properties=['energy_U0'],
    force_handle=spk.Properties.forces,
    position_conversion='A',
    force_conversion='kcal/mol/A'
)
inputs = calculator._generate_input(system)
model(inputs)
t1 = time.time()
for i in range(1000):
    results = model(inputs)
print(results)
print(time.time()-t1)

Testing a 60 atom system on a Titan V, it takes about 3.6 ms per energy evaluation. Testing a 2269 atom system it runs out of memory on the GPU and crashes.

While the test is running, nvidia-smi shows that the GPU is only 28% busy. nvvp shows a lot of short kernels with larger gaps between them. The two most significant kernels are volta_sgemm_32x128_tn (19.8% of GPU time) and volta_sgemm_32x32_sliced1x4_tn (16% of GPU time). It then gets into a whole lot of kernels with uninformative names like _ZN2at6native6legacy18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_EEvS5_RKT_EUliE2_EEviT1_.

@raimis
Copy link
Contributor

raimis commented Oct 22, 2020

Testing a 60 atom system on a Titan V, it takes about 3.6 ms per energy evaluation.

The speed of SchNet is comparable to ANI. I thought these graph convolutions would be much slower. So it seems the ultimate bottleneck is the matrix multiplications, which is the case for ANI too.

@peastman
Copy link
Member Author

Any idea how I can get it to compute forces? The documentation implies I should just be able to call calculate() on the Calculator, but that throws an exception because the property names in this model don't match what it's expecting.

@raimis
Copy link
Contributor

raimis commented Oct 23, 2020

@stefdoerr might know.

@stefdoerr
Copy link

Sorry I don't have experience with SchnetPack. Maybe @giadefa knows

@stefdoerr
Copy link

Maybe check the available_properties attribute of the class?

@giadefa
Copy link
Member

giadefa commented Oct 23, 2020

@peastman
Copy link
Member Author

That code assumes the model has an output called "forces". The pretrained QM9 model only has a single output called "energy_U0".

I think I managed to correctly get it to compute forces, though I don't know if I'm doing it in the best way. After loading the system I added the line

system.positions.requires_grad_()

And then I changed the loop to

for i in range(1000):
    if system.positions.grad is not None:
        system.positions.grad.zero_()
    results = model(inputs)
    results['energy_U0'].backward(retain_graph=True)

It now takes 7.7 ms per iteration, which is still quite respectable compared to TorchANI. nvidia-smi now shows the GPU as 46% busy and nvvp still shows lots of gaps between kernels, so there ought to be lots of room for speedups.

@giadefa
Copy link
Member

giadefa commented Oct 24, 2020 via email

@peastman
Copy link
Member Author

It appears SchNetPack does support neighbor lists (specified with the System's neighborlist argument), but the only implementation it provides just lists every atom as being a neighbor of every other. Creating a proper implementation might improve performance, and should also help with running out of memory on large molecules.

@giadefa
Copy link
Member

giadefa commented Oct 29, 2020 via email

@peastman
Copy link
Member Author

I wrote a proper neighbor list implementation. There's probably a more efficient way of building it, but it works. I can now run the 2269 atom system. It takes 76 ms per iteration. On the 60 atom system, there's no change in speed.

import torch
import schnetpack as spk

class NeighborList(spk.md.neighbor_lists.MDNeighborList):
    def __init__(self, system, cutoff=None):
        self.simple = spk.md.neighbor_lists.SimpleNeighborList(system, cutoff)
        super(NeighborList, self).__init__(system, cutoff)

    def _construct_neighbor_list(self):
        self.simple._construct_neighbor_list()
        neighbors = self.simple.neighbor_list.view(-1, self.system.max_n_atoms, self.system.max_n_atoms-1)
        positions = self.system.positions.view(-1, self.system.max_n_atoms, 3)
        n_copies = neighbors.shape[0]
        n_atoms = neighbors.shape[1]
        r_ij = spk.nn.neighbors.atom_distances(positions, neighbors, None)
        lists = []
        for i in range(n_copies):
            for j in range(n_atoms):
                lists.append(neighbors[i,j][r_ij[i,j]<self.cutoff])
        max_neighbors = max(len(l) for l in lists)
        self.neighbor_list = torch.zeros(n_copies, n_atoms, max_neighbors, device=self.system.device, dtype=torch.int64)
        self.neighbor_mask = torch.zeros(n_copies, n_atoms, max_neighbors, device=self.system.device)
        iterator = iter(lists)
        for i in range(n_copies):
            for j in range(n_atoms):
                l = next(iterator)
                size = len(l)
                self.neighbor_list[i,j,:size] = l
                self.neighbor_mask[i,j,:size] = 1
        n_replicas = self.simple.neighbor_list.shape[0]
        n_molecules = self.simple.neighbor_list.shape[1]
        self.neighbor_list = self.neighbor_list.view(n_replicas, n_molecules, n_atoms, max_neighbors)
        self.neighbor_mask = self.neighbor_mask.view(n_replicas, n_molecules, n_atoms, max_neighbors)

    def update_neighbors(self):
        self._construct_neighbor_list()

@peastman
Copy link
Member Author

When running the large system, the GPU is 99% busy. That's compared to only 43% when running the small system.

@stefdoerr
Copy link

Never mind my comments (which I deleted). They were irrelevant to this project, I got confused by another discussion, sorry

@peastman
Copy link
Member Author

peastman commented Nov 2, 2020

I realized there was a mistake in the numbers above: I wasn't rebuilding the neighbor list for each iteration. With the default implementation you don't need to because it just includes every interaction, but if you want to use a real neighbor list you would need to rebuild it for every step of a simulation. The code above is very slow, so I came up with a much faster implementation.

import torch
import schnetpack as spk

class NeighborList(spk.md.neighbor_lists.MDNeighborList):
    def __init__(self, system, cutoff=None):
        self.simple = spk.md.neighbor_lists.SimpleNeighborList(system, cutoff)
        super(NeighborList, self).__init__(system, cutoff)

    def _construct_neighbor_list(self):
        self.simple._construct_neighbor_list()
        neighbors = self.simple.neighbor_list.view(-1, self.system.max_n_atoms, self.system.max_n_atoms-1)
        positions = self.system.positions.view(-1, self.system.max_n_atoms, 3)
        n_copies = neighbors.shape[0]
        n_atoms = neighbors.shape[1]
        r_ij = spk.nn.neighbors.atom_distances(positions, neighbors, None)
        mask = r_ij < self.cutoff
        max_neighbors = int(torch.count_nonzero(mask, dim=2).max())
        copy_index, atom_index, neighbor_index = torch.nonzero(mask, as_tuple=True)
        cumsum = torch.cumsum(mask, dim=2)-1
        target_index = cumsum[copy_index, atom_index, neighbor_index]
        self.neighbor_list = torch.zeros(n_copies, n_atoms, max_neighbors, device=self.system.device, dtype=torch.int64)
        self.neighbor_mask = torch.zeros(n_copies, n_atoms, max_neighbors, device=self.system.device)
        self.neighbor_list[copy_index, atom_index, target_index] = neighbors[copy_index, atom_index, neighbor_index]
        self.neighbor_mask[copy_index, atom_index, target_index] = 1
        n_replicas = self.simple.neighbor_list.shape[0]
        n_molecules = self.simple.neighbor_list.shape[1]
        self.neighbor_list = self.neighbor_list.view(n_replicas, n_molecules, n_atoms, max_neighbors)
        self.neighbor_mask = self.neighbor_mask.view(n_replicas, n_molecules, n_atoms, max_neighbors)

    def update_neighbors(self):
        self._construct_neighbor_list()

Here are some new benchmarks for the 60 atom system.

Default neighbor list, don't rebuild it: 8.2 ms
Default neighbor list, rebuild each iteration: 8.4 ms
"Real" neighbor list, rebuild each iteration: 9.0 ms

For the 2269 atom system:

Default neighbor list: runs out of memory
Real neighbor list: 81 ms

@peastman
Copy link
Member Author

I have benchmarks for the implementation in #18. I tried to make it as close as possible to the SchNetPack QM9 model benchmarked above. I use the same cutoff distance, number of Gaussians, and output width. Each iteration builds the neighbor list then computes the value and gradient six times to match the six layers in the model. This still isn't exactly comparable, since the real model includes other calculations in addition to the cfconv layers. Those are per-atom rather than per-interaction, though, so they should be much faster and only account for a small fraction of the computation time.

For the 60 atom system, it takes 2 ms/iteration, so roughly four times faster than SchNetPack. For the 2269 atom system it takes 86 ms/iteration, so slightly slower than SchNetPack (using the neighbor list implementation above, not the standard one).

@peastman
Copy link
Member Author

With my latest optimizations, the 60 atom system is down to only 0.82 ms/iteration. The 2269 atom system is basically unchanged, 88 ms/iteration. Possibly there are ways I could speed that up, but it's also possible the optimal way to structure the calculation is just different for small systems than for larger ones.

@raimis raimis added the help wanted Extra attention is needed label May 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants