In [None]:
from pyspark.sql import SparkSession

#start session 
spark_session = SparkSession\
        .builder\
        .master("spark://192.168.2.81:7077") \
        .appName("glow_FL")\
        .config("spark.dynamicAllocation.enabled", False)\
        .config("spark.shuffle.service.enabled", False)\
        .config("spark.dynamicAllocation.executorIdleTimeout","30s")\
        .config("spark.executor.cores",2)\
        .getOrCreate()
        
spark_context = spark_session.sparkContext

In [11]:
vcf = spark_session.read.format("vcf").option("vcfHeader", True).option("includeSampleIds", True).option('flattenInfoFields', True).load("hdfs://192.168.2.81:9000//user/LDSA/ALL.chr22.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz")
#vcf = spark_context.textFile("hdfs://192.168.2.81:9000//user/LDSA/ALL.chr1.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz")
#vcf.show(2)

# display() is only available on Databricks
# display(vcf.limit(10))

In [18]:
vcf.select('genotypes').select('element').show()
vcf.printSchema()

AnalysisException: "cannot resolve '`element`' given input columns: [genotypes];;\n'Project ['element]\n+- Project [genotypes#749]\n   +- Relation[contigName#713,start#714L,end#715L,names#716,referenceAllele#717,alternateAlleles#718,qual#719,filters#720,splitFromMultiAllelic#721,INFO_MEND#722,INFO_AC#723,INFO_CIEND#724,INFO_NS#725,INFO_AFR_AF#726,INFO_VT#727,INFO_AN#728,INFO_MULTI_ALLELIC#729,INFO_SAS_AF#730,INFO_CIPOS#731,INFO_AA#732,INFO_AF#733,INFO_EAS_AF#734,INFO_AMR_AF#735,INFO_DP#736,... 13 more fields] vcf\n"

In [None]:
#remove rows that start with ## 
#vcf = vcf.filter(lambda line : not  line.startswith('t=VCF'))\
#    .filter(lambda line : not line.startswith('##'))\
#    .map(lambda line : line.split())
#
#vcf.take(10)

In [None]:
#create df with header 
#vcf = vcf.toDF(schema = vcf.first()) 
#vcf = vcf.filter(vcf["#CHROM"]!="#CHROM")
#vcf.show()

In [None]:
from pyspark.sql.types import *
from pyspark.sql.functions import col
import re 

#Get names of sample columns (samples start with HG and NA)
columns = vcf.schema.names
columns = [x for x in columns if re.match("HG*|NA*", x)]

#extract 100 samples for test 
vcf = vcf.drop(*columns[100:len(columns)])

#Make POS and QUAL integer, filter on QUAL 
vcf = vcf.withColumn("POS", vcf["POS"].cast(IntegerType()))\
        .withColumn('QUAL', vcf['QUAL'].cast(IntegerType()))\
        .filter(vcf['QUAL'] > 20)
                 
vcf.printSchema()

In [None]:
import re 
from pyspark.sql.functions import col, split

#Get names of sample columns 
columns = vcf.schema.names
columns = [x for x in columns if re.match("HG*|NA*", x)]

for sample in columns: 
    vcf = vcf.withColumn(sample, split(col(sample), ":").cast(ArrayType(StringType())).alias(sample)[0])
    
vcf.select('POS', 'ALT', 'QUAL', 'FORMAT', 'HG00101').show()

In [None]:
from pyspark.sql.functions import concat
import re 

#Concatenate the genotype columns to one column 
vcf = vcf.withColumn('GENOTYPES',concat(*columns))

#drop the sample columns 
vcf = vcf.drop(*columns) 

vcf.show()

In [None]:
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf, struct
import re 

#count number of time each allel occurs 
def count_allel(GENOTYPES, ALT):
    #get number of allels 
    n = re.findall('[ATGC.]+', ALT)
    res = []
    for i in range(0,len(n)+1):
        res.append(GENOTYPES.count(str(i)))
    return res

count_allel_udf = udf(count_allel, StringType())

vcf = vcf.withColumn('ALLEL_FREQ', count_allel_udf('GENOTYPES', 'ALT'))
vcf.show()

In [None]:
#Calculate Nd
def Nd(ALLEL_FREQ):
    sqrd = [x**2 for x in ALLEL_FREQ]
    diff = [x-y for x, y in zip(sqrd, ALLEL_FREQ)]
    return 1-(sum(diff)/((sum(ALLEL_FREQ)**2)-sum(ALLEL_FREQ)))
    
Nd_udf = udf(Nd, StringType())

vcf = vcf.withColumn('Nd', Nd_udf('ALLEL_FREQ'))
vcf.select('Nd').show()  

In [None]:
from pyspark.ml.feature import Bucketizer

#define the splits (use min max values from POS argument...) 
splits = list(range(0, 57227415, 100000))
bucketizer = Bucketizer(splits=splits, inputCol="POS", outputCol="POS_BUCKET")

#Transform original data into bucket index
vcf = bucketizer.transform(vcf)

vcf.select('POS_BUCKET', 'POS', 'Nd').show()

In [None]:
from pyspark.sql.functions import avg 

vcf = vcf.withColumn('Nd', vcf['Nd'].cast(StringType()))

#Group by 100000 bp interval and calculate per base nucleotide diversity 
vcf_Nd = vcf.groupBy('POS_BUCKET')\
            .agg(avg(col("Nd")))\
            .orderBy('POS_BUCKET')

#Save data to hdfs 
vcf_Nd.write.format('csv').option('header',True).mode('overwrite').option('sep',',').save('/user/LDSA/output.csv')
        
vcf_Nd.show()


In [None]:
#Collect values for plotting 
y = vcf_Nd.select('AVG(Nd)').collect()
x = vcf_Nd.select('POS_BUCKET').collect()

In [None]:
#plot
import matplotlib.pyplot as plt

plt.plot(x, y)
plt.ylabel('Nucleotide Diversity')
plt.xlabel('bp')
plt.title('Y Chromosome')
plt.legend('', loc='upper left')

plt.show()

In [None]:
spark_context.stop()