In [1]:
import os
import sys
import os.path as osp
# Necessary for workers to use same conda environment
# Otherwise you get "module 'glow' not found" errors when actually running jobs
os.environ["PYSPARK_PYTHON"] = sys.executable
# export SPARK_DRIVER_MEMORY=64g
# export ARROW_PRE_0_15_IPC_FORMAT=1
import pyspark.sql.functions as F
from pyspark.sql.session import SparkSession
from glow import *
from glow.wgr.functions import *
from glow.wgr.linear_model import *
import pandas as pd

In [8]:
path = osp.join(os.environ['WORK_DIR'], 'data/gwas/tutorial/1_QC_GWAS/HapMap_3_r3_1.bed')
path

'/home/jovyan/work/data/gwas/tutorial/1_QC_GWAS/HapMap_3_r3_1.bed'

In [9]:
#os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages io.projectglow:glow_2.11:0.5.0' # does not work
# Do in $SPARK_HOME/conf/spark-defaults.conf instead
spark = SparkSession.builder\
    .config('spark.jars.packages', 'io.projectglow:glow_2.11:0.5.0')\
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")\
    .getOrCreate()

In [10]:
import glow
glow.register(spark)

In [11]:
df = (
    spark.read.format('plink')
    .option("bimDelimiter", "\t")
    .option("famDelimiter", " ")
    .option("includeSampleIds", True)
    .option("mergeFidIid", False)
    .load(path)
)

In [12]:
df.printSchema()

root
 |-- contigName: string (nullable = true)
 |-- names: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- position: double (nullable = true)
 |-- start: long (nullable = true)
 |-- end: long (nullable = true)
 |-- referenceAllele: string (nullable = true)
 |-- alternateAlleles: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- genotypes: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- sampleId: string (nullable = true)
 |    |    |-- calls: array (nullable = true)
 |    |    |    |-- element: integer (containsNull = true)



In [13]:
df.show(3)

