# Enformer-GTEx results

Explore Enformer predictions on GTEx variants

## Setup

In [1]:
import polars as pl
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import seaborn as sns
from kipoi_enformer.utils import get_tss_from_genome_annotation
import numpy as np
import statsmodels.api as sm
import plotnine as pn
from scipy.stats import ranksums
import lightgbm as lgb
from datetime import datetime

%load_ext autoreload
%autoreload 2

2024-09-04 13:14:35.950921: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
plt.rcParams['figure.dpi'] = 150

In [3]:
pl.Config.set_fmt_str_lengths(100)

polars.config.Config

In [4]:
pn.theme_set(pn.theme_bw())
pn.theme_update(dpi=150)

In [5]:
# Files relevant to the GTEx analysis
enformer_run_name = 'enformer_gtexv8_elasticnet_cage_canonical_2000_500'
# enformer variant effect on GTEx variants
enformer_path = f'/s/project/promoter_prediction/kipoi_expression_prediction/veff.parquet/run={enformer_run_name}'
# GTEx variants
variant_path = '/s/project/rep/processed/training_results_v15/gtex_v8_old_dna/private_variants.parquet/rare_variants.vcf.parquet/**/*.parquet'
# abexp benchmark dataset
gtex_benchmark_with_annotation_path = "/s/project/rep/processed/training_results_v15/gtex_benchmark_with_annotation.parquet/*.parquet"

# Output path
output_path = Path(f'/data/nasif12/home_if12/tsi/output/enformer/{enformer_run_name}/')
output_path.mkdir(exist_ok=True, parents=True)

## Analysis

We are only loading ensembl canonical transcripts for this analysis. So there is 1 transcript per gene!

In [6]:
veff_transcript_ldf = (pl.scan_parquet(Path(enformer_path) / '*.parquet').
                       select(pl.col(['tissue', 'gene_id', 'transcript_id', 'transcript_start', 'transcript_end']), 
                              pl.col('strand').cast(pl.Enum(['-', '+'])),
                              pl.col(['chrom', 'variant_start', 'variant_end', 'ref', 'alt', 'veff_score', 'ref_score', 'alt_score'])).
                       rename({'gene_id': 'gene', 'transcript_id': 'transcript'}).
                       with_columns(pl.col('gene').str.replace(r'([^\.]+)\..+$', "${1}").alias('gene'),
                                    pl.col('transcript').str.replace(r'([^\.]+)\..+$', "${1}").alias('transcript')))


# It is possible that a gene comes multiple times (different versions)

In [7]:
variant_ldf = pl.scan_parquet(variant_path).select(['sampleId', 'chrom', 'start', 'end', 'ref', 'alt']).rename({'sampleId': 'individual',
                                                                                                                'start': 'variant_start',
                                                                                                                'end': 'variant_end'})

In [8]:
training_benchmark_ldf = (pl.scan_parquet(gtex_benchmark_with_annotation_path)
                          .select(['gene', 'individual', 'tissue', 'FDR', 
                                   'mu', 'zscore', 'l2fc', 'is_obvious_outlier'])
                          .unique()
                          .rename({'l2fc': 'l2fc_outrider'})
                          .with_columns((pl.when(pl.col('FDR') > 0.2)
                                         .then(pl.lit('normal'))
                                         .otherwise(
                                             pl.when(pl.col('zscore') > 0)
                                             .then(pl.lit('overexpressed'))
                                             .otherwise(
                                                 pl.when(pl.col('zscore') < 0)
                                                 .then(pl.lit('underexpressed'))
                                                 # this should never be the case
                                                 .otherwise(pl.lit('CHECK'))
                                             ))).cast(pl.Enum(['underexpressed', 'normal', 'overexpressed'])).alias('outlier_state')))

### What is the enformer variant-effect-score distribution around the TSS?

In [9]:
upstream=2000
downstream=500

# join variants with enformer_veff on veff
veff_variant_ldf = variant_ldf.join(veff_transcript_ldf, how='inner', on=['chrom', 'variant_start', 'variant_end', 'ref', 'alt'])
# calculate variant position relative to the tss: rel_var_pos
veff_variant_ldf = veff_variant_ldf.with_columns((pl.when(pl.col('strand') == '+')
                                 .then(pl.col('variant_start') - pl.col('transcript_start'))
                                 .otherwise(pl.col('transcript_end') - pl.col('variant_start'))
                            ).alias('rel_var_pos'))
# filter out variants far from the TSS (-2000, 500)
veff_variant_ldf = veff_variant_ldf.filter((pl.col('rel_var_pos') >= -upstream) & (pl.col('rel_var_pos') < downstream))
# rename scores to canonical_scores
veff_variant_ldf = (veff_variant_ldf.
                              select(['individual', 'chrom', 'variant_start', 'variant_end', 'ref', 'alt', 'tissue', 'gene', 'transcript', 'strand', 'veff_score', 'ref_score', 'alt_score', 'rel_var_pos']))

