Skip to content

Optimize calculate_radial_contributions to reduce GPU memory usage#316

Merged
wiederm merged 4 commits intomainfrom
dev-memory-aimnet2
Nov 9, 2024
Merged

Optimize calculate_radial_contributions to reduce GPU memory usage#316
wiederm merged 4 commits intomainfrom
dev-memory-aimnet2

Conversation

@wiederm
Copy link
Copy Markdown
Member

@wiederm wiederm commented Nov 9, 2024

Pull Request Summary

This PR addresses the high GPU memory usage issue caused by the creation of a large intermediate tensor in the calculate_radial_contributions function of the AIMNet2InteractionModule. The proposed fix optimizes the computation to reduce memory consumption without affecting the model's performance.

The original implementation:

def calculate_radial_contributions(
    self,
    gs: Tensor,
    a_j: Tensor,
    number_of_atoms: int,
    idx_j: Tensor,
) -> Tensor:
    # Compute radial contributions
    avf_s = gs.unsqueeze(-1) * a_j.unsqueeze(1)  # Shape: (number_of_pairs, G, F_atom)
    avf_s = avf_s.sum(dim=1)  # Sum over G

    # Aggregate per atom
    radial_contributions = torch.zeros(
        (number_of_atoms, F_atom),
        device=avf_s.device,
        dtype=avf_s.dtype,
    )
    radial_contributions.index_add_(0, idx_j, avf_s)

    return radial_contributions

is changed to

def calculate_radial_contributions(
    self,
    gs: Tensor,
    a_j: Tensor,
    number_of_atoms: int,
    idx_j: Tensor,
) -> Tensor:
    # Map gs to match the dimension of a_j
    mapped_gs = self.gs_to_fatom(gs)  # Linear layer mapping: (number_of_pairs, G) -> (number_of_pairs, F_atom)

    # Element-wise multiplication without expanding dimensions
    avf_s = a_j * mapped_gs  # Shape: (number_of_pairs, F_atom)

    # Aggregate per atom
    radial_contributions = torch.zeros(
        (number_of_atoms, F_atom),
        device=avf_s.device,
        dtype=avf_s.dtype,
    )
    radial_contributions.index_add_(0, idx_j, avf_s)

    return radial_contributions

Key changes

  • modified calculate_radial_contributions to compute radial contributions without creating a large intermediate tensor.
  • replaced the original tensor operations with a more memory-efficient approach using a linear layer.
  • updated the calculation of self.number_of_input_features to reflect the correct dimensions.

Associated Issue(s)

Pull Request Checklist

  • Issue(s) raised/addressed and linked
  • Includes appropriate unit test(s)
  • Appropriate docstring(s) added/updated
  • Appropriate .rst doc file(s) added/updated
  • PR is ready for review

…ze (nr_of_pairs, F, G) with F number of atom features and G number of radial features. The generation of this internal representation can be avoided, which is addressed in this PR
@wiederm wiederm self-assigned this Nov 9, 2024
@wiederm wiederm merged commit 426171a into main Nov 9, 2024
@wiederm wiederm deleted the dev-memory-aimnet2 branch November 9, 2024 22:37
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Nov 9, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.54%. Comparing base (cf5b7c3) to head (bac77c8).
Report is 5 commits behind head on main.

Additional details and impacted files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants