In [None]:
from pyspark.sql import SparkSession
import hail as hl
import os
import time
import dxpy
import logging
import pandas as pd
import re


# Build spark
builder = (
    SparkSession
    .builder
    .enableHiveSupport()
)
spark = builder.getOrCreate()
hl.init(sc=spark.sparkContext, idempotent=True)

# Read the UKB SNP array data

In [None]:
snp_data_pre = "file:///mnt/project/notebooks/snp/liftover/ukb_c1-22_GRCh38_full_analysis_set_plus_decoy_hla_merged"

geno_mt = hl.import_plink(
    bed=f'{snp_data_pre}.bed',
    bim=f'{snp_data_pre}.bim',
    fam=f'{snp_data_pre}.fam',
    reference_genome='GRCh38'
)

# Quality control of UKB data

In [None]:
geno_sample_qc_file = "file:///mnt/project/fields/data/sample_qc/sample_qc_info.tsv"
geno_sample_qc_table = hl.import_table(
    geno_sample_qc_file,
)
geno_sample_qc_table = geno_sample_qc_table.key_by("sample_names")


In [None]:
def geno_sample_qc(mt, sample_qc_table):
    mt = mt.annotate_cols(sample_qc_ukb=sample_qc_table[mt.s])
    mt = mt.filter_cols(
        (mt.sample_qc_ukb.sex==mt.sample_qc_ukb.genetic_sex) & 
        (mt.sample_qc_ukb.sex_chromosome_aneuploidy!="Yes") & 
        (mt.sample_qc_ukb.genetic_kinship_to_other_participants!="Ten or more third-degree relatives identified")&
        (mt.sample_qc_ukb.out_hetz_missing!="Yes")
    )
    return mt


In [None]:
geno_mt = geno_sample_qc(geno_mt, geno_sample_qc_table)


In [None]:
geno_mt.count()


In [None]:
def geno_variant_qc(mt):
    mt = hl.variant_qc(mt)
    mt = mt.filter_rows((mt.variant_qc.AF[1]>0.001)&(mt.variant_qc.call_rate>0.99))
    return mt


In [None]:
geno_mt = geno_variant_qc(geno_mt)


# Checkpoint 1

In [None]:
# checkpoint save
# Create database in DNAX
db_name = f"ancestry_inference"
stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)
spark.sql(stmt).show()

# Find database ID of newly created database using dxpy method
db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
mt_name = f"geno_filtered.mt"
url = f"dnax://{db_uri}/{mt_name}"


In [None]:
RERUN=False
if RERUN:
    geno_mt.write(url, overwrite=True)

In [None]:
geno_mt = hl.read_matrix_table(url)

In [None]:
geno_mt.count()

# Get the gnomad hgdp 1kg sample data using hail load dataset
All dataset info is present here: https://hail.is/docs/0.2/datasets.html

In [None]:
ref_mt = hl.experimental.load_dataset(
    name="gnomad_hgdp_1kg_subset_dense",
    version="3.1.2",
    reference_genome='GRCh38',
    region='us',
    cloud='aws'
)


# Only keep the sites in ref which are observed in the SNP array data

In [None]:
# only keep variants found in geno table
ref_mt = ref_mt.filter_rows(hl.is_defined(geno_mt.rows()[ref_mt.row_key]))

# Checkpoint 2

In [None]:
# checkpoint save
# Create database in DNAX
db_name = f"ancestry_inference"
stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)
spark.sql(stmt).show()

# Find database ID of newly created database using dxpy method
db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
mt_name = f"ref_overlap_unfiltered.mt"
url = f"dnax://{db_uri}/{mt_name}"


In [None]:
RERUN=False
if RERUN:
    ref_mt.write(url, overwrite=True)

In [None]:
ref_mt = hl.read_matrix_table(url)

In [None]:
ref_mt.count()

# Quality control of reference data

In [None]:
def ref_sample_qc(mt):
    # only keep samples which have gnomad high quality and are not related
    mt = mt.filter_cols(
        (mt.high_quality==True) & 
        (mt.relatedness_inference.related==False) & 
        (mt.gnomad_high_quality==True)
    )
    return mt

def ref_variant_qc(mt):
    mt = hl.variant_qc(mt)
    mt = mt.filter_rows((mt.variant_qc.AF[1]>0.001)&(mt.variant_qc.call_rate>0.99))
    mt = mt.filter_rows((mt.locus.contig=="chrX")|(mt.locus.contig=="chrY"), keep=False)
    return mt

def ref_ld_prune(mt):
    pruned_variant_table = hl.ld_prune(mt.GT, r2=0.1)
    mt = mt.filter_rows(hl.is_defined(pruned_variant_table[mt.row_key]))
    return mt

In [None]:
ref_mt = ref_sample_qc(ref_mt)

In [None]:
ref_mt.count()

In [None]:
ref_mt = ref_variant_qc(ref_mt)

In [None]:
ref_mt.count()

In [None]:
ref_mt = ref_ld_prune(ref_mt)