# join outrider with variants on individual
veff_outrider_ldf = (veff_variant_ldf.join(training_benchmark_ldf, how='inner', on=['individual', 'gene', 'tissue']).
                               select(['gene', 'tissue', 'individual', 'rel_var_pos', 'outlier_state', 'zscore', 'FDR', 'veff_score',
                                       'l2fc_outrider', 'mu', 'is_obvious_outlier']))

In [10]:
# pl.Config.set_streaming_chunk_size(100)
# print(veff_outrider_ldf.explain(streaming=True))

In [10]:
veff_outrider_df = veff_outrider_ldf.collect(streaming=True)

In [12]:
# rel_var_pos in bins
bin_size=50
cuts = list(range(-upstream + bin_size, downstream, bin_size))
cut_labels = [str(x) for x in [-upstream, *cuts]]
# cut_labels = [f'[-{upstream}, -{upstream - bin_size})'] + [f'[{cuts[i]}, {cuts[i+1]})' for i in range(len(cuts) - 1)] + [f'[{downstream - bin_size}, {downstream})']
veff_outrider_df = (veff_outrider_df.with_columns(
    (pl.col('rel_var_pos').cut(cuts, labels=cut_labels)).alias('rel_var_pos_bin').cast(pl.Enum(cut_labels))
).with_columns((pl.col('outlier_state') == 'underexpressed').alias('is_underexpressed')))

veff_outrider_df = veff_outrider_df.with_columns(rel_var_pos_bin_label = pl.col("rel_var_pos_bin").cast(pl.String), bin_size=pl.lit(50))
for c in cut_labels:
    veff_outrider_df = veff_outrider_df.with_columns(rel_var_pos_bin_label = pl.col("rel_var_pos_bin_label")
                                                     .replace(str(c), f'[{c}, {int(c) + bin_size})'))

In [13]:
# set new bins
# new_bins = [(-2000, -500),
#             *[(i, i + 100) for i in range(-500, -100, 100)],
#             (-100, -50),
#             (-50, 0),
#             (0, 50),
#             (50, 100),
#             *[(i, i + 100) for i in range(100, 500, 100)],]
new_bins = [(-2000, -1500),
            (-1500, -1000),
            (-1000, -500),
            *[(i, i + 100) for i in range(-500, -100, 100)],
            (-100, -50),
            (-50, 0),
            (0, 50),
            (50, 100),
            *[(i, i + 100) for i in range(100, 500, 100)],]
new_bin_labels  = [f'[{start}, {stop})' for start, stop in new_bins]

for start, stop in new_bins:
    veff_outrider_df = veff_outrider_df.with_columns(
        rel_var_pos_bin_label = pl.when((pl.col('rel_var_pos_bin').cast(pl.Int16) >= start) & (pl.col('rel_var_pos_bin').cast(pl.Int16) < stop))
                                  .then(pl.lit(f'[{start}, {stop})'))
                                  .otherwise(pl.col('rel_var_pos_bin_label')),
        bin_size = pl.when((pl.col('rel_var_pos_bin').cast(pl.Int16) >= start) & (pl.col('rel_var_pos_bin').cast(pl.Int16) < stop))
                                  .then(pl.lit(stop - start))
                                  .otherwise(pl.col('bin_size')))
veff_outrider_df = veff_outrider_df.with_columns(rel_var_pos_bin_label=pl.col('rel_var_pos_bin_label').cast(pl.Enum(new_bin_labels)))

In [14]:
# calculate counts per bins
bin_count_df = veff_outrider_df.group_by(['outlier_state', 'rel_var_pos_bin_label', 'bin_size']). \
    agg((pl.len()).alias('count'))

# # calculate mean of each bin and then sum the means in each outlier state
totals_df = bin_count_df.group_by('outlier_state').agg(pl.sum('count').alias('total_count'))
# # normalize each count by the mean calculated above
enrichment_df = bin_count_df.join(totals_df, on='outlier_state').with_columns((pl.col('count') / pl.col('total_count')).alias('enrichment'))
ci_low, ci_high = sm.stats.proportion_confint(enrichment_df["count"], enrichment_df["total_count"])
enrichment_df = enrichment_df.with_columns(pl.Series(ci_low).alias('ci_low'), pl.Series(ci_high).alias('ci_high'))
# normalize by bin size
enrichment_df = enrichment_df.with_columns(enrichment = pl.col('enrichment') / pl.col('bin_size'),
                                           ci_low = pl.col('ci_low') / pl.col('bin_size'),
                                           ci_high = pl.col('ci_high') / pl.col('bin_size'))

In [15]:
score_column = 'veff_score'

In [17]:
115691 / 2749843

0.04207185646598733

In [16]:
enrichment_df

