Skip to content

Commit

Permalink
Merge branch 'read_trimming_fixes' into 'dev'
Browse files Browse the repository at this point in the history
read trimming and smolecule tweaks and faster processing of short regions and other changes in prep for medaka tr entry point.

See merge request research/medaka!549
  • Loading branch information
mwykes committed Aug 2, 2023
2 parents ae8a369 + 291384c commit 71dbb67
Show file tree
Hide file tree
Showing 12 changed files with 214 additions and 112 deletions.
12 changes: 9 additions & 3 deletions medaka/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,16 @@ def parasail_to_sam(result, seq):
rstart = int(first[0])
else:
pre = '{}{}'.format(clip, prefix)

mid = cigstr[len(prefix):]
end_clip = len(seq) - result.end_query - 1
# if last cigar op is I, result.end_query will not consider the insertion
# as aligned hence recalculate end_query here.
end_query = result.end_query + 1 # end_query is inclusive
last = next(cigar_ops_from_end(cigstr))
if last[1] == 'I':
end_query -= int(last[0])
cigstr = cigstr[:-len(''.join(last))]
end_clip = len(seq) - end_query
suf = '{}S'.format(end_clip) if end_clip > 0 else ''
mid = cigstr[len(prefix):]
new_cigstr = ''.join((pre, mid, suf))
return rstart, new_cigstr

Expand Down
2 changes: 1 addition & 1 deletion medaka/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def __eq__(self, other):
for field in self._fields:
s = getattr(self, field)
o = getattr(other, field)
if type(s) != type(o):
if type(s) is not type(o):
return False
elif isinstance(s, np.ndarray):
if (s.shape != o.shape or np.any(s != o)):
Expand Down
14 changes: 8 additions & 6 deletions medaka/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ def __init__(self, filenames, threads=4):
self._index = None
self._extract_sample_registries()
self.metadata = self._load_metadata()
# in cases where where we have a single file containing samples over
# many contigs we might as well open the file once rather than once
# per contig.
self._ds = DataStore(filenames[0])

def _extract_sample_registries(self):
"""."""
Expand Down Expand Up @@ -503,7 +507,7 @@ def sorter(x):

return ref_names_ordered

