In [19]:
from pyspark.sql import SparkSession

#start session 
spark_session = SparkSession\
        .builder\
        .master("spark://192.168.2.81:7077") \
        .appName("nucleotide_div_alva")\
        .config("spark.dynamicAllocation.enabled", False)\
        .config("spark.shuffle.service.enabled", False)\
        .config("spark.dynamicAllocation.executorIdleTimeout","30s")\
        .config("spark.executor.cores",2)\
        .getOrCreate()
        
spark_context = spark_session.sparkContext

In [20]:
vcf = spark_context.textFile("hdfs://192.168.2.81:9000//user/LDSA/ALL.chr22.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz")

In [21]:
#remove rows that start with ## 
vcf = vcf.filter(lambda line : not line.startswith('t=VCF'))\
         .filter(lambda line : not line.startswith('##'))\
         .map(lambda line : line.split())

#vcf.take(10)

In [22]:
#create df with header 
vcf = vcf.toDF(schema = vcf.first()) 
vcf = vcf.filter(vcf["#CHROM"]!="#CHROM")
#vcf.show()

In [23]:
from pyspark.sql.types import *
from pyspark.sql.functions import col
import re 

#Get names of sample columns (samples start with HG and NA)
columns = vcf.schema.names
columns = [x for x in columns if re.match("HG*|NA*", x)]

#extract 100 samples for test 
vcf = vcf.drop(*columns[100:len(columns)])

In [24]:
from datetime import datetime

start = datetime.now()

current_time = start.strftime("%H:%M:%S")
print("Start =", current_time)

Start = 07:57:19


In [25]:
#Make POS and QUAL integer, filter on QUAL 
vcf = vcf.withColumn("POS", vcf["POS"].cast(IntegerType()))\
        .withColumn('QUAL', vcf['QUAL'].cast(IntegerType()))\
        .filter(vcf['QUAL'] > 20)
                 
#vcf.printSchema()

In [26]:
import re 
from pyspark.sql.functions import col, split

#Get names of sample columns 
columns = vcf.schema.names
columns = [x for x in columns if re.match("HG*|NA*", x)]

for sample in columns: 
    vcf = vcf.withColumn(sample, split(col(sample), ":").cast(ArrayType(StringType())).alias(sample)[0])
    
#vcf.select('POS', 'ALT', 'QUAL', 'FORMAT', 'HG00101').show()

In [27]:
from pyspark.sql.functions import concat
import re 

#Concatenate the genotype columns to one column 
vcf = vcf.withColumn('GENOTYPES',concat(*columns))

#drop the sample columns 
vcf = vcf.drop(*columns) 

#vcf.show()

In [28]:
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf, struct
import re 

#count number of time each allel occurs 
def count_allel(GENOTYPES, ALT):
    #get number of allels 
    n = re.findall('[ATGC.]+', ALT)
    res = []
    for i in range(0,len(n)+1):
        res.append(GENOTYPES.count(str(i)))
    return res

count_allel_udf = udf(count_allel, StringType())

vcf = vcf.withColumn('ALLEL_FREQ', count_allel_udf('GENOTYPES', 'ALT'))

#vcf.show()

In [29]:
#Calculate Nd
def Nd(ALLEL_FREQ):
    sqrd = [x**2 for x in ALLEL_FREQ]
    diff = [x-y for x, y in zip(sqrd, ALLEL_FREQ)]
    return 1-(sum(diff)/((sum(ALLEL_FREQ)**2)-sum(ALLEL_FREQ)))
    
Nd_udf = udf(Nd, StringType())

vcf = vcf.withColumn('Nd', Nd_udf('ALLEL_FREQ'))

#vcf.select('Nd').show()  

In [30]:
from pyspark.ml.feature import Bucketizer

#define the splits (use min max values from POS argument...) 
splits = list(range(0, 101991189, 500000))
bucketizer = Bucketizer(splits=splits, inputCol="POS", outputCol="POS_BUCKET")

#Transform original data into bucket index
vcf = bucketizer.transform(vcf)

#vcf.select('POS_BUCKET', 'POS', 'Nd').show()

In [31]:
from pyspark.sql.functions import avg 

#Group by 100000 bp interval and calculate per base nucleotide diversity 
vcf_Nd = vcf.groupBy('POS_BUCKET')\
            .agg(avg(col("Nd")))\
            .orderBy('POS_BUCKET')

vcf_Nd.write.format('csv').option('header',True).mode('overwrite').option('delimiter', ',').save('/user/LDSA/CHR22_out')

In [32]:
stop = datetime.now()

current_time = stop.strftime("%H:%M:%S")
print("End =", current_time)

End = 08:16:57


In [33]:
#Collect values for plotting 
#y = vcf_Nd.select('AVG(Nd)').collect()
#x = vcf_Nd.select('POS_BUCKET').collect()

In [34]:
#plot
#import matplotlib.pyplot as plt

#plt.plot(x, y)
#plt.ylabel('Nucleotide Diversity')
#plt.xlabel('bp')
#plt.title('Y Chromosome')
#plt.legend('', loc='upper left')

#plt.show()

In [35]:
#spark_context.stop()

In [36]:
vcf.count()

KeyboardInterrupt: 