outlier_state,rel_var_pos_bin_label,bin_size,count,total_count,enrichment,ci_low,ci_high
enum,enum,i32,u32,u32,f64,f64,f64
"""normal""","""[-400, -300)""",100,115691,2749843,0.000421,0.000418,0.000423
"""underexpressed""","""[400, 500)""",100,38,2051,0.000185,0.000127,0.000244
"""normal""","""[300, 400)""",100,112672,2749843,0.00041,0.000407,0.000412
"""normal""","""[400, 500)""",100,114017,2749843,0.000415,0.000412,0.000417
"""normal""","""[-1000, -500)""",500,508802,2749843,0.00037,0.000369,0.000371
…,…,…,…,…,…,…,…
"""overexpressed""","""[100, 200)""",100,95,1177,0.000807,0.000652,0.000963
"""normal""","""[100, 200)""",100,109829,2749843,0.000399,0.000397,0.000402
"""overexpressed""","""[0, 50)""",50,104,1177,0.001767,0.001443,0.002091
"""overexpressed""","""[50, 100)""",50,67,1177,0.001138,0.000874,0.001403


In [18]:
totals_df

outlier_state,total_count
enum,u32
"""underexpressed""",2051
"""overexpressed""",1177
"""normal""",2749843


In [37]:
# cnt = veff_outrider_df.group_by(['outlier_state']).len().with_columns(
#         (pl.col("len").map_elements(lambda x: format(x, ','), str))
#     )
# p1 = (
#         pn.ggplot(veff_outrider_df, pn.aes(x="outlier_state", fill="outlier_state"))
#         + pn.geom_boxplot(pn.aes(y=score_column))
#         + pn.theme(
#             figure_size=(8, 6),
#             axis_text_x=pn.element_text(angle=90),
#             dpi=150
#         )
#         + pn.labs(
#             x="Outlier state", 
#             y="Enformer log2FC",
#             color="",
#             fill="",
#             title=f"Enformer variant-effect of rare promoter variants in GTEx v7 (-{upstream}, +{downstream})",
#         )
#         + pn.guides(size='none', fill = pn.guide_legend(reverse = True))
#         + pn.geom_label(cnt, pn.aes(label='len', y=1, size=8), show_legend=False)
#         + pn.coord_flip(ylim=[-1, 1])
#     )
# p1

In [38]:
# p2 = boxplot_per_bin(veff_bin_df, score_column=score_column)
# cnt = tmp.group_by(['rel_var_pos_bin_label', 'outlier_state']).len().with_columns(
#         (pl.when(pl.col('outlier_state') == 'underexpressed').then(1).otherwise(
#             pl.when(pl.col('outlier_state')=='normal').then(0.9).otherwise(0.8)
#         )).alias('pos'),
#         (pl.col("len").map_elements(lambda x: format(x, ','), str))
#     )
# p2 = (
#         pn.ggplot(veff_outrider_df, pn.aes(x="rel_var_pos_bin_label", fill="outlier_state"))
#         + pn.geom_boxplot(pn.aes(y=score_column))
#         + pn.theme(
#             figure_size=(12, 6),
#             axis_text_x=pn.element_text(angle=90),
#             dpi=150
#         )
#         + pn.labs(
#             x="Distance to TSS", 
#             y="Enformer log2FC",
#             color="",
#             fill="",
#             title="Enformer variant-effect of rare promoter variants in GTEx v7",
#         )
#         + pn.coord_cartesian(ylim=[-1, 1])
#         + pn.guides(size='none', fill = pn.guide_legend(reverse = True))
#     )
# p2

In [39]:
# p3 = (
#     pn.ggplot(enrichment_df, pn.aes(x="rel_var_pos_bin_label", y="enrichment", fill="outlier_state", color="outlier_state"))
#     + pn.geom_line(pn.aes(group="outlier_state"), linetype="dashed")
#     + pn.geom_point()
#     + pn.geom_errorbar(pn.aes(ymin="ci_low", ymax="ci_high"))
#     + pn.theme(
#         figure_size=(8, 6),
#         axis_text_x=pn.element_text(angle=90),
#         dpi=150
#     )
#     + pn.labs(
#         x="Distance to TSS", 
#         y="Variant enrichment",
#         color="",
#         fill="",
#         title="Enrichment of rare promoter variants in GTEx v7",
#     )
# )
# p3

In [40]:
base_path = Path('/data/nasif12/home_if12/tsi/kipoi_expression_prediction/etc/enformer_bins') / datetime.now().strftime("%Y%m%d%H%M")
base_path.mkdir(parents=True)
veff_outrider_df.write_parquet(base_path / 'veff_bin.parquet', use_pyarrow=True)
enrichment_df.write_parquet(base_path / 'enrichment.parquet', use_pyarrow=True)
base_path

PosixPath('/data/nasif12/home_if12/tsi/kipoi_expression_prediction/etc/enformer_bins/202407252235')