def yield_from_feature_files(self, regions=None, samples=None, workers=8):
def yield_from_feature_files(self, regions=None, samples=None):
"""Yield `medaka.common.Sample` objects from one or more feature files.
:regions: list of `medaka.common.Region` s for which to yield samples.
Expand Down Expand Up @@ -533,9 +537,7 @@ def yield_from_feature_files(self, regions=None, samples=None, workers=8):
samples.append(
(sample['sample_key'], sample['filename']))
# yield samples reusing filehandle where possible
ds, ds_fname = None, None
for key, fname in samples:
if fname != ds_fname:
ds = DataStore(fname)
ds_fname = fname
yield ds.load_sample(key)
if fname != self._ds.filename:
self._ds = DataStore(fname)
yield self._ds.load_sample(key)
6 changes: 4 additions & 2 deletions medaka/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _process_region(reg):
def get_trimmed_reads(
region, bam, dtype_prefixes=None, region_split=750, chunk_overlap=150,
workers=8, tag_name=None, tag_value=None, keep_missing=False,
partial=True, num_qstrat=1, read_group=None):
partial=True, num_qstrat=1, read_group=None, min_mapq=1):
"""Fetch reads trimmed to a region.
Overlapping chunks of the input region will be produced, with each chunk
Expand All @@ -264,6 +264,8 @@ def get_trimmed_reads(
:param keep_missing: whether to keep reads when tag is missing.
:param partial: whether to keep reads which don't fully span the region.
:param num_qstrat: number of layers for qscore stratification.
:param read_group: str, bam read group for reads to keep.
:param min_mapq: minimum mapping quality for reads to keep.
:returns: iterator of lists of trimmed reads.
"""
Expand All @@ -278,7 +280,7 @@ def _process_region(reg):
region_str = '{}:{}-{}'.format(reg.ref_name, reg.start + 1, reg.end)
stuff = lib.PY_retrieve_trimmed_reads(
region_str.encode(), bam.encode(), num_dtypes, dtypes,
tag_name, tag_value, keep_missing, partial, read_group,
tag_name, tag_value, keep_missing, partial, read_group, min_mapq,
)
# last string is reference
seqs = [(False, ffi.string(stuff.seqs[stuff.n_seqs - 1]).decode())]
Expand Down
62 changes: 35 additions & 27 deletions medaka/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ def predict(args):
logger.info('Processing region(s): {}'.format(
' '.join(str(r) for r in bam_regions)))

# split out regions which are smaller than the chunk
# size for processing later.
regions, remainder_regions = [], []
for region in bam_regions:
if region.size < args.chunk_len:
remainder_regions.append(region)
else:
# Split overly long regions to maximum size so as to not create
# massive feature matrices
if region.size > args.bam_chunk:
# chunk_ovlp is mostly used in overlapping pileups (which
# generally end up being expanded compared to the draft
# coordinate system)
regions.extend(region.split(
args.bam_chunk, overlap=args.chunk_ovlp,
fixed_size=False))
else:
regions.append(region)

logger.info("Using model: {}.".format(args.model))
with medaka.models.open_model(args.model) as model_store:
feature_encoder = model_store.get_meta('feature_encoder')
Expand All @@ -140,34 +159,24 @@ def predict(args):
"`--disable_cudnn. If OOM (out of memory) errors are found "
"please reduce batch size.")

# Split overly long regions to maximum size so as to not create
# massive feature matrices
regions = []
for region in bam_regions:
if region.size > args.bam_chunk:
# chunk_ovlp is mostly used in overlapping pileups (which
# generally end up being expanded compared to the draft
# coordinate system)
regs = region.split(
args.bam_chunk, overlap=args.chunk_ovlp,
fixed_size=False)
else:
regs = [region]
regions.extend(regs)
bam_pool = medaka.features.BAMHandler(args.bam)

logger.info("Processing {} long region(s) with batching.".format(
len(regions)))
model = model_store.load_model(time_steps=args.chunk_len)
if len(regions) > 0:

bam_pool = medaka.features.BAMHandler(args.bam)
logger.info("Processing {} long region(s) with batching.".format(
len(regions)))
model = model_store.load_model(time_steps=args.chunk_len)

# the returned regions are those where the pileup width is smaller
# than chunk_len (which could happen due to gaps in coverage)
remainder_regions_depth = run_prediction(
args.output, bam_pool, regions, model, feature_encoder,
args.chunk_len, args.chunk_ovlp,
batch_size=args.batch_size, save_features=args.save_features,
bam_workers=args.bam_workers)

# the returned regions are those where the pileup width is smaller than
# chunk_len
remainder_regions = run_prediction(
args.output, bam_pool, regions, model, feature_encoder,
args.chunk_len, args.chunk_ovlp,
batch_size=args.batch_size, save_features=args.save_features,
bam_workers=args.bam_workers)
# run_prediction returns [(region, pileup width)]
remainder_regions.extend([r[0] for r in remainder_regions_depth])

# short/remainder regions: just do things without chunking. We can do
# this here because we now have the size of all pileups (and know they
Expand All @@ -183,9 +192,8 @@ def predict(args):
# creating a thread that does not die for every retrace
model = model_store.load_model(time_steps=None)
model.run_eagerly = True
remainers = [r[0] for r in remainder_regions]
new_remainders = run_prediction(
args.output, bam_pool, remainers, model,
args.output, bam_pool, remainder_regions, model,
feature_encoder,
args.chunk_len, args.chunk_ovlp, # these won't be used
batch_size=1, # everything is a different size, cant batch
Expand Down
66 changes: 48 additions & 18 deletions medaka/smolecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
class Read(object):
"""Functionality to extract information from a read with subreads."""

def __init__(self, name, subreads, initialize=False):
def __init__(self, name, subreads):
"""Initialize repeat read analysis.
:param name: read name.
:param subreads: list of subreads.
:param initialize: initialize subread alignments.
"""
self.name = name
Expand All @@ -45,6 +44,31 @@ def __init__(self, name, subreads, initialize=False):
self._initialized = False
# has a consensus been run
self.consensus_run = False
# enable use of alternative aligners by changing name.
# parasail align functions not pickleable so cause issues with
# multiprocessing if set as a direct attribute.
# instead, create the aligner on demand.
self.parasail_aligner_name = 'sw_trace_striped_16'

@property
def parasail_aligner_name(self):
"""Return parasail alignment function name (str)."""
return self._parasail_aligner_name

@parasail_aligner_name.setter
def parasail_aligner_name(self, value):
if hasattr(parasail, value) and callable(getattr(parasail, value)):
self._parasail_aligner_name = value
else:
raise ValueError(f'{value} is not a valid parasail function.')

@property
def parasail_aligner(self):
"""Return functools.partial-wrapped parasail alignment function."""
return functools.partial(
getattr(parasail, self.parasail_aligner_name),
open=8, extend=4, matrix=parasail.dnafull
)

def initialize(self):
"""Calculate initial alignments of subreads to scaffold read."""
Expand Down Expand Up @@ -151,18 +175,23 @@ def nseqs(self):
"""Return the number of subreads contained in the read."""
return len(self.subreads)

def poa_consensus(self, additional_seq=None, method='spoa'):
def poa_consensus(self, method='spoa'):
"""Create a consensus sequence for the read."""
self.initialize()
seqs = list()
for orient, subread in zip(*self.interleaved_subreads):
if orient:
seq = subread.seq
else:
seq = medaka.common.reverse_complement(subread.seq)
seqs.append(seq)
if method == 'spoa':
seqs = list()
for orient, subread in zip(*self.interleaved_subreads):
if orient:
seq = subread.seq
else:
seq = medaka.common.reverse_complement(subread.seq)
seqs.append(seq)
consensus_seq, _ = spoa.poa(seqs, genmsa=False)
elif method == 'abpoa':
import pyabpoa as pa
abpoa_aligner = pa.msa_aligner(aln_mode='g')
result = abpoa_aligner.msa(seqs, out_cons=True, out_msa=False)
consensus_seq = result.cons_seq[0]
else:
raise ValueError('Unrecognised method: {}.'.format(method))
self.consensus = consensus_seq
Expand All @@ -182,10 +211,8 @@ def orient_subreads(self):
alignments = []
for sr in self.subreads:
rc_seq = medaka.common.reverse_complement(sr.seq)
result_fwd = parasail.sw_trace_striped_16(
sr.seq, self.consensus, 8, 4, parasail.dnafull)
result_rev = parasail.sw_trace_striped_16(
rc_seq, self.consensus, 8, 4, parasail.dnafull)
result_fwd = self.parasail_aligner(sr.seq, self.consensus)
result_rev = self.parasail_aligner(rc_seq, self.consensus)
is_fwd = result_fwd.score > result_rev.score
self._orient.append(is_fwd)
result = result_fwd if is_fwd else result_rev
Expand Down Expand Up @@ -218,8 +245,8 @@ def align_to_template(self, template, template_name):
seq = sr.seq
else:
seq = medaka.common.reverse_complement(sr.seq)
result = parasail.sw_trace_striped_16(
seq, template, 8, 4, parasail.dnafull)
result = self.parasail_aligner(seq, template)

if result.cigar.beg_ref >= result.end_ref or \
result.cigar.beg_query >= result.end_query:
# unsure why this can happen
Expand Down Expand Up @@ -280,11 +307,14 @@ def write_bam(fname, alignments, header, bam=True):
"""
mode = 'wb' if bam else 'w'
if isinstance(header, dict):
header = pysam.AlignmentHeader.from_dict(header)
with pysam.AlignmentFile(fname, mode, header=header) as fh:
for ref_id, subreads in enumerate(alignments):
for subreads in alignments:
for aln in sorted(subreads, key=lambda x: x.rstart):
a = medaka.align.initialise_alignment(
aln.qname, ref_id, aln.rstart, aln.seq,
aln.qname, header.get_tid(aln.rname),
aln.rstart, aln.seq,
aln.cigar, aln.flag)
fh.write(a)
if mode == 'wb':
Expand Down
Loading

0 comments on commit 71dbb67

Please sign in to comment.