/
s4_joint_estimate.py
381 lines (312 loc) · 12.8 KB
/
s4_joint_estimate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
#!/usr/bin/env python
"""Jointly estimate heterozygosity and error rate.
"""
from typing import TypeVar, Tuple
from itertools import combinations
from loguru import logger
from scipy.optimize import minimize
import scipy.stats
import pandas as pd
import numpy as np
import numba
from ipyrad.assemble.base_step import BaseStep
from ipyrad.core.schema import Stats4
from ipyrad.core.progress_bar import AssemblyProgressBar
from ipyrad.assemble.utils import IPyradError, NoHighDepthClustersError
from ipyrad.assemble.clustmap_within_both import iter_clusters
Assembly = TypeVar("Assembly")
Sample = TypeVar("Sample")
logger = logger.bind(name="ipyrad")
class Step4(BaseStep):
"""Run the step4 estimation """
def __init__(self, data, force, quiet, ipyclient):
super().__init__(data, 4, quiet, force)
self.haploid = data.params.max_alleles_consens == 1
self.ipyclient = ipyclient
self.lbview = self.ipyclient.load_balanced_view()
def run(self):
"""Distribute optimization jobs and store results."""
jobs = {}
for sname, sample in self.samples.items():
args = (self.data, sample, self.haploid)
jobs[sname] = self.lbview.apply(optim2, *args)
msg = "inferring [H, E]"
prog = AssemblyProgressBar(jobs, msg, 4, self.quiet)
prog.block()
prog.check()
# collect updated samples and save to JSON
for sname, result in prog.results.items():
sample = self.data.samples[sname]
sample.state = 4
sample._clear_old_results()
sample.stats_s5 = None # sample.stats_s5 = Stats5()
sample.stats_s4 = Stats4(
hetero_est=result[0],
error_est=result[1],
min_depth_stat_during_step4=self.data.params.min_depth_statistical,
)
self.data.save_json()
# write to stats file
statsdf = pd.DataFrame(
index=sorted(self.samples),
columns=["hetero_est", "error_est"],
)
# update samples on the Assembly object
for sname in self.samples:
stats = self.data.samples[sname].stats_s4.dict()
for i in statsdf.columns:
statsdf.loc[sname, i] = stats[i]
# log and save stats
logger.info("\n" + statsdf.to_string())
outfile = self.data.stepdir / "s4_joint_estimate.txt"
with open(outfile, 'w', encoding="utf-8") as out:
out.write(statsdf.to_string())
def optim2(data: Assembly, sample: Sample, haploid: bool) -> Tuple[float, float]:
"""Maximum likelihood optimization with scipy."""
# get array of all clusters data: (maxclusts, maxlen, 4)
try:
stacked = get_stack_array(data, sample)
# if no high depth cluster the sample can still proceed, but it
# will not contribute to estimating the avg H,E and will receive
# only low depth calls in step 5 unless params are changed.
except NoHighDepthClustersError:
return np.nan, np.nan
# get base frequencies
bfreqs = stacked.sum(axis=0) / float(stacked.sum())
# count each unique site count pattern
ustacks, counts = np.unique(stacked, axis=0, return_counts=True)
# fit haploid or diploid model to counts
if haploid:
fit = minimize(
get_haploid_loglik,
x0=(0.001,),
args=(bfreqs, ustacks, counts),
method="L-BFGS-B",
bounds=[1e-6, 0.1],
)
hetero = 0.0
error = fit.x[0]
else:
fit = minimize(
nget_diploid_loglik,
x0=(0.01, 0.001),
args=(bfreqs, ustacks, counts),
method="L-BFGS-B",
bounds=[(1e-6, 0.1), (1e-6, 0.1)],
)
hetero, error = fit.x
return hetero, error
def get_haploid_loglik(errors, bfreqs, ustacks, counts):
"""Log likelihood score given values [E]."""
hetero = 0.
lik1 = ((1. - hetero) * likelihood1(errors, bfreqs, ustacks))
liks = lik1
logliks = np.log(liks[liks > 0]) * counts[liks > 0]
score = -logliks.sum()
return score
def nget_diploid_loglik(
pstart: Tuple[float, float],
bfreqs: np.ndarray,
ustacks: np.ndarray,
counts: np.ndarray,
) -> float:
"""Return Log likelihood score given values [H,E]"""
hetero, errors = pstart
lik1 = (1. - hetero) * likelihood1(errors, bfreqs, ustacks)
lik2 = (hetero) * nlikelihood2(errors, bfreqs, ustacks)
liks = lik1 + lik2
logliks = np.log(liks[liks > 0]) * counts[liks > 0]
score = -logliks.sum()
return score
def likelihood1(errors, bfreqs, ustacks):
"""Probability homozygous."""
# make sure base_frequencies are in the right order
# print uniqstackl.sum()-uniqstack, uniqstackl.sum(), 0.001
# totals = np.array([ustacks.sum(axis=1)]*4).T
totals = np.array([ustacks.sum(axis=1)] * 4).T
prob = scipy.stats.binom.pmf(totals - ustacks, totals, errors)
lik1 = np.sum(bfreqs * prob, axis=1)
return lik1
def nlikelihood2(errors, bfreqs, ustacks):
"""Calls nblik2_build and lik2_calc for a given err."""
one = [2. * bfreqs[i] * bfreqs[j] for i, j in combinations(range(4), 2)]
four = 1. - np.sum(bfreqs**2)
tots, twos, thrs = nblik2_build(ustacks)
res2 = lik2_calc(errors, one, tots, twos, thrs, four)
return res2
@numba.jit(nopython=True)
def nblik2_build(ustacks):
"""JIT'd function builds array that can be used to calc binom pmf
"""
# fill for pmf later
tots = np.empty((ustacks.shape[0], 1))
twos = np.empty((ustacks.shape[0], 6))
thrs = np.empty((ustacks.shape[0], 6, 2))
# fill big arrays
for idx in range(ustacks.shape[0]):
ust = ustacks[idx]
tot = ust.sum()
tots[idx] = tot
# fast filling of arrays
i = 0
for jdx in range(4):
for kdx in range(4):
if jdx < kdx:
twos[idx][i] = tot - ust[jdx] - ust[kdx]
thrs[idx][i] = ust[jdx], ust[jdx] + ust[kdx]
i += 1
return tots, twos, thrs
def lik2_calc(err, one, tots, twos, thrs, four):
"""
vectorized calc of binom pmf on large arrays
"""
# calculate twos
_twos = scipy.stats.binom.pmf(twos, tots, 0.5)
# calculate threes
_thrs = thrs.reshape(thrs.shape[0] * thrs.shape[1], thrs.shape[2])
_thrs = scipy.stats.binom.pmf(_thrs[:, 0], _thrs[:, 1], (2. * err) / 3.)
_thrs = _thrs.reshape(thrs.shape[0], 6)
# calculate return sums
return np.sum(one * _twos * (_thrs / four), axis=1)
def recal_hidepth_cluster_stats(
data: Assembly, sample: Sample, majrule: bool = False,
) -> Tuple[np.ndarray, int]:
"""Return a mask for cluster depths, and the max frag length.
This is useful to run first to get a sense of the depths and lens
given the current mindepth param settings.
Note: this func is used in both steps 4 and 5.
"""
# otherwise calculate depth again given the new mindepths settings.
depths = [] # read depth: sum of 'sizes'
clens = [] # lengths of clusters
for clust in iter_clusters(sample.files.clusters, gzipped=True):
names = clust[::2]
sizes = [int(i.split(";")[-2][5:]) for i in names]
depths.append(sum(sizes))
clens.append(len(clust[1].strip()))
clens, depths = np.array(clens), np.array(depths)
# get mask of clusters that are hidepth
if majrule:
keep = depths >= data.params.min_depth_majrule
else:
keep = depths >= data.params.min_depth_statistical
# get frag lenths of clusters that are hidepth
lens_above_st = clens[keep]
# print(f"{sample.name}, {keep.shape}, {depths} {depths >=data.params.min_depth_majrule} {data.params.min_depth_majrule} {lens_above_st}, {clens}")
# calculate frag length from hidepth lens
try:
maxfrag = int(4 + lens_above_st.mean() + (2. * lens_above_st.std()))
except Exception as inst:
# this exception will raise in step 4 and be caught to print an
# warning message and then will set the samples H,E estimates to
# nan. In step 5 the nans will be caught above...
print(
f"sample {sample.name} has no clusters above the "
"`min_depth_statistical` parameter setting, and thus will "
"include only low depth base calls in step 5.")
raise NoHighDepthClustersError(f"{sample.name}") from inst
return keep, maxfrag
def get_stack_array(data: Assembly, sample: Sample, size: int = 10_000) -> np.ndarray:
"""Stacks clusters into arrays using at most 10K clusters.
Uses maxlen to limit the end of arrays, and also masks the first
and last 6 bp from each read since these are more prone to
alignmentn errors in denovo assemblies are will likely be
trimmed later.
"""
# only use clusters with depth > min_depth_statistical for param estimates
stat_mask, maxfrag = recal_hidepth_cluster_stats(data, sample)
# sample many (e.g., 10_000) clusters to use for param estimation.
maxclusts = min(size, stat_mask.sum())
maxfrag = min(150, maxfrag)
dims = (maxclusts, maxfrag, 4)
stacked = np.zeros(dims, dtype=np.uint64)
# fill stacked
clustgen = iter_clusters(sample.files.clusters, gzipped=True)
sidx = 0 # stored row number
for idx, clust in enumerate(clustgen):
# skip masked (lowdepth) clusters
if not stat_mask[idx]:
continue
# if maxclusts are stored then no need to do more.
if sidx >= maxclusts:
continue
# parse cluster and expand derep depths
names = clust[0::2]
seqs = clust[1::2]
reps = [int(i.split(";")[-2][5:]) for i in names]
sseqs = [list(i.strip()) for i in seqs]
arr = np.concatenate([[seq] * rep for seq, rep in zip(sseqs, reps)])
# select at most random 500 reads in a cluster
if arr.shape[0] > 500:
ridxs = range(arr.shape[0])
ridxs = np.random.choice(ridxs, size=500, replace=False)
arr = arr[ridxs]
# mask edges, indels, and pair inserts and remove empty columns.
arr[:, :8] = "N"
arr[:, -8:] = "N"
arr[arr == "-"] = "N"
arr[:, np.any(arr == "n", axis=0)] = "N"
arr = arr[:, ~np.all(arr == "N", axis=0)]
# store in stack shape=(nsites, 4)
catg = [np.sum(arr == i, axis=0) for i in list("CATG")]
catg = np.array(catg, dtype=np.uint64).T
# limit stored catg data to maxfrag len
stacked[sidx, :catg.shape[0], :] = catg[:maxfrag, :]
sidx += 1
# drop the empty rows in case there are fewer loci than the size of array
newstack = stacked[stacked.sum(axis=2) > 0]
assert not np.any(newstack.sum(axis=1) == 0), "no zero rows"
return newstack
def make_chunk_files(data, sample, keep_mask, chunksize=5000) -> None:
"""Split cluster file into <chunksize> hidepth clusters each.
THIS IS CALLED IN STEP 5.
Parameters
----------
data: Assembly
params and samples
sample: Sample
filepaths and stats
keep_mask: np.ndarray
Used to filter which consens reads will be included in chunk
files used for parallel processing.
chunksize: int
Chunksize used for breaking the problem into parallel chunks.
The chunks are unzipped. Default is 5K, but is set to very
large when using debug() function to not split files.
"""
# open to cluster generator
clusters = iter_clusters(sample.files.clusters, gzipped=True)
# load in high depth clusters and then write to chunk
chunk = []
sidx = 0
for idx, clust in enumerate(clusters):
if not keep_mask[idx]:
continue
chunk.append("".join(clust))
# write to chunk file and reset
if len(chunk) == int(chunksize):
end = sidx + len(chunk)
handle = data.tmpdir / f"{sample.name}_chunk_{sidx}_{end}"
with open(handle, 'w', encoding="utf-8") as out:
out.write("//\n//\n".join(chunk) + "//\n//\n")
chunk = []
sidx += int(chunksize)
# write any remaining
if chunk:
end = sidx + len(chunk)
handle = data.tmpdir / f"{sample.name}_chunk_{sidx}_{end}"
with open(handle, 'w', encoding="utf-8") as out:
out.write("//\n//\n".join(chunk) + "//\n//\n")
if __name__ == "__main__":
import ipyrad as ip
ip.set_log_level("DEBUG", log_file="/tmp/test.log")
# for JSON in ["/tmp/TEST1.json", "/tmp/TEST5.json"]:
# TEST = ip.load_json(JSON)
# TEST.run("4", force=True, quiet=True)
TEST = ip.load_json("../../pedtest/NEW.json").branch("NEW2")
TEST.run("3", force=True, quiet=False)
print(TEST.stats)
# TEST = ip.load_json("../../pedtest/NEW.json")
# TEST.run("4", force=True, quiet=False)
# TEST = ip.load_json("/tmp/TEST3.json")
# TEST.run("4", force=True, quiet=False)