In [None]:
ref_mt.count()

# Checkpoint 3

In [None]:
# checkpoint save
# Create database in DNAX
db_name = f"ancestry_inference"
stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)
spark.sql(stmt).show()

db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
mt_name = f"ref_overlap_filtered.mt"
url = f"dnax://{db_uri}/{mt_name}"


In [None]:
RERUN=False
if RERUN:
    ref_mt.write(url, overwrite=True)

In [None]:
ref_mt = hl.read_matrix_table(url)

In [None]:
ref_mt.count()

# Overlap geno final

In [None]:
# only keep variants found in reference table
geno_mt = geno_mt.filter_rows(hl.is_defined(ref_mt.rows()[geno_mt.row_key]))

# Checkpoint 4

In [None]:
# checkpoint save
# Create database in DNAX
db_name = f"ancestry_inference"
stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)
spark.sql(stmt).show()

db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
mt_name = f"geno_overlap_filtered.mt"
url = f"dnax://{db_uri}/{mt_name}"


In [None]:
RERUN=False
if RERUN:
    geno_mt.write(url, overwrite=True)
    

In [None]:
geno_mt = hl.read_matrix_table(url)

In [None]:
geno_mt.count()

# Overlap ref final

In [None]:
# only keep variants found in geno table
ref_mt = ref_mt.filter_rows(hl.is_defined(geno_mt.rows()[ref_mt.row_key]))

# Checkpoint 5

In [None]:
# checkpoint save
# Create database in DNAX
db_name = f"ancestry_inference"
stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)
spark.sql(stmt).show()

db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
mt_name = f"ref_overlapped_filtered.mt"
url = f"dnax://{db_uri}/{mt_name}"


In [None]:
RERUN=False
if RERUN:
    ref_mt.write(url, overwrite=True)
    

In [None]:
ref_mt = hl.read_matrix_table(url)

In [None]:
ref_mt.count()

# PCA calculation and projection

In [None]:
# Compute loadings and allele frequency for reference dataset

eigenvalues, scores, loadings_ht = hl.hwe_normalized_pca(ref_mt.GT, k=20, compute_loadings=True)   

ref_mt = ref_mt.annotate_rows(af=hl.agg.mean(ref_mt.GT.n_alt_alleles()) / 2)                

loadings_ht = loadings_ht.annotate(af=ref_mt.rows()[loadings_ht.key].af)            


# Save PCA and ancestry for ref 

In [None]:
scores = scores.annotate(ancestry_pred=ref_mt.cols()[scores.s].gnomad_population_inference.pop)

In [None]:
def upload_file_to_project(filename, proj_dir):
    dxpy.upload_local_file(filename, folder=proj_dir, parents=True)
    print(f"*********{filename} uploaded!!*********")
    return

In [None]:
ref_pca_df = scores.to_pandas()


In [None]:
ref_pca_df.head()


In [None]:
ref_pca_df[[f"pca_{i}" for i in range(1, 21)]] = pd.DataFrame(ref_pca_df.scores.tolist(), index= ref_pca_df.index)

In [None]:
ref_pca_df = ref_pca_df.drop(columns=["scores"])

In [None]:
proj_dir = f"/notebooks/ancestry_inference/data/"
filename = "ref_pca.csv.gz"
ref_pca_df.to_csv(filename, index=False)
upload_file_to_project(filename, proj_dir)


# Save PCA projections for geno

In [None]:
# checkpoint save
# Create database in DNAX
db_name = f"ancestry_inference"
stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)
spark.sql(stmt).show()

db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
mt_name = f"pca_loadings.ht"
url = f"dnax://{db_uri}/{mt_name}"


In [None]:
RERUN=False
if RERUN:
    loadings_ht.write(url, overwrite=True)


In [None]:
loadings_ht = hl.read_table(url)

In [None]:
loadings_ht.describe()

In [None]:
# Project new genotypes onto loadings

ht = hl.experimental.pc_project(geno_mt.GT, loadings_ht.loadings, loadings_ht.af)

In [None]:
ht.describe()

In [None]:
geno_mt.count()

In [None]:
# checkpoint save
# Create database in DNAX
db_name = f"ancestry_inference"
stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)
spark.sql(stmt).show()

db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
mt_name = f"geno_sample_pca.ht"
url = f"dnax://{db_uri}/{mt_name}"


In [None]:
RERUN=True
if RERUN:
    ht.write(url, overwrite=True)


In [None]:
ht = hl.read_table(url)

In [None]:
geno_pca_df = ht.to_pandas()

In [None]:
geno_pca_df[[f"pca_{i}" for i in range(1, 21)]] = pd.DataFrame(geno_pca_df.scores.tolist(), index= geno_pca_df.index)


In [None]:
geno_pca_df = geno_pca_df.drop(columns=["scores"])


In [None]:
proj_dir = f"/notebooks/ancestry_inference/data/"
filename = "geno_pca.csv.gz"
ref_pca_df.to_csv(filename, index=False)
upload_file_to_project(filename, proj_dir)

