Skip to content

Commit

Permalink
Merge pull request #2 from jonathanking/parallel_combine_data
Browse files Browse the repository at this point in the history
Combine SidechainNet and ProteinNet in parallel
  • Loading branch information
jonathanking committed May 11, 2020
2 parents b023385 + b4a0047 commit 5e460c7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 54 deletions.
17 changes: 15 additions & 2 deletions sidechainnet/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import prody as pr
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

from sidechainnet.utils.alignment import can_be_directly_merged, expand_data_with_mask

Expand Down Expand Up @@ -103,8 +104,12 @@ def combine_datasets(proteinnet_out, sc_data, training_set):

aligner = init_aligner()
# for pnid in error_ids:
for pnid in tqdm(sc_data.keys(), dynamic_ncols=True):
combined_result, warning = combine(pn_data[pnid], sc_data[pnid], aligner)
with Pool(cpu_count()) as p:
tuples = (get_tuple(pn_data, sc_data, pnid) for pnid in sc_data.keys())
results_warnings = list(tqdm(p.imap(combine_wrapper, tuples), total=len(sc_data.keys()), dynamic_ncols=True))

# for pnid in tqdm(sc_data.keys(), dynamic_ncols=True):
for (combined_result, warning), pnid in zip(results_warnings, sc_data.keys()):
if combined_result:
pn_data[pnid] = combined_result
else:
Expand Down Expand Up @@ -136,6 +141,14 @@ def combine_datasets(proteinnet_out, sc_data, training_set):
f"{len(errors['failed'])} IDs failed to combine successfully.")
return pn_data

def get_tuple(pndata, scdata, pnid):
return pndata[pnid], scdata[pnid]

def combine_wrapper(pndata_scdata):
pn_data, sc_data = pndata_scdata
aligner = init_aligner()
return combine(pn_data, sc_data, aligner)


def main():
# First, create PyTorch versions of raw proteinnet files for convenience
Expand Down
4 changes: 2 additions & 2 deletions sidechainnet/download_and_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def get_sidechain_data(pnids, limit):
that failed to download.
"""
with multiprocessing.Pool(multiprocessing.cpu_count()) as p:
results = list(tqdm.tqdm(p.imap(process_id, pnids[:limit]),
total=len(pnids[:limit]), dynamic_ncols=True))
results = tqdm.tqdm(p.imap(process_id, pnids[:limit]),
total=len(pnids[:limit]), dynamic_ncols=True)
all_errors = []
all_data = dict()
with open("errors/MODIFIED_MODEL_WARNING.txt", "a") as model_warning_file:
Expand Down
55 changes: 5 additions & 50 deletions sidechainnet/utils/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import numpy as np
import torch
from Bio import Align
from tqdm import tqdm

Expand Down Expand Up @@ -78,12 +77,10 @@ def masks_match(pn, new):
found_a_match = False
best_alignment = None
best_idx = 0
if len(a) >= 50:
many_alignments = True
a = list(a)[:50]
else:
many_alignments = False
has_many_alignments = len(a) >= 50
for i, a0 in enumerate(a):
if has_many_alignments and i >= 50:
break
computed_mask = get_mask_from_alignment(a0)
if not best_mask:
best_mask = computed_mask
Expand All @@ -98,13 +95,13 @@ def masks_match(pn, new):
break
if found_a_match:
warning = "multiple alignments, found matching mask"
if many_alignments:
if has_many_alignments:
warning += ", many alignments"
return True, best_mask, best_alignment, warning
else:
mask = get_mask_from_alignment(a[0])
warning = "multiple alignments, mask mismatch"
if many_alignments:
if has_many_alignments:
warning += ", many alignments"
return True, mask, a[0], warning

Expand Down Expand Up @@ -145,44 +142,6 @@ def binary_mask_to_str(m):
return "".join(m)


def find_how_many_entries_can_be_directly_merged():
"""
Counts the number of entries that can be successfully aligned between the
sidechain dataset and the protein dataset.
"""
d = torch.load(
"/home/jok120/protein-transformer/data/proteinnet/casp12_200218_30.pt")
pn = torch.load("/home/jok120/proteinnet/data/casp12/torch/training_30.pt")
aligner = init_aligner()
total = 0
successful = 0
with open("merging_problems.csv", "w") as f, open("merging_success.csv",
"w") as sf:
for i, my_id in enumerate(tqdm(d["train"]["ids"])):
my_seq, pn_seq, pn_mask = d["train"]["seq"][i], pn[my_id][
"primary"], binary_mask_to_str(pn[my_id]["mask"])
my_seq = unmask_seq(d["train"]["ang"][i], my_seq)
result, computed_mask, alignment = can_be_directly_merged(
aligner, pn_seq, my_seq, pn_mask)
if result:
successful += 1
sf.write(",".join([my_id, my_seq, computed_mask]) + "\n")
else:
if pn_mask.count("+") < len(my_seq):
size_comparison = "<"
elif pn_mask.count("+") > len(my_seq):
size_comparison = ">"
else:
size_comparison = "=="
f.write(
f"{my_id}: (PN {size_comparison} Obs)\n{str(alignment)}")
f.write(f"PN Mask:\n{pn_mask}\n\n")
total += 1
print(
f"{successful} out of {total} ({successful / total}) sequences can be merged successfully."
)


def unmask_seq(ang, seq):
"""
Given an angle array that is padded with np.nans, applies this padding to
Expand All @@ -201,10 +160,6 @@ def unmask_seq(ang, seq):
return new_seq


if __name__ == '__main__':
find_how_many_entries_can_be_directly_merged()


def coordinate_iterator(coords, atoms_per_res):
"""Iterates over coordinates in a numpy array grouped by residue.
Expand Down

0 comments on commit 5e460c7

Please sign in to comment.