Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
add cpu/gpu and full smiles output
  • Loading branch information
hgarrereyn committed Jun 14, 2021
1 parent 670da60 commit eb36305
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 23 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Expand Up @@ -6,3 +6,7 @@ data/**
!data/README.md

.DS_Store

.store/
dist/**
build/**
7 changes: 4 additions & 3 deletions README.md
Expand Up @@ -7,9 +7,7 @@ DeepFrag is a machine learning model for fragment-based lead optimization. In th

If you use DeepFrag in your research, please cite as:

```
Green, H., Koes, D. R., & Durrant, J. D. (2021). DeepFrag: a deep convolutional neural network for fragment-based lead optimization. Chemical Science.
```

```tex
@article{green2021deepfrag,
Expand Down Expand Up @@ -64,11 +62,14 @@ To remove a fragment, you specify a second atom that is contained in the fragmen

By default, DeepFrag will print a list of fragment predictions to stdout similar to the [Browser App](https://durrantlab.pitt.edu/deepfrag/).

- `--out <out.csv>`: Save predictions in CSV format to `out.csv`.
- `--out <out.csv>`: Save predictions in CSV format to `out.csv`. Each line contains the fragment rank, score and SMILES string.

## Miscellaneous (optional)

- `--full`: Generate SMILES strings with the full ligand structure instead of just the fragment.
- `--cpu/--gpu`: DeepFrag will attempt to infer if a Cuda GPU is available and fallback to the CPU if it is not. You can set either the `--cpu` or `--gpu` flag to explicitly specify the target device.
- `--num_grids <num>`: Number of grid rotations to use. Using more will take longer but produce a more stable prediction. (Default: 4)
- `--top_k <k>`: Number of predictions to print in stdout. Use -1 to display all. (Default: 25)

# Reproduce Results

Expand Down
105 changes: 85 additions & 20 deletions deepfrag.py
Expand Up @@ -207,13 +207,13 @@ def get_structures(args):
parent_coords = util.get_coords(lig)
parent_types = np.array(util.get_types(lig)).reshape((-1,1))

return (rec_coords, rec_types, parent_coords, parent_types, conn)
return (rec_coords, rec_types, parent_coords, parent_types, conn, lig)


def get_model(args):
def get_model(args, device):
"""Load a pre-trained DeepFrag model."""
print('[*] Loading model ... ', end='')
model = LeadoptModel.load(str(get_model_path() / 'final_model'), device='cpu')
model = LeadoptModel.load(str(get_model_path() / 'final_model'), device=('cuda' if device == 'gpu' else device))
print('done.')
return model

Expand All @@ -233,7 +233,26 @@ def get_fingerprints(args):
return (f_smiles, f_fingerprints)


def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, parent_types, conn):
def get_target_device(args) -> str:
"""Infer the target device or use the argument overrides."""
device = 'gpu' if torch.cuda.device_count() > 0 else 'cpu'

if args.cpu:
if device == 'gpu':
print('[*] Warning: GPU is available but running on CPU due to --cpu flag')
device = 'cpu'
elif args.gpu:
if device == 'cpu':
print('[*] Error: No CUDA-enabled GPU was found. Exiting due to --gpu flag. You can run on the CPU instead with the --cpu flag.')
exit(-1)
device = 'gpu'

print('[*] Running on device: %s' % device)

return device


def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, parent_types, conn, device):
start = time.time()

print('[*] Generating grids ... ', end='', flush=True)
Expand All @@ -248,7 +267,7 @@ def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, paren
point_radius=model_args['point_radius'],
point_type=model_args['point_type'],
acc_type=model_args['acc_type'],
cpu=True
cpu=(device == 'cpu')
)
print('done.')
end = time.time()
Expand Down Expand Up @@ -276,44 +295,80 @@ def get_predictions(model, batch, f_smiles, f_fingerprints):
dist = list(dist.numpy())
scores = list(zip(f_smiles, dist))
scores = sorted(scores, key=lambda x:x[1], reverse=True)
scores = [(a.decode('ascii'), b) for a,b in scores]

return scores


def gen_output(args, scores):
if args.top_k != -1:
scores = scores[:args.top_k]

if args.out is None:
# Write results to stdout.
print('%4s %8s %s' % ('#', 'Score', 'Fragment'))
print('%4s %8s %s' % ('#', 'Score', 'SMILES'))
for i in range(len(scores)):
smi, score = scores[i]
print('%4d %8f %s' % (i+1, score, smi.decode('ascii')))
print('%4d %8f %s' % (i+1, score, smi))
else:
# Write csv output.
csv = 'Rank,Fragment SMILES,Score\n'
csv = 'Rank,SMILES,Score\n'
for i in range(len(scores)):
smi, score = scores[i]
csv += '%d,%s,%f\n' % (
i+1, smi.decode('ascii'), score
i+1, smi, score
)

open(args.out, 'w').write(csv)
print('[*] Wrote output to %s' % args.out)


def fuse(lig, frag):
merged = Chem.RWMol(Chem.CombineMols(lig, frag))

conn_atoms = [a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0]
neighbors = [merged.GetAtomWithIdx(x).GetNeighbors()[0].GetIdx() for x in conn_atoms]

bond = merged.AddBond(neighbors[0], neighbors[1], Chem.rdchem.BondType.SINGLE)

merged.RemoveAtom([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0])
merged.RemoveAtom([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0])

Chem.SanitizeMol(merged)

return merged


def fuse_fragments(lig, conn, scores):
new_sc = []
for smi, score in scores:
try:
frag = Chem.MolFromSmiles(smi)
fused = fuse(Chem.Mol(lig), frag)
new_sc.append((Chem.MolToSmiles(fused, False), score))
except:
print('[*] Error: couldn\'t process mol.')
new_sc.append(('<err>', score))

return new_sc


def run(args):
model = get_model(args)
device = get_target_device(args)

model = get_model(args, device)
f_smiles, f_fingerprints = get_fingerprints(args)

rec_coords, rec_types, parent_coords, parent_types, conn = get_structures(args)
rec_coords, rec_types, parent_coords, parent_types, conn, lig = get_structures(args)

batch = generate_grids(args, model._args, rec_coords, rec_types,
parent_coords, parent_types, conn)
parent_coords, parent_types, conn, device)

scores = get_predictions(model, batch, f_smiles, f_fingerprints)

if args.top_k != -1:
scores = scores[:args.top_k]

if args.full:
scores = fuse_fragments(lig, conn, scores)

gen_output(args, scores)


Expand Down Expand Up @@ -341,24 +396,34 @@ def main():
parser.add_argument('--rname', type=str, help='Removal point atom name.')

# Misc
parser.add_argument('--num_grids', type=int, default=4, help='Number of grid rotations.')
parser.add_argument('-k', '--top_k', type=int, default=25, help='Number of results to show. Set to -1 to show all.')
parser.add_argument('--out', type=str, help='Path to output CSV file.')
parser.add_argument('--full', action='store_true', default=False,
help='Print the full (fused) ligand structure.')
parser.add_argument('--num_grids', type=int, default=4,
help='Number of grid rotations.')
parser.add_argument('--top_k', type=int, default=25,
help='Number of results to show. Set to -1 to show all.')
parser.add_argument('--out', type=str,
help='Path to output CSV file.')
parser.add_argument('--cpu', action='store_true', default=False,
help='Use the CPU for grid generation and predictions.')
parser.add_argument('--gpu', action='store_true', default=False,
help='Use a (CUDA-capable) GPU for grid generation and predictions.')

args = parser.parse_args()

groupings = [
([('receptor', 'ligand'), ('pdb', 'resnum')], True),
([('cx', 'cy', 'cz'), ('cname',)], True),
([('rx', 'ry', 'rz'), ('rname',)], False)
([('rx', 'ry', 'rz'), ('rname',)], False),
([('cpu',), ('gpu',)], False)
]

for grp, req in groupings:
partial = []
complete = 0

for subset in grp:
res = [getattr(args, name) is not None for name in subset]
res = [not (getattr(args, name) in [None, False]) for name in subset]
partial.append(any(res) and not all(res))
complete += int(all(res))

Expand Down

0 comments on commit eb36305

Please sign in to comment.