In [1]:
from pyspark.sql.session import SparkSession
from pyspark.sql import functions as F
from glow import *
from glow.wgr.functions import *
import pandas as pd
spark = SparkSession.builder.getOrCreate()

In [15]:
# df = spark.read.parquet('validation/result/glow/sim_01-v100-b50/reduced_blocks_flat.parquet')
# df

In [5]:
label_df = pd.read_csv('validation/data/sim_01/traits.csv', index_col='sample_id')
label_df

Unnamed: 0_level_0,Y0000
sample_id,Unnamed: 1_level_1
S0000001,-1015.789060
S0000002,-956.278049
S0000003,-969.424269
S0000004,-1025.074713
S0000005,-990.430356
...,...
S0000246,-1000.212787
S0000247,-1007.231613
S0000248,-1088.631907
S0000249,-1036.334048


In [14]:
cov_df = pd.read_csv('validation/data/sim_01/covariates.csv', index_col='sample_id')
cov_df

Unnamed: 0_level_0,X000,X001,X002
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
S0000001,1.764052,0.400157,0.978738
S0000002,2.240893,1.867558,-0.977278
S0000003,0.950088,-0.151357,-0.103219
S0000004,0.410599,0.144044,1.454274
S0000005,0.761038,0.121675,0.443863
...,...,...,...
S0000246,-2.245322,0.564009,-1.284552
S0000247,-0.104343,-0.988002,-1.177629
S0000248,-1.140196,1.754986,-0.132988
S0000249,-0.765702,0.555787,0.010349


In [7]:
y_hat_df = pd.read_csv('validation/result/glow/sim_01-v100-b50/predictions.csv', index_col='sample_id')
y_hat_df

Unnamed: 0_level_0,Y0000
sample_id,Unnamed: 1_level_1
S0000001,0.005314
S0000002,0.096403
S0000003,0.108917
S0000004,-0.153500
S0000005,-0.007197
...,...
S0000246,-0.237901
S0000247,0.235395
S0000248,-0.330481
S0000249,-0.167726


In [9]:
variant_df = spark.read.parquet('/tmp/variant_df.parquet')
variant_df.printSchema()

root
 |-- contigName: string (nullable = true)
 |-- names: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- position: double (nullable = true)
 |-- start: long (nullable = true)
 |-- end: long (nullable = true)
 |-- referenceAllele: string (nullable = true)
 |-- alternateAlleles: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- genotypes: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- sampleId: string (nullable = true)
 |    |    |-- calls: array (nullable = true)
 |    |    |    |-- element: integer (containsNull = true)
 |-- values: array (nullable = true)
 |    |-- element: double (containsNull = true)



In [10]:
def reshape_for_gwas(spark, label_df):
    # https://github.com/projectglow/glow/blob/04257f65ad64b45b2ad4a9417292e0ead6f94212/python/glow/wgr/functions.py
    assert check_argument_types()

    if label_df.index.nlevels == 1:  # Indexed by sample id
        transposed_df = label_df.T
        column_names = ['label', 'values']
    elif label_df.index.nlevels == 2:  # Indexed by sample id and contig name
        # stacking sorts the new column index, so we remember the original sample
        # ordering in case it's not sorted
        ordered_cols = pd.unique(label_df.index.get_level_values(0))
        transposed_df = label_df.T.stack()[ordered_cols]
        column_names = ['label', 'contigName', 'values']
    else:
        raise ValueError('label_df must be indexed by sample id or by (sample id, contig name)')

    transposed_df['values_array'] = transposed_df.to_numpy().tolist()
    return spark.createDataFrame(transposed_df[['values_array']].reset_index(), column_names)

In [11]:
adjusted_phenotypes = reshape_for_gwas(spark, label_df - y_hat_df)
adjusted_phenotypes.printSchema()

root
 |-- label: string (nullable = true)
 |-- values: array (nullable = true)
 |    |-- element: double (containsNull = true)



In [17]:
variant_df.withColumnRenamed('values', 'callValues').crossJoin(adjusted_phenotypes)

DataFrame[contigName: string, names: array<string>, position: double, start: bigint, end: bigint, referenceAllele: string, alternateAlleles: array<string>, genotypes: array<struct<sampleId:string,calls:array<int>>>, values: array<double>, label: string, values: array<double>]

In [20]:
# Must be run in glow env
wgr_gwas = (
    variant_df
    .withColumnRenamed('values', 'callValues')
    .crossJoin(
        adjusted_phenotypes
        .withColumnRenamed('values', 'phenotypeValues')
    )
    .select(
        'start',
        'names',
        'label',
        expand_struct(linear_regression_gwas( 
            F.col('callValues'),
            F.col('phenotypeValues'),
            F.lit(cov_df.to_numpy())
        ))
    )
)
wgr_gwas

TypeError: 'JavaPackage' object is not callable