-
Notifications
You must be signed in to change notification settings - Fork 1
/
constraint.py
1565 lines (1380 loc) · 67 KB
/
constraint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import logging
from typing import Dict, List, Tuple, Union
import hail as hl
from gnomad.resources.resource_utils import DataException
from gnomad.utils.file_utils import file_exists
from gnomad_lof.constraint_utils.generic import annotate_variant_types
from rmc.resources.basics import temp_path
from rmc.utils.generic import get_coverage_correction_expr
logging.basicConfig(
format="%(asctime)s (%(name)s %(lineno)s): %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
logger = logging.getLogger("constraint_utils")
logger.setLevel(logging.INFO)
GROUPINGS = [
"context",
"ref",
"alt",
"methylation_level",
"annotation",
"modifier",
"transcript",
"gene",
"exome_coverage",
]
"""
Core fields to group by when calculating expected variants per variant type.
Fields taken from gnomAD LoF repo.
"""
def calculate_observed(ht: hl.Table) -> hl.Table:
"""
Group input Table by transcript, filter based on `keep_criteria`, and aggregate observed variants count per transcript.
.. note::
Assumes input HT has been filtered using `keep_criteria`.
:param hl.Table ht: Input Table.
:return: Table annotated with observed variant counts.
:rtype: hl.Table
"""
return ht.group_by(ht.transcript_consequences.transcript_id).aggregate(
observed=hl.agg.count()
)
def get_cumulative_mu_expr(
transcript_expr: hl.expr.StringExpression, mu_expr: hl.expr.Float64Expression,
) -> hl.expr.DictExpression:
"""
Return annotation with the cumulative mutation rate probability, shifted by one.
Value is shifted by one due to the nature of `hl.scan` and needs to be corrected later.
This function can produce the scan when searching for the first break or when searching for additional break(s).
:param hl.expr.StringExpression transcript_expr: StringExpression containing transcript information.
:param hl.expr.Float64Expression mu_expr: FloatExpression containing mutation rate probability per variant.
:return: DictExpression containing scan expressions for cumulative mutation rate probability.
:rtype: hl.expr.DictExpression
"""
return hl.scan.group_by(transcript_expr, hl.scan.sum(mu_expr))
def adjust_mu_expr(
cumulative_mu_expr: hl.expr.DictExpression,
mu_expr: hl.expr.Int32Expression,
transcript_expr: hl.expr.StringExpression,
) -> hl.expr.DictExpression:
"""
Adjust the scan with the cumulative number mutation rate probability.
This adjustment is necessary because scans are always one line behind, and we want the values to match per line.
This function can correct the scan created when searching for the first break or when searching for additional break(s).
:param hl.expr.DictExpression cumulative_mu_expr: DictExpression containing scan expressions for cumulative mutation rate probability.
:param hl.expr.Float64Expression mu_expr: FloatExpression containing mutation rate probability per variant.
:param hl.expr.StringExpression transcript_expr: StringExpression containing transcript information.
:return: Adjusted cumulative mutation rate expression.
:rtype: hl.expr.DictExpression
"""
return hl.if_else(
# Check if the current transcript/section exists in the _mu_scan dictionary
# If it doesn't exist, that means this is the first line in the HT for that particular transcript/section
# The first line of a scan is always missing, but we want it to exist
# Thus, set the cumulative_mu equal to the current mu_snp value
hl.is_missing(cumulative_mu_expr.get(transcript_expr)),
{transcript_expr: mu_expr},
# Otherwise, add the current mu_snp to the scan to make sure the cumulative value isn't one line behind
{transcript_expr: cumulative_mu_expr[transcript_expr] + mu_expr},
)
def translate_mu_to_exp_expr(
cumulative_mu_expr: hl.expr.DictExpression,
transcript_expr: hl.expr.StringExpression,
total_mu_expr: hl.expr.Float64Expression,
total_exp_expr: hl.expr.Float64Expression,
) -> hl.expr.DictExpression:
"""
Translate cumulative mutation rate probability into cumulative expected count per base.
Expected variants counts are produced per base by first calculating the fraction of probability of mutation per base,
then multiplying that fraction by the total expected variants count for a transcript or transcript sub-section.
The expected variants count for section of interest is mutation rate per SNP adjusted by location in the genome/CpG status
(plateau model) and coverage (coverage model).
:param hl.expr.DictExpression cumulative_mu_expr: DictExpression containing scan expressions for cumulative mutation rate probability.
:param hl.expr.StringExpression transcript_expr: StringExpression containing transcript information.
:param hl.expr.Float64Expression total_mu_expr: FloatExpression describing total sum of mutation rate probabilities per transcript.
:param hl.expr.Float64Expression total_exp_expr: FloatExpression describing total expected variant counts per transcript.
:return: Cumulative expected variants count expression.
:rtype: hl.expr.DictExpression
"""
return (cumulative_mu_expr[transcript_expr] / total_mu_expr) * total_exp_expr
def calculate_exp_per_transcript(
context_ht: hl.Table, locus_type: str, groupings: List[str] = GROUPINGS,
) -> hl.Table:
"""
Return the total number of expected variants and aggregate mutation rate per transcript.
.. note::
- Assumes that context_ht is annotated with all of the fields in `groupings` and that the names match exactly.
- Assumes that input table is filtered to autosomes/PAR only, X nonPAR only, or Y nonPAR only.
- Expects that input table contains coverage and plateau models in its global annotations (`coverage_model`, `plateau_models`).
- Expects that input table has multiple fields for mutation rate probabilities:
`mu_snp` for mutation rate probability adjusted by coverage, and
`raw_mu_snp` for raw mutation rate probability.
:param hl.Table context_ht: Context Table.
:param str locus_type: Locus type of input table. One of "X", "Y", or "autosomes".
NOTE: will treat any input other than "X" or "Y" as autosomes.
:param List[str] groupings: List of Table fields used to group Table to adjust mutation rate.
Table must be annotated with these fields. Default is GROUPINGS.
:return: Table grouped by transcript with expected counts per search field.
:rtype: hl.Table
"""
logger.info("Grouping by %s...", groupings)
group_ht = context_ht.group_by(*groupings).aggregate(
mu_agg=hl.agg.sum(context_ht.raw_mu_snp)
)
logger.info("Adding CpG annotations...")
group_ht = annotate_variant_types(group_ht)
logger.info("Adjusting aggregated mutation rate with plateau model...")
if locus_type == "X":
model = group_ht.plateau_x_models["total"][group_ht.cpg]
elif locus_type == "Y":
model = group_ht.plateau_y_models["total"][group_ht.cpg]
else:
model = group_ht.plateau_models["total"][group_ht.cpg]
group_ht = group_ht.annotate(mu_adj=group_ht.mu_agg * model[1] + model[0])
logger.info(
"Adjusting aggregated mutation rate with coverage correction to get expected counts..."
)
group_ht = group_ht.annotate(
coverage_correction=get_coverage_correction_expr(
group_ht.exome_coverage, group_ht.coverage_model
),
)
group_ht = group_ht.annotate(
_exp=group_ht.mu_adj * group_ht.coverage_correction,
mu=group_ht.mu_agg * group_ht.coverage_correction,
)
logger.info("Getting expected counts per transcript and returning...")
return group_ht.group_by("transcript").aggregate(
expected=hl.agg.sum(group_ht._exp), mu_agg=hl.agg.sum(group_ht.mu)
)
def get_obs_exp_expr(
cond_expr: hl.expr.BooleanExpression,
obs_expr: hl.expr.Int64Expression,
exp_expr: hl.expr.Float64Expression,
) -> hl.expr.Float64Expression:
"""
Return observed/expected annotation based on inputs.
Cap observed/expected value at 1.
Function can generate observed/expected values across the entire transcript or section of a transcript depending on inputs.
Function can also generate 'forward' (moving from smaller to larger positions") or 'reverse' (moving from larger to smaller positions)
section obs/exp values.
.. note::
`cond_expr` should vary depending on size/direction of section being annotated.
:param hl.expr.BooleanExpression cond_expr: Condition to check prior to adding obs/exp expression.
:param hl.expr.Int64Expression obs_expr: Expression containing number of observed variants.
:param hl.expr.Float64Expression exp_expr: Expression containing number of expected variants.
:return: Observed/expected expression.
:rtype: hl.expr.Float64Expression
"""
return hl.or_missing(cond_expr, hl.min(obs_expr / exp_expr, 1))
def get_cumulative_obs_expr(
transcript_expr: hl.expr.StringExpression, observed_expr: hl.expr.Int64Expression,
) -> hl.expr.DictExpression:
"""
Return annotation with the cumulative number of observed variants, shifted by one.
Value is shifted by one due to the nature of `hl.scan` and needs to be corrected later.
This function can produce the scan when searching for the first break or when searching for additional break(s).
:param hl.expr.StringExpression transcript_expr: StringExpression containing transcript information.
:param hl.expr.Int64Expression observed_expr: Observed variants expression.
:return: Struct containing the cumulative number of observed and expected variants.
:return: DictExpression containing scan expressions for cumulative observed variant counts for `search_expr`.
:rtype: hl.expr.DictExpression
"""
return hl.scan.group_by(transcript_expr, hl.scan.sum(observed_expr))
def adjust_obs_expr(
cumulative_obs_expr: hl.expr.DictExpression,
obs_expr: hl.expr.Int64Expression,
transcript_expr: hl.expr.StringExpression,
) -> hl.expr.DictExpression:
"""
Adjust the scan with the cumulative number of observed variants.
This adjustment is necessary because scans are always one line behind, and we want the values to match per line.
This function can correct the scan created when searching for the first break or when searching for additional break(s).
.. note::
This function expects:
- cumulative_obs_expr is a DictExpression keyed by transcript.
:param hl.expr.DictExpression cumulative_obs_expr: DictExpression containing scan expression with cumulative observed counts per base.
:param hl.expr.Int32Expression obs_expr: IntExpression with value of either 0 (no observed variant at site) or 1 (variant found in gnomAD).
:param hl.expr.StringExpression transcript_expr: StringExpression containing transcript information.
:return: Adjusted cumulative observed counts expression.
:rtype: hl.expr.DictExpression
"""
return hl.if_else(
# Check if the current transcript/section exists in the _obs_scan dictionary
# If it doesn't exist, that means this is the first line in the HT for that particular transcript
# The first line of a scan is always missing, but we want it to exist
# Thus, set the cumulative_obs equal to the current observed value
hl.is_missing(cumulative_obs_expr.get(transcript_expr)),
{transcript_expr: obs_expr},
# Otherwise, add the current obs to the scan to make sure the cumulative value isn't one line behind
{transcript_expr: cumulative_obs_expr[transcript_expr] + obs_expr},
)
def get_reverse_obs_exp_expr(
total_obs_expr: hl.expr.Int64Expression,
total_exp_expr: hl.expr.Float64Expression,
scan_obs_expr: hl.expr.DictExpression,
cumulative_exp_expr: hl.expr.Float64Expression,
) -> hl.expr.StructExpression:
"""
Return the "reverse" section observed and expected variant counts.
The reverse counts are the counts moving from larger to smaller positions
(backwards from the end of the transcript back to the beginning of the transcript).
reverse value = total value - cumulative value
.. note::
This function is designed to run on one transcript at a time.
:param hl.expr.Int64Expression total_obs_expr: Expression containing total number of observed variants for transcript.
:param hl.expr.Float64Expression total_exp_expr: Expression containing total number of expected variants for transcript.
:param hl.expr.DictExpression scan_obs_expr: Expression containing cumulative number of observed variants for transcript.
:param hl.expr.Float64Expression cumulative_exp_expr: Expression containing cumulative number of expected variants for transcript.
:return: Struct with reverse observed and expected variant counts.
:rtype: hl.expr.StructExpression
"""
return hl.struct(
# NOTE: Adding hl.max to exp expression to make sure reverse exp is never negative
# Without this, ran into errors where reverse exp was -5e-14
# Picked 1e-09 here as tiny number that is not 0
# ExAC code also did not allow reverse exp to be zero, as this breaks the likelihood ratio tests
obs=total_obs_expr - scan_obs_expr,
exp=hl.max(total_exp_expr - cumulative_exp_expr, 1e-09),
)
def get_fwd_exprs(
ht: hl.Table,
transcript_str: str,
obs_str: str,
mu_str: str,
total_mu_str: str,
total_exp_str: str,
) -> hl.Table:
"""
Annotate input Table with the forward section cumulative observed, expected, and observed/expected values.
.. note::
'Forward' refers to moving through the transcript from smaller to larger chromosomal positions.
Expects:
- Input HT is annotated with transcript, observed, mutation rate, total mutation rate (per section),
and total expected counts (per section).
:param hl.Table ht: Input Table.
:param str transcript_str: Name of field containing transcript information.
:param str obs_str: Name of field containing observed variants counts.
:param str mu_str: Name of field containing mutation rate probability per variant.
:param str total_mu_str: Name of field containing total mutation rate per section of interest (transcript or sub-section of transcript).
:param str total_exp_str: Name of field containing total expected variants count per section of interest (transcript or sub-section of transcript).
:param str search_field: Name of field to group by prior to running scan. Should be 'transcript' if searching for the first break.
Otherwise, should be transcript section if searching for additional breaks.
:return: Table with forward values (cumulative obs, exp, and forward o/e) annotated.
:rtype: hl.Table
"""
logger.info("Getting cumulative observed variant counts...")
ht = ht.annotate(
_obs_scan=get_cumulative_obs_expr(
transcript_expr=ht[transcript_str], observed_expr=ht[obs_str],
)
)
ht = ht.annotate(
cumulative_obs=adjust_obs_expr(
cumulative_obs_expr=ht._obs_scan,
obs_expr=ht[obs_str],
transcript_expr=ht[transcript_str],
)
)
logger.info("Getting cumulative expected variant counts...")
# Get scan of mu_snp
ht = ht.annotate(_mu_scan=get_cumulative_mu_expr(ht[transcript_str], ht[mu_str]))
# Adjust scan of mu_snp
ht = ht.annotate(
_mu_scan=adjust_mu_expr(ht._mu_scan, ht[mu_str], ht[transcript_str])
)
ht = ht.annotate(
cumulative_exp=translate_mu_to_exp_expr(
ht._mu_scan, ht[transcript_str], ht[total_mu_str], ht[total_exp_str]
)
)
logger.info("Getting forward observed/expected count and returning...")
# NOTE: adding cond_expr here because get_obs_exp_expr expects it
# cond_expr is necessary for reverse obs/exp, which is why the function has it
ht = ht.annotate(cond_expr=True)
ht = ht.annotate(
forward_oe=get_obs_exp_expr(
ht.cond_expr, ht.cumulative_obs[ht[transcript_str]], ht.cumulative_exp,
)
)
return ht.drop("cond_expr")
def get_reverse_exprs(
ht: hl.Table,
total_obs_expr: hl.expr.Int64Expression,
total_exp_expr: hl.expr.Float64Expression,
scan_obs_expr: Dict[hl.expr.StringExpression, hl.expr.Int64Expression],
scan_exp_expr: Dict[hl.expr.StringExpression, hl.expr.Float64Expression],
) -> hl.Table:
"""
Call `get_reverse_obs_exp_expr` and `get_obs_exp_expr` to add the reverse section cumulative observed, expected, and observed/expected values.
.. note::
'Reverse' refers to moving through the transcript from larger to smaller chromosomal positions.
:param hl.Table ht: Input Table.
:param hl.expr.Int64Expression total_obs_expr: Expression containing total number of observed variants per transcript (if searching for first break)
or per section (if searching for additional breaks).
:param hl.expr.Float64Expression total_exp_expr: Expression containing total number of expected variants per transcript (if searching for first break)
or per section (if searching for additional breaks).
:param Dict[hl.expr.StringExpression, hl.expr.Int64Expression] scan_obs_expr: Expression containing cumulative number of observed variants per transcript
(if searching for first break) or per section (if searching for additional breaks).
:param Dict[hl.expr.StringExpression, hl.expr.Float64Expression] scan_exp_expr: Expression containing cumulative number of expected variants per transcript
(if searching for first break) or per section (if searching for additional breaks).
:return: Table with reverse values annotated
:rtype: hl.Table
"""
# reverse value = total value - cumulative value
ht = ht.annotate(
reverse=get_reverse_obs_exp_expr(
total_obs_expr=total_obs_expr,
total_exp_expr=total_exp_expr,
scan_obs_expr=scan_obs_expr,
cumulative_exp_expr=scan_exp_expr,
)
)
# Set reverse o/e to missing if reverse expected value is 0 (to avoid NaNs)
return ht.annotate(
reverse_obs_exp=get_obs_exp_expr(
(ht.reverse.exp != 0), ht.reverse.obs, ht.reverse.exp
)
)
def get_dpois_expr(
cond_expr: hl.expr.BooleanExpression,
section_oe_expr: hl.expr.Float64Expression,
obs_expr: Union[
Dict[hl.expr.StringExpression, hl.expr.Int64Expression], hl.expr.Int64Expression
],
exp_expr: Union[
Dict[hl.expr.StringExpression, hl.expr.Float64Expression],
hl.expr.Float64Expression,
],
) -> hl.expr.StructExpression:
"""
Calculate null and alt values in preparation for chi-squared test to find significant breaks.
All parameter values depend on the direction of calculation (forward/reverse) and
number of breaks (searching for first break or searching for additional break).
For forward null/alts, values for obs_expr and and exp_expr should be:
- Expression containing cumulative numbers for entire transcript.
- Expression containing cumulative numbers for section of transcript
between the beginning or end of the transcript and the first breakpoint.
For reverse null/alts, values for obs_expr and and exp_expr should be:
- Reverse counts for entire transcript.
- Reverse counts for section of transcript.
For forward null/alts, values for overall_oe_expr and section_oe_expr should be:
- Expression containing observed/expected value for entire transcript and
expression containing observed/expected value calculated on cumulative observed and expected
variants at each position.
- Expression containing observed/expected value for section of transcript.
For reverse null/alts, values for overall_oe_expr and section_oe_expr should be:
- Expression containing observed/expected value for entire transcript and
expression containing observed/expected value calculated on reverse observed variants value
(total observed - cumulative observed count).
- Expression containing observed/expected value for section of transcript and
expression containing reverse observed/expected value for section of transcript.
For forward null/alts, cond_expr should check:
- That the length of the obs_expr isn't 0 when searching for the first break.
- That the length of the obs_expr is 2 when searching for a second additional break.
For reverse null/alts, cond_expr should check:
- That the reverse observed value for the entire transcript is defined when searching for the first break.
- That the reverse observed value for the section between the first breakpoint and the end of the transcript
is defined when searching for a second additional break.
:param hl.expr.BooleanExpression cond_expr: Conditional expression to check before calculating null and alt values.
:param hl.expr.Float64Expression section_oe_expr: Expression of section observed/expected value.
:param Union[Dict[hl.expr.StringExpression, hl.expr.Int64Expression], hl.expr.Int64Expression] obs_expr: Expression containing observed variants count.
:param Union[Dict[hl.expr.StringExpression, hl.expr.Float64Expression], hl.expr.Float64Expression] exp_expr: Expression containing expected variants count.
:return: Struct containing forward or reverse null and alt values (either when searching for first or second break).
:rtype: hl.expr.StructExpression
"""
return hl.or_missing(cond_expr, hl.dpois(obs_expr, exp_expr * section_oe_expr))
def get_section_expr(dpois_expr: hl.expr.ArrayExpression,) -> hl.expr.Float64Expression:
"""
Build null or alt model by multiplying all section null or alt distributions.
For example, when checking for the first break in a transcript, the transcript is broken into two sections:
pre-breakpoint and post-breakpoint. Each section's null and alt distributions must be multiplied
to later test the significance of the break.
:param hl.expr.ArrayExpression dpois_expr: ArrayExpression that contains all section nulls or alts.
Needs to be reduced to single float by multiplying each number in the array.
:return: Overall section distribution.
:rtype: hl.expr.Float64Expression
"""
return hl.fold(lambda i, j: i * j, 1, dpois_expr)
def search_for_break(
ht: hl.Table,
search_field: hl.str,
simul_break: bool = False,
chisq_threshold: float = 10.8,
) -> hl.Table:
"""
Search for breakpoints in a transcript or within a transcript subsection.
Expects input context HT to contain the following fields:
- locus
- alleles
- transcript or section
- coverage (median)
- mu_snp
- scan_counts struct
- overall_oe
- forward_oe
- reverse struct
- reverse_obs_exp
- total_obs
- total_exp
If searching for simultaneous breaks, expects HT to have the following fields:
- pre_obs
- pre_exp
- pre_oe
- window_obs
- window_exp
- window_oe
- post_obs
- post_exp
- post_oe
Also expects:
- multiallelic variants in input HT have been split.
- Input HT is autosomes/PAR only, X non-PAR only, or Y non-PAR only.
Returns HT filtered to lines with maximum chisq if chisq >= max_value, otherwise returns None.
:param hl.Table ht: Input context Table.
:param hl.expr.StringExpression search_field: Field of table to search. Value should be either 'transcript' or 'section'.
:param bool simul_break: Whether this function is searching for simultaneous breaks. Default is False.
:param float chisq_threshold: Chi-square significance threshold.
Value should be 10.8 (single break) and 13.8 (two breaks) (values from ExAC RMC code).
Default is 10.8.
:return: Table annotated with whether position is a breakpoint.
:rtype: hl.Table
"""
logger.info(
"Creating section null (no regional variability in missense depletion)\
and alt (evidence of domains of missense constraint) expressions..."
)
# Split transcript into three sections if searching for simultaneous breaks
if simul_break:
ht = ht.annotate(
section_nulls=[
# Get null expression for section of transcript pre-window
get_dpois_expr(
cond_expr=True,
section_oe_expr=ht.overall_oe,
obs_expr=ht.pre_obs,
exp_expr=ht.pre_exp,
),
# Get null expression for window of constraint
get_dpois_expr(
cond_expr=True,
section_oe_expr=ht.overall_oe,
obs_expr=ht.window_obs,
exp_expr=ht.window_exp,
),
# Get null expression for section of transcript post-window
get_dpois_expr(
cond_expr=True,
section_oe_expr=ht.overall_oe,
obs_expr=ht.post_obs,
exp_expr=ht.post_exp,
),
]
)
ht = ht.annotate(
section_alts=[
# Get alt expression for section of transcript pre-window
get_dpois_expr(
cond_expr=True,
section_oe_expr=ht.pre_oe,
obs_expr=ht.pre_obs,
exp_expr=ht.pre_exp,
),
# Get alt expression for window of constraint
get_dpois_expr(
cond_expr=True,
section_oe_expr=ht.window_oe,
obs_expr=ht.window_obs,
exp_expr=ht.window_exp,
),
# Get alt expression for section of transcript post-window
get_dpois_expr(
cond_expr=True,
section_oe_expr=ht.post_oe,
obs_expr=ht.post_obs,
exp_expr=ht.next_values.reverse_exp,
),
]
)
# Otherwise, split transcript only into two sections (when searching for first/additional breaks)
else:
ht = ht.annotate(
section_nulls=[
# Add forwards section null (going through positions from smaller to larger)
# section_null = stats.dpois(section_obs, section_exp*overall_obs_exp)[0]
get_dpois_expr(
cond_expr=hl.len(ht.cumulative_obs) != 0,
section_oe_expr=ht.overall_oe,
obs_expr=ht.cumulative_obs[ht[search_field]],
exp_expr=ht.cumulative_exp,
),
# Add reverse section null (going through positions larger to smaller)
get_dpois_expr(
cond_expr=hl.is_defined(ht.reverse.obs),
section_oe_expr=ht.overall_oe,
obs_expr=ht.reverse.obs,
exp_expr=ht.reverse.exp,
),
],
section_alts=[
# Add forward section alt
# section_alt = stats.dpois(section_obs, section_exp*section_obs_exp)[0]
get_dpois_expr(
cond_expr=hl.len(ht.cumulative_obs) != 0,
section_oe_expr=ht.forward_oe,
obs_expr=ht.cumulative_obs[ht[search_field]],
exp_expr=ht.cumulative_exp,
),
# Add reverse section alt
get_dpois_expr(
cond_expr=hl.is_defined(ht.reverse.obs),
section_oe_expr=ht.reverse_obs_exp,
obs_expr=ht.reverse.obs,
exp_expr=ht.reverse.exp,
),
],
)
logger.info("Multiplying all section nulls and all section alts...")
# Kaitlin stores all nulls/alts in section_null and section_alt and then multiplies
# e.g., p1 = prod(section_null_ps)
ht = ht.annotate(
total_null=get_section_expr(ht.section_nulls),
total_alt=get_section_expr(ht.section_alts),
)
logger.info("Adding chisq value and getting max chisq...")
ht = ht.annotate(chisq=(2 * (hl.log10(ht.total_alt) - hl.log10(ht.total_null))))
# "The default chi-squared value for one break to be considered significant is
# 10.8 (p ~ 10e-3) and is 13.8 (p ~ 10e-4) for two breaks. These currently cannot
# be adjusted."
group_ht = ht.group_by(search_field).aggregate(max_chisq=hl.agg.max(ht.chisq))
ht = ht.annotate(max_chisq=group_ht[ht.transcript].max_chisq)
return ht.annotate(
is_break=((ht.chisq == ht.max_chisq) & (ht.chisq >= chisq_threshold))
)
def process_transcripts(ht: hl.Table, chisq_threshold: float):
"""
Annotate each position in Table with whether that position is a significant breakpoint.
Also annotate input Table with cumulative observed, expected, and observed/expected values
for both forward (moving from smaller to larger positions) and reverse (moving from larger to
smaller positions) directions.
Expects input Table to have the following fields:
- locus
- alleles
- mu_snp
- transcript
- observed
- expected
- coverage_correction
- cpg
- cumulative_obs
- cumulative_exp
- forward_oe
Also expects:
- multiallelic variants in context HT have been split.
- Global annotations contains plateau models.
:param hl.Table ht: Input Table. annotated with observed and expected variants counts per transcript.
:param float chisq_threshold: Chi-square significance threshold.
Value should be 10.8 (single break) and 13.8 (two breaks) (values from ExAC RMC code).
:return: Table with cumulative observed, expected, and observed/expected values annotated for forward and reverse directions.
Table also annotated with boolean for whether each position is a breakpoint.
:rtype: hl.Table
"""
logger.info(
"Annotating HT with reverse observed and expected counts for each transcript...\n"
"(transcript-level reverse (moving from larger to smaller positions) values)"
)
ht = get_reverse_exprs(
ht=ht,
total_obs_expr=ht.total_obs,
total_exp_expr=ht.total_exp,
scan_obs_expr=ht.cumulative_obs[ht.transcript],
scan_exp_expr=ht.cumulative_exp,
)
return search_for_break(ht, "transcript", False, chisq_threshold)
def get_subsection_exprs(
ht: hl.Table,
section_str: str = "section",
obs_str: str = "observed",
mu_str: str = "mu_snp",
total_mu_str: str = "total_mu",
total_exp_str: str = "total_exp",
) -> hl.Table:
"""
Annotate total observed, expected, and observed/expected (OE) counts for each section of a transcript.
.. note::
Assumes input Table is annotated with:
- section
- observed variants count per site
- mutation rate probability per site
- total mutation rate probability per transcript
- total expected variant counts per transcript
Names of annotations must match section_str, obs_str, mu_str, total_mu_str, and total_exp_str.
:param hl.Table ht: Input Table.
:param str section_str: Name of section annotation.
:param str obs_str: Name of observed variant counts annotation.
:param str mu_str: Name of mutation rate probability per site annotation.
:param str total_mu_str: Name of annotation containing sum of mutation rate probabilities per transcript.
:param str total_exp_str: Name of annotation containing total expected variant counts per transcript.
:return: Table annotated with section observed, expected, and OE counts.
:return: hl.Table
"""
logger.info(
"Getting total observed and expected counts for each transcript section..."
)
# Get total obs and mu per section
section_counts = ht.group_by(ht[section_str]).aggregate(
obs=hl.agg.sum(ht[obs_str]), mu=hl.agg.sum(ht[mu_str]),
)
# Translate total mu to total expected per section
ht = ht.annotate(
section_exp=(section_counts[ht[section_str]].mu / ht[total_mu_str])
* ht[total_exp_str],
section_obs=section_counts[ht[section_str]].obs,
)
logger.info("Getting observed/expected value for each transcript section...")
return ht.annotate(
break_oe=get_obs_exp_expr(
cond_expr=hl.is_defined(ht[section_str]),
obs_expr=ht.section_obs,
exp_expr=ht.section_exp,
)
)
def process_sections(ht: hl.Table, chisq_threshold: float):
"""
Search for breaks within given sections of a transcript.
Expects that input Table has the following annotations:
- context
- ref
- alt
- cpg
- observed
- mu_snp
- coverage_correction
- methylation_level
- section
Also assumes that Table's globals contain plateau and coverage models.
:param hl.Table ht: Input Table.
:param float chisq_threshold: Chi-square significance threshold.
Value should be 10.8 (single break) and 13.8 (two breaks) (values from ExAC RMC code).
:return: Table annotated with whether position is a breakpoint.
:rtype: hl.Table
"""
ht = get_subsection_exprs(ht)
# TODO: Move this calculation to release HT generation step
logger.info("Getting section chi-squared values...")
ht = ht.annotate(
section_chisq=calculate_section_chisq(
obs_expr=ht.section_obs, exp_expr=ht.section_exp,
)
)
logger.info("Splitting HT into pre and post breakpoint sections...")
pre_ht = ht.filter(ht.section.contains("pre"))
post_ht = ht.filter(ht.section.contains("post"))
logger.info(
"Annotating post breakpoint HT with cumulative observed and expected counts..."
)
post_ht = get_fwd_exprs(
ht=post_ht,
transcript_str="transcript",
obs_str="observed",
mu_str="mu_snp",
total_mu_str="total_mu",
total_exp_str="total_exp",
)
logger.info(
"Annotating post breakpoint HT with reverse observed and expected counts..."
)
post_ht = get_reverse_exprs(
ht=post_ht,
total_obs_expr=post_ht.section_obs,
total_exp_expr=post_ht.section_exp,
scan_obs_expr=post_ht.cumulative_obs[post_ht.transcript],
scan_exp_expr=post_ht.cumulative_exp,
)
logger.info("Searching for a break in each section and returning...")
pre_ht = search_for_break(
pre_ht,
search_field="transcript",
simul_break=False,
chisq_threshold=chisq_threshold,
)
# Adjust is_break annotation in pre_ht
# to prevent this function from continually finding previous significant breaks
pre_ht = pre_ht.annotate(
is_break=hl.if_else(pre_ht.break_list.any(lambda x: x), False, pre_ht.is_break)
)
post_ht = search_for_break(
post_ht,
search_field="transcript",
simul_break=False,
chisq_threshold=chisq_threshold,
)
return pre_ht.union(post_ht)
def process_additional_breaks(
ht: hl.Table, break_num: int, chisq_threshold: float
) -> hl.Table:
"""
Search for additional breaks in a transcript after finding one significant break.
Expects that input Table is filtered to only transcripts already containing one significant break.
Expects that input Table has the following annotations:
- cpg
- observed
- mu_snp
- coverage_correction
- transcript
Also assumes that Table's globals contain plateau models.
:param hl.Table ht: Input Table.
:param int break_num: Number of additional break: 2 for second break, 3 for third, etc.
:param float chisq_threshold: Chi-square significance threshold.
Value should be 10.8 (single break) and 13.8 (two breaks) (values from ExAC RMC code).
:return: Table annotated with whether position is a breakpoint.
:rtype: hl.Table
"""
logger.info(
"Generating table keyed by transcripts (used to get breakpoint position later)..."
)
break_ht = ht.filter(ht.is_break).key_by("transcript")
logger.info(
"Renaming is_break field to prepare to search for an additional break..."
)
# Rename because this will be overwritten when searching for additional break
annot_expr = {"is_break": f"is_break_{break_num - 1}"}
ht = ht.rename(annot_expr)
logger.info(
"Splitting each transcript into two sections: pre-first breakpoint and post..."
)
ht = ht.annotate(
section=hl.if_else(
ht.locus.position > break_ht[ht.transcript].locus.position,
hl.format("%s_%s", ht.transcript, "post"),
hl.format("%s_%s", ht.transcript, "pre"),
),
)
return process_sections(ht, chisq_threshold)
def get_window_end_pos_expr(
pos_expr: hl.expr.Int32Expression,
end_pos_expr: hl.expr.Int32Expression,
window_size_expr: int,
) -> hl.expr.Expression:
"""
Get window end position.
Annotates window end position only if the position + (`window_size_expr` - 1) is less than or equal to the transcript end position.
.. note::
The subtraction (`window_size_expr` - 1) is necessary to get a window with only `window_size_expr` base pairs.
Without the subtraction, the window size woudl be one base pair too many:
e.g., If position is 8, window size is 3 bp, and end position is 12, then:
the window end is 8 + (3 - 1) = 10, and the window is [8, 9, 10] (3 bp).
Without the subtraction, the window would be [8, 9, 10, 11], or 4 bp.
:param hl.expr.Int32Expression pos_expr: IntExpression representing chromosomal position.
:param hl.expr.Int32Expression end_pos_expr: IntExpression representing last position in transcript.
:param int window_size_expr: Size of two break window.
:return: Expression annotating window end position.
:rtype: hl.expr.Expression
"""
return hl.or_missing(
pos_expr + (window_size_expr - 1) <= end_pos_expr,
pos_expr + (window_size_expr - 1),
)
def get_min_post_window_pos(ht: hl.Table, pos_ht: hl.Table) -> hl.Table:
"""
Get the first position of transcript outside of smallest simultaneous break window.
Run `hl.binary_search` to find index of first post-window position.
Assumes:
- ht is annotated with smallest window end position (`min_window_end`)
:param hl.Table ht: Input Table.
:param hl.Table pos_ht: Input GroupedTable grouped by transcript with list of all positions per transcript.
:return: Table annotated with post-window position.
:rtype: hl.Table
"""
logger.info("Annotating the positions per transcript back onto the input HT...")
# This is to run `hl.binary_search` using the window end position as input
ht = ht.key_by("transcript").join(pos_ht)
ht = ht.transmute(pos_per_transcript=hl.sorted(ht.positions))
logger.info(
"Running hl.binary_search to find the smallest post window positions..."
)
# hl.binary search will return the index of the position that is closest to the search element (window_end)
# If window_end exists in the HT, hl.binary_search will return the index of the window_end
# If window_end does not exist in the HT, hl.binary_search will return the index of the next largest position
# e.g., pos_per_transcript = [1, 4, 8]
# hl.binary_search(pos_per_transcript, 4) will return 1
# hl.binary_search(pos_per_transcript, 5) will return 2
ht = ht.annotate(
min_post_window_index=hl.binary_search(
ht.pos_per_transcript, ht.min_window_end
),
n_pos_per_transcript=hl.len(ht.pos_per_transcript),
)
# Adjust the post window position to make sure it is larger than the provided window end position
# The default behavior in this case statement is to return the transcript end position because:
# 1) `pos_expr` is created using the filtered context table. This means:
# - `pos_expr` contains only positions in each transcript that had a possible missense variant.
# - `pos_expr` might not contain the transcript end position.
# 2) `hl.binary_search` will return an index larger than the length of a list if that position
# is larger than the largest element in the list.
# For example: if `pos_expr` is [1, 4, 8], then `hl.binary_search(pos_per_transcript, 10)` will return 3.
return ht.annotate(
min_post_window_pos=hl.case()
# When the index is the last index in the list
.when(
ht.min_post_window_index == ht.n_pos_per_transcript - 1,
# Return the position pointed to by the index only if it is larger than the window end
# Otherwise, return missing, since any larger positions in the transcript do not exist in the input HT
hl.or_missing(
ht.pos_per_transcript[ht.min_post_window_index] > ht.min_window_end,
ht.pos_per_transcript[ht.min_post_window_index],
),
)
# When the index returned is anywhere from the start index to the second to last index in the list:
.when(
ht.min_post_window_index < ht.n_pos_per_transcript,
# Check if the position pointed to by the index is the same as the window end
# If the positions are the same, then return the next position in the list
hl.if_else(
ht.pos_per_transcript[ht.min_post_window_index] == ht.min_window_end,
ht.pos_per_transcript[ht.min_post_window_index + 1],
ht.pos_per_transcript[ht.min_post_window_index],
),
).or_missing()
)
def get_min_two_break_window(
ht: hl.Table,
min_window_size: int,
transcript_percentage: float,
overwrite_pos_ht: bool = False,
annotations: List[str] = [
"mu_snp",
"total_exp",
"_mu_scan",
"total_mu",
"cumulative_obs",
"observed",
"cumulative_exp",
"total_obs",
"reverse",
"forward_oe",
"overall_oe",
],
) -> Tuple[hl.Table, int]:
"""
Annotate input Table with smallest simultaneous breaks window end positions.
Also annotate Table with first position post-window.
Function prepares Table for simultaneous breaks searches.
Assumes input Table has all annotations present in `annotations`.
:param hl.Table ht: Input Table.
:param int min_window_size: Minimum number of bases to search for constraint (window size for simultaneous breaks).
:param float transcript_percentage: Maximum percentage of the transcript that can be included within a window of constraint.
:param bool overwrite_pos_ht: Whether to overwrite positions per transcript HT. Default is False.
:param List[str] annotations: Annotations to keep from input HT. Required to search for significant breakpoint.