-
Notifications
You must be signed in to change notification settings - Fork 7
Add padding #199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add padding #199
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
045adb0
Add pair_mask parameter to potential calculations in Calculator and C…
E-Rum 1170639
small typing fix and adding helper function that generates batched k-…
E-Rum ee63dbb
first iteration that forward passes (broken as hell)
E-Rum f7f17e9
cleaning up
E-Rum 1a96505
add docs and linter
E-Rum 71765de
fix existing tests
E-Rum 4c51838
Add padding tests
E-Rum 34d34d8
Refactor error handling and improve tests for potential padding
E-Rum e0600e4
Add examples and input files for batched Ewald computation with padding
E-Rum 64f5002
Fix formatting in padding example documentation
E-Rum 5e045b9
Enhance padding example with additional systems and performance compa…
E-Rum c5e8ac4
Update changelog
E-Rum File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,204 @@ | ||
| """ | ||
| Batched Ewald Computation with Padding | ||
| ====================================== | ||
|
|
||
| This example demonstrates how to compute Ewald potentials for a batch of systems with | ||
| different numbers of atoms using padding. The idea is to pad atomic positions, charges, | ||
| and neighbor lists to the same length and use masks to ignore padded entries during | ||
| computation. Note that batching systems of varying sizes in this way can increase the | ||
| computational cost during model training, since padded atoms are included in the batched | ||
| operations even though they don't contribute physically. | ||
| """ | ||
|
|
||
| # %% | ||
| import time | ||
|
|
||
| import torch | ||
| import vesin | ||
| from torch.nn.utils.rnn import pad_sequence | ||
|
|
||
| import torchpme | ||
|
|
||
| dtype = torch.float64 | ||
| cutoff = 4.4 | ||
|
|
||
| # %% | ||
| # Example: two systems with 5 different systems | ||
| systems = [ | ||
| { | ||
| "symbols": ("Cs", "Cl"), | ||
| "positions": torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=dtype), | ||
| "charges": torch.tensor([[1.0], [-1.0]], dtype=dtype), | ||
| "cell": torch.eye(3, dtype=dtype) * 3.0, | ||
| "pbc": torch.tensor([True, True, True]), | ||
| }, | ||
| { | ||
| "symbols": ("Na", "Cl", "Cl"), | ||
| "positions": torch.tensor( | ||
| [(0, 0, 0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)], dtype=dtype | ||
| ), | ||
| "charges": torch.tensor([[1.0], [-1.0], [-1.0]], dtype=dtype), | ||
| "cell": torch.eye(3, dtype=dtype) * 4.0, | ||
| "pbc": torch.tensor([True, True, True]), | ||
| }, | ||
| { | ||
| "symbols": ("K", "Br", "Br", "K"), | ||
| "positions": torch.tensor( | ||
| [(0, 0, 0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25), (0.75, 0.75, 0.75)], | ||
| dtype=dtype, | ||
| ), | ||
| "charges": torch.tensor([[1.0], [-1.0], [-1.0], [1.0]], dtype=dtype), | ||
| "cell": torch.eye(3, dtype=dtype) * 5.0, | ||
| "pbc": torch.tensor([True, True, True]), | ||
| }, | ||
| { | ||
| "symbols": ("Mg", "O", "O", "Mg", "O"), | ||
| "positions": torch.tensor( | ||
| [ | ||
| (0, 0, 0), | ||
| (0.5, 0.5, 0.5), | ||
| (0.25, 0.25, 0.25), | ||
| (0.75, 0.75, 0.75), | ||
| (0.1, 0.1, 0.1), | ||
| ], | ||
| dtype=dtype, | ||
| ), | ||
| "charges": torch.tensor([[2.0], [-2.0], [-2.0], [2.0], [-2.0]], dtype=dtype), | ||
| "cell": torch.eye(3, dtype=dtype) * 6.0, | ||
| "pbc": torch.tensor([True, True, True]), | ||
| }, | ||
| { | ||
| "symbols": ("Al", "O", "O", "Al", "O", "O"), | ||
| "positions": torch.tensor( | ||
| [ | ||
| (0, 0, 0), | ||
| (0.5, 0.5, 0.5), | ||
| (0.25, 0.25, 0.25), | ||
| (0.75, 0.75, 0.75), | ||
| (0.1, 0.1, 0.1), | ||
| (0.9, 0.9, 0.9), | ||
| ], | ||
| dtype=dtype, | ||
| ), | ||
| "charges": torch.tensor( | ||
| [[3.0], [-2.0], [-2.0], [3.0], [-2.0], [-2.0]], dtype=dtype | ||
| ), | ||
| "cell": torch.eye(3, dtype=dtype) * 7.0, | ||
| "pbc": torch.tensor([True, True, True]), | ||
| }, | ||
| ] | ||
|
|
||
| # %% | ||
| # Compute neighbor lists for each system | ||
| i_list, j_list, d_list, pos_list, charges_list, cell_list, periodic_list = ( | ||
| [], | ||
| [], | ||
| [], | ||
| [], | ||
| [], | ||
| [], | ||
| [], | ||
| ) | ||
|
|
||
| nl = vesin.NeighborList(cutoff=cutoff, full_list=False) | ||
|
|
||
| for sys in systems: | ||
| neighbor_indices, neighbor_distances = nl.compute( | ||
| points=sys["positions"], | ||
| box=sys["cell"], | ||
| periodic=sys["pbc"][0], | ||
| quantities="Pd", | ||
| ) | ||
| i_list.append(torch.tensor(neighbor_indices[:, 0], dtype=torch.int64)) | ||
| j_list.append(torch.tensor(neighbor_indices[:, 1], dtype=torch.int64)) | ||
| d_list.append(torch.tensor(neighbor_distances, dtype=dtype)) | ||
| pos_list.append(sys["positions"]) | ||
| charges_list.append(sys["charges"]) | ||
| cell_list.append(sys["cell"]) | ||
| periodic_list.append(sys["pbc"]) | ||
|
|
||
| # %% | ||
| # Pad positions, charges, and neighbor lists | ||
| max_atoms = max(pos.shape[0] for pos in pos_list) | ||
| pos_batch = pad_sequence(pos_list, batch_first=True) | ||
| charges_batch = pad_sequence(charges_list, batch_first=True) | ||
| cell_batch = torch.stack(cell_list) | ||
| periodic_batch = torch.stack(periodic_list) | ||
| i_batch = pad_sequence(i_list, batch_first=True, padding_value=0) | ||
| j_batch = pad_sequence(j_list, batch_first=True, padding_value=0) | ||
| d_batch = pad_sequence(d_list, batch_first=True, padding_value=0.0) | ||
|
|
||
| # Masks for ignoring padded atoms and neighbor entries | ||
| node_mask = ( | ||
| torch.arange(max_atoms)[None, :] | ||
| < torch.tensor([p.shape[0] for p in pos_list])[:, None] | ||
| ) | ||
| pair_mask = ( | ||
| torch.arange(i_batch.shape[1])[None, :] | ||
| < torch.tensor([len(i) for i in i_list])[:, None] | ||
| ) | ||
| # %% | ||
| # Initialize Ewald calculator | ||
| calculator = torchpme.EwaldCalculator( | ||
| torchpme.CoulombPotential(smearing=0.5), | ||
| lr_wavelength=4.0, | ||
| ) | ||
| calculator.to(dtype=dtype) | ||
|
|
||
| # %% | ||
| # Compute potentials in a batched manner using vmap | ||
| kvectors = torchpme.lib.compute_batched_kvectors( | ||
| lr_wavelength=calculator.lr_wavelength, cells=cell_batch | ||
| ) | ||
|
|
||
| potentials_batch = torch.vmap(calculator.forward)( | ||
| charges_batch, | ||
| cell_batch, | ||
| pos_batch, | ||
| torch.stack((i_batch, j_batch), dim=-1), | ||
| d_batch, | ||
| periodic_batch, | ||
| node_mask, | ||
| pair_mask, | ||
| kvectors, | ||
| ) | ||
|
|
||
| # %% | ||
| print("Batched potentials shape:", potentials_batch.shape) | ||
| print(potentials_batch) | ||
| # %% | ||
| # Compare performance of batched vs. looped computation | ||
| n_iter = 100 | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(n_iter): | ||
| _ = torch.vmap(calculator.forward)( | ||
| charges_batch, | ||
| cell_batch, | ||
| pos_batch, | ||
| torch.stack((i_batch, j_batch), dim=-1), | ||
| d_batch, | ||
| periodic_batch, | ||
| node_mask, | ||
| pair_mask, | ||
| kvectors, | ||
| ) | ||
| t_batch = (time.perf_counter() - t0) / n_iter | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(n_iter): | ||
| for k in range(len(pos_list)): | ||
| _ = calculator.forward( | ||
| charges_list[k], | ||
| cell_list[k], | ||
| pos_list[k], | ||
| torch.stack((i_list[k], j_list[k]), dim=-1), | ||
| d_list[k], | ||
| periodic_list[k], | ||
| ) | ||
| t_loop = (time.perf_counter() - t0) / n_iter | ||
|
|
||
| print(f"Average time per batched call: {t_batch:.6f} s") | ||
| print(f"Average time per loop call: {t_loop:.6f} s") | ||
| print("Batched is faster" if t_batch < t_loop else "Loop is faster") | ||
| # %% |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also show a speed comparison for a batched one and a looped one...
Might be nice :-)