+----------+------------+--------+------+------+---------------+----------------+--------------------+
|contigName|       names|position| start|   end|referenceAllele|alternateAlleles|           genotypes|
+----------+------------+--------+------+------+---------------+----------------+--------------------+
|         1| [rs2185539]|     0.0|556737|556738|              C|             [T]|[[NA06989, [0, 0]...|
|         1|[rs11510103]|     0.0|557615|557616|              A|             [G]|[[NA06989, [0, 0]...|
|         1|[rs11240767]|     0.0|718813|718814|              C|             [T]|[[NA06989, [0, 0]...|
+----------+------------+--------+------+------+---------------+----------------+--------------------+
only showing top 3 rows



In [16]:
dfv = df\
  .withColumn('values', mean_substitute(genotype_states(F.col('genotypes')))) \
  .filter(F.size(F.array_distinct('values')) > 1)\
  .filter(F.col('contigName') == F.lit('22'))
dfv.printSchema()

root
 |-- contigName: string (nullable = true)
 |-- names: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- position: double (nullable = true)
 |-- start: long (nullable = true)
 |-- end: long (nullable = true)
 |-- referenceAllele: string (nullable = true)
 |-- alternateAlleles: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- genotypes: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- sampleId: string (nullable = true)
 |    |    |-- calls: array (nullable = true)
 |    |    |    |-- element: integer (containsNull = true)
 |-- values: array (nullable = true)
 |    |-- element: double (containsNull = true)



In [17]:
dfv.show(3)

+----------+------------+--------+--------+--------+---------------+----------------+--------------------+--------------------+
|contigName|       names|position|   start|     end|referenceAllele|alternateAlleles|           genotypes|              values|
+----------+------------+--------+--------+--------+---------------+----------------+--------------------+--------------------+
|        22|[rs11089128]|     0.0|14560202|14560203|              A|             [G]|[[NA06989, [0, 0]...|[0.0, 1.0, 0.0, 0...|
|        22| [rs7288972]|     0.0|14564327|14564328|              T|             [C]|[[NA06989, [0, 0]...|[0.0, 1.0, 1.0, 0...|
|        22|[rs11167319]|     0.0|14850624|14850625|              T|             [G]|[[NA06989, [0, 1]...|[1.0, 0.0, 0.0, 0...|
+----------+------------+--------+--------+--------+---------------+----------------+--------------------+--------------------+
only showing top 3 rows



In [18]:
dfv.count()

18308

In [11]:
sample_ids = get_sample_ids(dfv)
print(len(sample_ids))
sample_ids[:5]

165


['NA06989', 'NA11891', 'NA11843', 'NA12341', 'NA12739']

In [12]:
block_df, sample_blocks = block_variants_and_samples(dfv, sample_ids, variants_per_block=5000, sample_block_count=5)

root
 |-- values: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- fractionalSampleBlockSize: double (nullable = true)
 |-- sample_block: string (nullable = true)



In [13]:
block_df.printSchema()

root
 |-- header: string (nullable = false)
 |-- size: integer (nullable = false)
 |-- values: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- header_block: string (nullable = false)
 |-- sample_block: string (nullable = true)
 |-- sort_key: integer (nullable = true)
 |-- mu: double (nullable = true)
 |-- sig: double (nullable = true)



In [14]:
pd.Series(sample_blocks)

1    [NA06989, NA11891, NA11843, NA12341, NA12739, ...
2    [NA12282, NA11920, NA12776, NA12283, NA07435, ...
3    [NA12489, NA12399, NA12413, NA10843, NA12842, ...
4    [NA11829, NA12239, NA12762, NA12716, NA12878, ...
5    [NA06994, NA11993, NA11995, NA12891, NA12864, ...
dtype: object

In [15]:
len(sample_ids), len(set(sample_ids))

(165, 165)

In [16]:
#label_df = pd.read_csv(phenotypes_path, index_col='sample_id')
label_df = pd.DataFrame({
    'sample_id': sample_ids,
    'trait_1': np.random.normal(size=len(sample_ids)),
    'trait_2': np.random.normal(size=len(sample_ids))
}).set_index('sample_id')
label_df = ((label_df - label_df.mean()) / label_df.std(ddof=0))[['trait_1', 'trait_2']]
label_df

Unnamed: 0_level_0,trait_1,trait_2
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1
NA06989,-0.078535,0.550310
NA11891,-0.419782,-0.050742
NA11843,0.774156,-2.126330
NA12341,0.771144,-1.314122
NA12739,0.402055,-0.865296
...,...,...
NA12752,1.005390,-0.517686
NA12043,0.534191,-0.535534
NA12264,-0.383037,0.109099
NA10854,-0.121615,-1.001846


In [17]:
# covariates = pd.read_csv(covariates_path, index_col='sample_id')
cov_df = pd.DataFrame({
    'sample_id': sample_ids,
    'cov_1': np.random.normal(size=len(sample_ids)),
    'cov_2': np.random.normal(size=len(sample_ids))
}).set_index('sample_id')
cov_df = ((cov_df - cov_df.mean()) / cov_df.std(ddof=0))
cov_df

Unnamed: 0_level_0,cov_1,cov_2
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1
NA06989,-0.961607,-0.240424
NA11891,-1.312328,-0.032669
NA11843,0.720588,0.281981
NA12341,0.157190,-0.014216
NA12739,0.178456,-1.166741
...,...,...
NA12752,0.253341,-0.922501
NA12043,-0.030301,-0.419506
NA12264,-1.912005,0.365930
NA10854,0.596679,-0.790976


In [18]:
stack = RidgeReducer()
reduced_block_df = stack.fit_transform(block_df, label_df, sample_blocks, cov_df)
reduced_block_df.printSchema()

Generated alphas: [1.29671818e+06 1.71166800e+06 2.56750200e+06 5.13500400e+06
 1.28375100e+08]
root
 |-- header: string (nullable = true)
 |-- size: integer (nullable = true)
 |-- values: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- header_block: string (nullable = true)
 |-- sample_block: string (nullable = true)
 |-- sort_key: integer (nullable = true)
 |-- mu: double (nullable = true)
 |-- sig: double (nullable = true)
 |-- alpha: string (nullable = true)
 |-- label: string (nullable = true)



In [None]:
estimator = RidgeRegression()
model_df, cv_df = estimator.fit(reduced_block_df, label_df, sample_blocks, cov_df)
model_df.printSchema()
cv_df.printSchema()

In [None]:
y_hat_df = estimator.transform(reduced_block_df, label_df, sample_blocks, model_df, cv_df, cov_df)
y_hat_df.printSchema()