In [None]:
import os
import sys
import subprocess
import hail as hl
from pyspark.sql import SparkSession

In [None]:
builder = (
    SparkSession
    .builder
    .enableHiveSupport())
spark = builder.getOrCreate()
hl.init(sc=spark.sparkContext)
hl.default_reference("GRCh38")
print("Hail version:", hl.__version__)

In [None]:
!hdfs dfs -mkdir -p /tmp/pca_bgen
!hdfs dfs -put -f /mnt/project/dcm_pgs/pca_variants/dcm_pca_chr*_subset.bgen /tmp/pca_bgen/
!hdfs dfs -ls /tmp/pca_bgen | head -n 10

In [None]:
contig_recoding = {f"{i:02d}": str(i) for i in range(1, 23)}

for c in range(1, 23):
    bgen = f"hdfs://master:9000/tmp/pca_bgen/dcm_pca_chr{c}_subset.bgen"
    idx2 = bgen + ".idx2"

    print(f"chr{c}: indexing")
    hl.index_bgen(
        path=bgen,
        index_file_map={bgen: idx2},
        reference_genome="GRCh37",
        contig_recoding=contig_recoding,
    )

print("Done indexing.")

In [None]:
# BGEN files
bgen_paths = [
    f"hdfs://master:9000/tmp/pca_bgen/dcm_pca_chr{c}_subset.bgen"
    for c in range(1, 23)
]

mt = hl.import_bgen(
    bgen_paths,
    sample_file="file:///mnt/project/Bulk/Imputation/UKB imputation from genotype/ukb22828_c1_b0_v3.sample",
    entry_fields=["dosage"]
)

In [None]:
# Annotate rows with loadings and filter to intersection
ld = hl.read_table("file:///mnt/project/dcm_pgs/loadings/gnomad.v3.1.pca_loadings_grch37.ht")
ld = ld.key_by("locus", "alleles")
mt = mt.annotate_rows(l = ld[mt.row_key])
mt = mt.filter_rows(hl.is_defined(mt.l))

In [None]:
# compute PC1..PC10
p = mt.l.pca_af
mu = 2.0 * p
sigma = hl.sqrt(2.0 * p * (1.0 - p))
x = (hl.float64(mt.dosage) - mu) / sigma

mt = mt.annotate_cols(
    PC1  = hl.agg.sum(x * mt.l.loadings[0]),
    PC2  = hl.agg.sum(x * mt.l.loadings[1]),
    PC3  = hl.agg.sum(x * mt.l.loadings[2]),
    PC4  = hl.agg.sum(x * mt.l.loadings[3]),
    PC5  = hl.agg.sum(x * mt.l.loadings[4]),
    PC6  = hl.agg.sum(x * mt.l.loadings[5]),
    PC7  = hl.agg.sum(x * mt.l.loadings[6]),
    PC8  = hl.agg.sum(x * mt.l.loadings[7]),
    PC9  = hl.agg.sum(x * mt.l.loadings[8]),
    PC10 = hl.agg.sum(x * mt.l.loadings[9]),
)

# CRITICAL: drop entry + row data so cols() is cheap
mt = mt.select_rows().select_entries()

# now cols table is genuinely small
pcs = mt.cols()

pcs = pcs.select(
    eid = pcs.s,
    PC1 = pcs.PC1,
    PC2 = pcs.PC2,
    PC3 = pcs.PC3,
    PC4 = pcs.PC4,
    PC5 = pcs.PC5,
    PC6 = pcs.PC6,
    PC7 = pcs.PC7,
    PC8 = pcs.PC8,
    PC9 = pcs.PC9,
    PC10 = pcs.PC10,
)

# repartition to avoid driver aggregation issues
pcs = pcs.repartition(2000)

# checkpoint / export
pcs = pcs.checkpoint("/tmp/ukb_gnomad_projected_pcs.ht", overwrite=True)

In [None]:
# export sharded (prevents executor OOM)
pcs = hl.read_table("/tmp/ukb_gnomad_projected_pcs.ht")

pcs.export(
    "/tmp/ukb_gnomad_projected_pcs.tsv.bgz",
    parallel="header_per_shard"
)

In [None]:
!hdfs dfs -getmerge /tmp/ukb_gnomad_projected_pcs.tsv.bgz /opt/notebooks/ukb_gnomad_projected_pcs.tsv.bgz

In [None]:
!gunzip -c /opt/notebooks/ukb_gnomad_projected_pcs.tsv.bgz | head -n 5

In [None]:
!dx upload /opt/notebooks/ukb_gnomad_projected_pcs.tsv.bgz /mnt/project/dcm_pgs/ukb_gnomad_projected_pcs.tsv.bgz