-
Notifications
You must be signed in to change notification settings - Fork 1
/
extract_panel_genes_from_vcf.py
264 lines (223 loc) · 7.78 KB
/
extract_panel_genes_from_vcf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import subprocess
import sys
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TextIO, Set, Dict, NamedTuple, Optional, List
import click
from cyvcf2 import VCF, Writer
from intervaltree import IntervalTree
from loguru import logger
TRANSLATE = str.maketrans("ATGC", "TACG")
def revcomp(s: str) -> str:
return s.upper().translate(TRANSLATE)[::-1]
class Genotype(NamedTuple):
allele1: int
allele2: int
def is_null(self) -> bool:
"""Is the genotype null. i.e. ./."""
return self.allele1 == -1 and self.allele2 == -1
def is_hom(self) -> bool:
"""Is the genotype homozygous"""
if self.is_null():
return False
if self.allele1 == -1 or self.allele2 == -1:
return True
return self.allele1 == self.allele2
def is_het(self) -> bool:
"""Is the genotype heterozyhous"""
return not self.is_null() and not self.is_hom()
def is_hom_ref(self) -> bool:
"""Is genotype homozygous reference?"""
return self.is_hom() and (self.allele1 == 0 or self.allele2 == 0)
def is_hom_alt(self) -> bool:
"""Is genotype homozygous alternate?"""
return self.is_hom() and (self.allele1 > 0 or self.allele2 > 0)
def alt_index(self) -> Optional[int]:
"""If the genotype is homozygous alternate, returns the 0-based index of the
alt allele in the alternate allele array.
"""
if not self.is_hom_alt():
return None
return max(self.allele1, self.allele2) - 1
def allele_index(self) -> Optional[int]:
"""The index of the called allele"""
if self.is_hom_ref() or self.is_null():
return 0
elif self.is_hom_alt():
return self.alt_index() + 1
else:
raise NotImplementedError(f"Het Genotype is unexpected: {self}")
@staticmethod
def from_arr(arr: List[int]) -> "Genotype":
alleles = [a for a in arr if type(a) is int]
if len(alleles) < 2:
alleles.append(-1)
return Genotype(*alleles)
def extract_genes_from_panel(stream: TextIO) -> Set[str]:
genes = set()
for line in map(str.rstrip, stream):
if not line:
continue
fields = line.split("\t")
if gene := fields[0]:
genes.add(gene)
return genes
def extract_intervals_for_genes_from_gff(
genes: Set[str], gff_stream: TextIO, padding: int = 0
) -> IntervalTree:
intervals = []
for row in map(str.rstrip, gff_stream):
if row.startswith("#") or not row:
continue
fields = row.split("\t")
if fields[2].lower() != "gene":
continue
attributes = attributes_dict_from_str(fields[8])
name = attributes.get("gene", attributes.get("Name", None))
if name is None:
raise ValueError(f"No gene/Name attribute for ID {attributes['ID']}")
continue
if name not in genes:
continue
start = (int(fields[3]) - 1) - padding # GFF start is 1-based inclusive
end = int(fields[4]) + padding # GFF end is 1-based inclusive
strand = fields[6]
intervals.append((start, end, (name, strand)))
return IntervalTree.from_tuples(intervals)
def attributes_dict_from_str(s: str) -> Dict[str, str]:
d = dict()
for pairs in s.split(";"):
k, v = pairs.split("=")
if k in d:
raise KeyError(f"Attribute key {k} appears twice")
d[k] = v
return d
##########################################################
# MAIN
##########################################################
@click.command()
@click.option(
"-P",
"--padding",
default=100,
type=int,
help="Number of bases up and downstream to include for a gene",
show_default=True,
)
@click.option(
"-a",
"--apply-filters",
is_flag=True,
help="Whether to skip variants that have a FILTER value that isn't PASS or '.'",
)
@click.option(
"--only-alt",
is_flag=True,
help="Only output variants that have an alternate allele genotype call. The default behaviour is apply all ALTs regardless of whether they are called by any sample",
)
@click.option(
"--adjust-pos",
is_flag=True,
help="Adjust the position of variants that occur in reverse strand genes so they are with respect to the forward strand. This will also reverse complement the alleles",
)
@click.option(
"-i",
"--panel",
help="The panel/catalogue the VCF will be used for. Only genes found in this panel will be extracted.",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
)
@click.option(
"-g",
"--gff",
help="The GFF3 annotation for the reference genome",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
)
@click.option(
"-b",
"--vcf",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
help="The VCF to extract genes from",
)
@click.option(
"-o",
"--out-vcf",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
help="The output VCF path",
)
def main(
padding: int,
apply_filters: bool,
only_alt: bool,
adjust_pos: bool,
panel: Path,
gff: Path,
vcf: Path,
out_vcf: Path,
):
logger.info("Extracting gene names from panel...")
with open(panel) as istream:
genes = extract_genes_from_panel(istream)
logger.success(f"Extracted {len(genes)} genes from the panel")
logger.info("Extracting intervals for genes from GFF...")
with open(gff) as istream:
ivtree = extract_intervals_for_genes_from_gff(genes, istream, padding)
logger.success(f"Intervals extracted for {len(ivtree)} genes")
logger.info(
"Extracting those VCF records that fall within the gene intervals and altering "
"their CHROM and POS accordingly..."
)
vcf_reader = VCF(str(vcf))
for iv in ivtree:
vcf_reader.add_to_header(f"##contig=<ID={iv.data[0]},length={iv.end-iv.begin}>")
with TemporaryDirectory() as tmpdirname:
tmpvcf = str(Path(tmpdirname) / "tmp.vcf")
vcf_writer = Writer(tmpvcf, tmpl=vcf_reader)
for record in vcf_reader:
if apply_filters and record.FILTER is not None:
continue
gt = Genotype.from_arr(record.genotypes[0])
if only_alt and not gt.is_hom_alt():
continue
ivs = ivtree[record.start]
if len(ivs) > 1:
logger.warning(
f"VCF record at POS {record.POS} overlaps with more than 1 gene: {ivs}. "
f"Duplicating record - one for each gene..."
)
original_record_start = record.start
original_ref = record.REF
original_alts = record.ALT
for iv in ivs:
chrom, strand = iv.data
if adjust_pos and strand == "-":
norm_pos = (iv.end - original_record_start) - 1
ref = revcomp(original_ref)
alts = [revcomp(s) for s in original_alts]
else:
norm_pos = original_record_start - iv.begin
ref = original_ref
alts = original_alts
record.set_pos(norm_pos)
record.CHROM = chrom
record.REF = ref
record.ALT = alts
vcf_writer.write_record(record)
vcf_writer.close()
logger.info("Sorting VCF...")
subprocess.run(
[
"bcftools",
"sort",
"-T",
tmpdirname,
"-o",
str(out_vcf),
tmpvcf,
],
check=True,
stderr=sys.stderr,
)
vcf_reader.close()
logger.success("Done!")
if __name__ == "__main__":
main()