In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import col, split, concat, udf, struct, avg
import re
from pyspark.ml.feature import Bucketizer
from pyspark.sql.types import StringType
import glow


#start session 
spark_session = SparkSession\
        .builder\
        .master("spark://192.168.2.81:7077") \
        .appName("nucleotide_div")\
        .config("spark.dynamicAllocation.enabled", False)\
        .config("spark.shuffle.service.enabled", False)\
        .config("spark.dynamicAllocation.executorIdleTimeout","30s")\
        .config("spark.executor.cores",2)\
        .getOrCreate()

glow.register(spark)
df = spark_session.read.format('vcf').load("hdfs://192.168.2.81:9000//user/LDSA/ALL.chr1.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz")


#Get names of sample columns (samples start with HG and NA)
columns = df.schema.names
columns = [x for x in columns if re.match("HG*|NA*", x)]

#extract 100 samples for test 
df = df.drop(*columns[100:len(columns)])

df.show()

TypeError: condition should be string or Column

In [4]:
spark_context.stop()

NameError: name 'spark_context' is not defined

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import col, split, concat, udf, struct, avg
import re
from pyspark.ml.feature import Bucketizer
from pyspark.sql.types import StringType

#start session 
spark_session = SparkSession\
        .builder\
        .master("spark://192.168.2.81:7077") \
        .appName("nucleotide_div")\
        .config("spark.dynamicAllocation.enabled", False)\
        .config("spark.shuffle.service.enabled", False)\
        .config("spark.dynamicAllocation.executorIdleTimeout","30s")\
        .config("spark.executor.cores",2)\
        .getOrCreate()
        
spark_context = spark_session.sparkContext

vcf = spark_context.textFile("hdfs://192.168.2.81:9000//user/LDSA/ALL.chr1.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz")


vcf = vcf.filter(lambda line : not  line.startswith('t=VCF'))\
    .filter(lambda line : not line.startswith('##'))\
    .map(lambda line : line.split())
    
#create df with header 
vcf = vcf.toDF(schema = vcf.first()) 
vcf = vcf.filter(vcf["#CHROM"]!="#CHROM")

#Get names of sample columns (samples start with HG and NA)
columns = vcf.schema.names
columns = [x for x in columns if re.match("HG*|NA*", x)]

#extract 100 samples for test 
vcf = vcf.drop(*columns[100:len(columns)])

#Make POS and QUAL integer, filter on QUAL 
vcf = vcf.withColumn("POS", vcf["POS"].cast(IntegerType()))\
        .withColumn('QUAL', vcf['QUAL'].cast(IntegerType()))\
        .filter(vcf['QUAL'] > 20)
        
#Get names of sample columns 
columns = vcf.schema.names
columns = [x for x in columns if re.match("HG*|NA*", x)]

for sample in columns: 
    vcf = vcf.withColumn(sample, split(col(sample), ":").cast(ArrayType(StringType())).alias(sample)[0])

#Concatenate the genotype columns to one column 
vcf = vcf.withColumn('GENOTYPES',concat(*columns))

#drop the sample columns 
vcf = vcf.drop(*columns) 

#count number of time each allel occurs 
def count_allel(GENOTYPES, ALT):
    #get number of allels 
    n = re.findall('[ATGC.]+', ALT)
    res = []
    for i in range(0,len(n)+1):
        res.append(GENOTYPES.count(str(i)))
    return res

count_allel_udf = udf(count_allel, StringType())

vcf = vcf.withColumn('ALLEL_FREQ', count_allel_udf('GENOTYPES', 'ALT'))

#Calculate Nd
def Nd(ALLEL_FREQ):
    sqrd = [x**2 for x in ALLEL_FREQ]
    diff = [x-y for x, y in zip(sqrd, ALLEL_FREQ)]
    return 1-(sum(diff)/((sum(ALLEL_FREQ)**2)-sum(ALLEL_FREQ)))
    
Nd_udf = udf(Nd, StringType())

vcf = vcf.withColumn('Nd', Nd_udf('ALLEL_FREQ'))

#define the splits (use min max values from POS argument...) 
splits = list(range(0, 57227415, 1000000))
bucketizer = Bucketizer(splits=splits, inputCol="POS", outputCol="POS_BUCKET")

#Transform original data into bucket index
vcf = bucketizer.transform(vcf)

vcf = vcf.withColumn('Nd', vcf['Nd'].cast(StringType()))

#Group by 100000 bp interval and calculate per base nucleotide diversity 
vcf_Nd = vcf.groupBy('POS_BUCKET')\
            .agg(avg(col("Nd")))\
            .orderBy('POS_BUCKET')

vcf_Nd.select('POS_BUCKET').show()       
#Save data to hdfs 
#vcf.write.format('csv').option('header',True).mode('overwrite').option('sep',',').save('/user/LDSA/output.csv')
spark_context.stop()