forked from jgurtowski/ectools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pb_correct.py
217 lines (175 loc) · 7.72 KB
/
pb_correct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#!/usr/bin/env python
import sys
from io import *
from seqio import fastaIterator
from operator import itemgetter
from itertools import groupby, repeat, izip_longest, imap, count, chain
from collections import namedtuple
from cov import getMarkedRanges
from misc import create_enum
import copy
if not len(sys.argv) == 7:
print "pb_correct.py in.fa in.snps in.showcoords clr_id(float) min_read_length out_prefix"
sys.exit(1)
CLR_ID_CUTOFF = float(sys.argv[4])
MIN_READ_LENGTH = int(sys.argv[5])
PileupEntry = namedtuple("PileupEntry", ["index","base","snps","utg","clr"])
CovStat = {"COVERED":"COVERED", "UNCOVERED":"UNCOVERED", "JOINED":"JOINED"}
class CoverageRange:
def __init__(self, b, e, pid, covstat):
self.begin = b
self.end = e
self.pctid = pid
self.covstat = covstat
def __repr__(self):
return "CoverageRange(%d,%d,%f,%s)" % (self.begin,self.end,
self.pctid, self.covstat)
def __eq__(self,other):
return (self.begin == other.begin and self.end == other.end
and self.pctid == other.pctid and self.covstat == other.covstat)
#correction logic
def correct_base(pentry):
'''Takes a pileup entry and returns corrected base(s)
With any warnings
('bases','warnings','clr_range')
'''
#filter snps
filt_snps = filter(lambda s: s.qname == pentry.utg,
[] if pentry.snps == None else pentry.snps)
#nothing
if len(filt_snps) == 0:
return (pentry.base, None, pentry.clr)
ssnp = filt_snps[0]
if len(filt_snps) > 1:
#better all be insertions
if all(map(lambda p: p.sbase == '.', filt_snps)):
#show-snps is strange, on reverse alignments
#it outputs indels in the forward direction
if ssnp.r2 == -1:
filt_snps.reverse()
return (pentry.base+"".join(map(lambda f: f.qbase,filt_snps)), None, pentry.clr)
else:
#not everything is an insertion, add the insertions and
#return warning
return (pentry.base+
"".join(map(lambda f: f.qbase if f.sbase == "." else "",filt_snps)),
"Multiple SNPs, Not all were Insertions", pentry.clr)
elif ssnp.sbase == '.': #single insertion
return (pentry.base+ssnp.qbase, None,pentry.clr)
elif ssnp.qbase == '.': #Deletion
return ("", None if ssnp.sbase == pentry.base else "Mismatched Bases", pentry.clr)
else: #Mismatch
return (ssnp.qbase, None if ssnp.sbase == pentry.base else "Mismatched Bases", pentry.clr)
def range_size(arange):
return arange.end - arange.begin
def get_contiguous_ranges(ranges):
'''Gets Contiguous Ranges from a list of CoverageRanges
Returns a new list of CoverageRanges updated with contiguous
ranges and their weighted pct_id
'''
if len(ranges) == 0:
return []
out = [copy.deepcopy(ranges[0])]
for i in range(1,len(ranges)):
if ranges[i].begin - ranges[i-1].end == 1:
sp = range_size(out[-1])
sc = range_size(ranges[i])
out[-1].pctid = ((sp * out[-1].pctid) +
(sc * ranges[i].pctid)) / (sp+sc)
out[-1].end = ranges[i].end
out[-1].covstat = CovStat["JOINED"]
else:
out.append(copy.deepcopy(ranges[i]))
return out
rfh = open(sys.argv[1])
sfh = open(sys.argv[2])
afh = open(sys.argv[3])
pout = open(sys.argv[6] +".cor.pileup", "w")
corout = open(sys.argv[6] +".cor.fa", "w")
alignment_it = lineRecordIterator(afh, NucRecord, NucRecordTypes)
snp_it = lineRecordIterator(sfh, NucSNPRecord, NucSNPRecordTypes)
reads = dict(map(lambda r : (str(r.name), str(r.seq)), fastaIterator(rfh)))
alignments = dict(map(lambda (n,a): (n,list(a)),
groupby(alignment_it, lambda x: x.sname)))
for pbname, snp_entries in groupby(snp_it, lambda x: x.sname):
warnings = []
pblen = len(reads[pbname])
##no alignments for this pb read
if pbname not in alignments:
continue
##create ranges of accepted alignments
accept_alignment_ranges = [None] * pblen
#alignments[pbname].sort(key=lambda a: (a.send-a.sstart) * pow(a.pctid/100.0,2))
alignments[pbname].sort(key=lambda a: (a.send-a.sstart))
for alignment in alignments[pbname]:
for p in range(alignment.sstart-1,alignment.send):
accept_alignment_ranges[p] = alignment.qname
##
##find clr ranges
##
#find ranges
covered_ranges = map(lambda (s,e): CoverageRange(s,e,1.0,CovStat["COVERED"]),
getMarkedRanges(map(lambda c: 1 if not c == None else 0 , accept_alignment_ranges)))
uncovered_ranges = map(lambda (s,e): CoverageRange(s,e,0.7,CovStat["UNCOVERED"]),
getMarkedRanges(map(lambda c: 1 if c == None else 0 , accept_alignment_ranges)))
#remove uncorrected ends
uncovered_ranges = filter(lambda x: not (x.begin == 0 or x.end == pblen-1),uncovered_ranges)
joined_ranges = sorted(covered_ranges + uncovered_ranges, key=lambda x: x.begin)
#find the clr ranges
while True:
clr_ranges = get_contiguous_ranges(joined_ranges)
if( all(map(lambda y: y.pctid > CLR_ID_CUTOFF,clr_ranges))):
break
for cr in clr_ranges:
#skip clr ranges that are ok
if cr.pctid > CLR_ID_CUTOFF:
continue
#get uncorrected subranges for larger clr range
subranges = filter(lambda x: x.covstat == CovStat["UNCOVERED"]
and x.begin >= cr.begin and x.end <= cr.end , joined_ranges)
del joined_ranges[joined_ranges.index(max(subranges, key=lambda y: y.end - y.begin))]
clr_ranges = filter(lambda c: range_size(c) > MIN_READ_LENGTH, clr_ranges)
#mark clr ranges in array
clr_range_array = [None] * pblen
for clr_range in clr_ranges:
for p in range(clr_range.begin, clr_range.end+1):
clr_range_array[p] = str("%d_%d" % (clr_range.begin,clr_range.end))
#build a list of snps
merged_snps = [None] * pblen
for pos, snps in groupby(snp_entries, lambda y: y.spos):
merged_snps[pos-1] = list(snps)
#build the pileup
pileup = map(PileupEntry._make,
izip(count(),
reads[pbname],
merged_snps,
accept_alignment_ranges,
clr_range_array))
#correct the bases
corrected_data = map(correct_base, pileup)
#how to print the snps (str format)
snp_str = lambda f : "None" if f == None else "%d,%s,%s,%s" % (f.spos,f.sbase,f.qbase,f.qname)
#build pileup string for debugging
pileup_str_list = map(lambda x: "\t".join([
str(x.index), x.base, str(x.utg),
"|".join(
map(snp_str, [None] if x.snps == None else x.snps))]),pileup)
#add warnings to pileup
pileup_str_list = map(lambda z : "\t".join(map(str,z)),
izip(pileup_str_list,
imap(itemgetter(1), corrected_data),
imap(itemgetter(0), corrected_data)
))
pbname_corrected_base = pbname + "_corrected2"
for clr_name, clr_group in groupby(corrected_data, itemgetter(2)):
#skip non clear ranges
if clr_name == None:
continue
pbname_corrected = pbname_corrected_base + "/" + clr_name
corout.write( ">%s\n%s\n" % (pbname_corrected,"".join(imap(itemgetter(0), clr_group))))
pout.write( ">%s\n%s\n" % (pbname_corrected_base,"\n".join(pileup_str_list)))
rfh.close()
sfh.close()
afh.close()
corout.close()
pout.close()