Skip to content

Commit

Permalink
Fix GPU relax for longer chains by pinning large memory ops to cpu.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501105389
Change-Id: I6c981d1d3231e008ebae192edb4586479eb5eb34
  • Loading branch information
alexanderimanicowenrivers authored and Copybara-Service committed Jan 10, 2023
1 parent 420fb08 commit 8f1ebd5
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion alphafold/relax/amber_minimize.py
Expand Up @@ -26,6 +26,7 @@
from alphafold.relax import utils
import ml_collections
import numpy as np
import jax
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
Expand Down Expand Up @@ -486,7 +487,9 @@ def run_pipeline(
pdb_string = clean_protein(prot, checks=True)
else:
pdb_string = ret["min_pdb"]
ret.update(get_violation_metrics(prot))
# Calculation of violations can cause CUDA errors for some JAX versions.
with jax.default_device(jax.devices("cpu")[0]):
ret.update(get_violation_metrics(prot))
ret.update({
"num_exclusions": len(exclude_residues),
"iteration": iteration,
Expand Down

0 comments on commit 8f1ebd5

Please sign in to comment.