In [None]:
from __future__ import division
from operator import add
from math import log
import csv
import pickle
import sys
from collections import Counter
import re
import glob, os

import ntpath
import functools
import itertools

from time import time
import argparse
from PRS_run import *

In [None]:
filetype="GEN"

# log file
snp_log="NFP_pruned_nodup_PRS_170116_snplog.csv"
## Setting parameters
gwas_id=0   # column of SNP ID
gwas_p=1     # column of P value
gwas_or=2    # column of odds ratio
gwas_a1=3    # column of a1 in the GWAS
gwas_a1f=5  # column index of maf in the GWAS

# defin column number for contents in genfile
if filetype.lower()=="vcf":
    geno_id= 2 # column number with rsID
    geno_start=9 # column number of the 1st genotype, in the raw vcf files, after separated by the delimiter of choice
    geno_a1 = 3 # column number that contains the reference allele
    GENO_delim= "\t"
elif filetype.lower()=="gen":
    geno_id = 1
    geno_start=5
    geno_a1=3
    GENO_delim=" "
# List of thresholds:
thresholds=[0.5,0.3,0.2,0.1,0.05, 0.01, 0.001, 0.0001]
# file delimiters:
GWAS_delim=" "

# Name of GWAS file
gwasFiles="gwas.cross.txt"
GWAS_has_header=True

# programme parameter
log_or=True  # sepcify whether you want to log your odds ratios
check_ref=True # if you know that there are mismatch between the top strand in the genotypes and that of the GWAS, set True. Not checking the reference allele will improve the speed
use_maf=True   # wheather to use MAF to check reference allele

# sample file path and name
sampleFilePath="/scratch/vvp-220-aa/NFP_plink/KieranNFP_0922.sample" # include the full/relative path and name of the sample file
sampleFileDelim=" "  # sample File Delimiter
sampleFileID=[0]   # which column in the sample file has the ID
sample_skip=2 # how many lines to skip so that the sample names can be matched to the genotypes 1-to-1, taking into account the header of the sample file
##output file information

outputPath="NFP_pruned_nodup_PRS_170116.csv"

# Sepcify whether to check for duplicate SNPs
checkDup=True


# get the name of the genotype files
genoFileNamePattern="/scratch/vvp-220-aa/NFP/NFP_pruned_nodup.gen"

# get the whole list of the file names
genoFileNames=glob.glob(genoFileNamePattern)

In [None]:
##  start spark context
import pyspark
from pyspark.sql import SparkSession
from pyspark import SparkConf, SparkContext
APP_NAME="PRS"

spark=SparkSession.builder.appName(APP_NAME).getOrCreate()

# if using spark < 2.0.0, use the pyspark module to make Spark context
# conf = pyspark.SparkConf().setAppName(APP_NAME).set()#.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

sc  = spark.sparkContext

#sc = spark.sparkContext
sc.setLogLevel("WARN")
log4jLogger = sc._jvm.org.apache.log4j
LOGGER = log4jLogger.LogManager.getLogger(__name__)
print("Start Reading Files")
print("Using these genoytpe files: ")

for filename in genoFileNames[:min(24, len(genoFileNames))]:
    print(filename)
if len(genoFileNames)>23:
    print("and more...")

print("total of {} files".format(str(len(genoFileNames))))
# 1. Load files

In [None]:
# read the raw data
genodata=sc.textFile(genoFileNamePattern)
#print("Using the GWAS file: {}".format(ntpath.basename(gwasFiles)))
print("Using the GWAS file: {}".format(gwasFiles))
gwastable=spark.read.option("header",GWAS_has_header).option("delimiter", "\t").csv(gwasFiles).cache()
print("Showing top 5 rows of GWAS file")
gwastable.show(5)

print("System recognizes the following information in the GWAS :")
print("SNP ID : Column {}".format(gwas_id))
print("P-values : Column {}".format(gwas_p))
print("Effect size : Column {}".format(gwas_or))
print("Allele A1 : Column {}".format(gwas_a1))
print("Allele A2 : Column {}".format(gwas_a1+1))
if use_maf:
    print("Allele Frequencies : Column {}".format(gwas_a1f))

In [None]:
# 1.1 Filter GWAS and prepare odds ratio

# filter the genotype to contain only the SNPs less than the maximum p value threshold in the GWAS
maxThreshold=max(thresholds)  # maximum p value
gwasOddsMapMax=filterGWASByP_DF(GWASdf=gwastable, pcolumn=gwas_p, idcolumn=gwas_id, oddscolumn=gwas_or, pHigh=maxThreshold, logOdds=log_or)
gwasOddsMapMaxCA=sc.broadcast(gwasOddsMapMax).value  # Broadcast the map

# ### 2. Initial processing
# at this step, the genotypes are already filtered to keep only the ones in 'gwasOddsMapMax'
bpMap={"A":"T", "T":"A", "C":"G", "G":"C"}
tic=time()
if filetype.lower()=="vcf":
    print("Genotype data format : VCF ")

    # [chrom, bp, snpid, A1, A2, *genotype]
    genointermediate=genodata.filter(lambda line: ("#" not in line)).map(lambda line: line.split(GENO_delim)).filter(lambda line: line[geno_id] in gwasOddsMapMaxCA).map(lambda line: line[0:5]+[chunk.split(":")[3] for chunk in line[geno_start::]]).map(lambda line: line[0:5]+[triplet.split(",") for triplet in line[5::]])

    ## (snpid, [genotypes])
    genotable=genointermediate.map(lambda line: (line[geno_id], list(itertools.chain.from_iterable(line[5::])))).mapValues(lambda geno: [float(x) for x in geno])
    if check_ref:
        if use_maf:
            print("Correcting strand alignment, using MAF")
            genoA1f=genointermediate.map(lambda line: (line[geno_id], (line[geno_a1], line[geno_a1+1]), [float(x) for x in list(itertools.chain.from_iterable(line[5::]))])).map(lambda line: (line[0], line[1][0], line[1][1], getA1f(line[2]))).toDF(["Snpid_geno", "GenoA1", "GenoA2", "GenoA1f"])

            # 'GwasA1F' means the allele of the A1 frequency in the GWAS
            gwasA1f=gwastable.rdd.map(lambda line:(line[gwas_id], line[gwas_a1], line[gwas_a1+1], line[gwas_a1f])).toDF(["Snpid_gwas", "GwasA1", "GwasA2", "GwasA1F"])

            # checktable = [ geno_snpid, genoA1, genoA2, genoA1f, gwas_snpid, gwasA1, gwasA2, gwasA1f]
            checktable=genoA1f.join(gwasA1f, genoA1f["Snpid_geno"]==gwasA1f["Snpid_gwas"], "inner").cache()
            if checkDup:
                flagList = checktable.rdd.map(lambda line: checkAlignmentDF(line, bpMap)).collect()  #  (snpid, flag)
                flagMap = rmDup(flagList)
            else:
                flagMap = checktable.rdd.map(lambda line: checkAlignmentDF(line, bpMap)).collectAsMap()

        else:
            print("Correcting strand alignment, without using MAF")
            genoalleles=genointermediate.map(lambda line: (line[geno_id], (line[geno_a1], line[geno_a1+1]), [float(x) for x in list(itertools.chain.from_iterable(line[5::]))])).map(lambda line: (line[0], line[1][0], line[1][1])).toDF(["Snpid_geno", "GenoA1", "GenoA2"])

            gwasalleles=gwastable.rdd.map(lambda line:(line[gwas_id], line[gwas_a1], line[gwas_a1+1])).toDF(["Snpid_gwas", "GwasA1", "GwasA2"])

            checktable=genoalleles.join(gwasalleles, genoalleles["Snpid_geno"]==gwasalleles["Snpid_gwas"], "inner").cache()

            if checkDup:
                flagList = checktable.rdd.map(lambda line: checkAlignmentDFnoMAF(line, bpMap)).collect()
                flagMap = rmDup(flagList)
            else:
                # no need to check the duplicates if the data is preprocessed
                flagMap = checktable.rdd.map(lambda line: checkAlignmentDFnoMAF(line, bpMap)).collectAsMap()

        print("Generating genotype dosage while taking into account difference in strand alignment")
        flagMap=sc.broadcast(flagMap).value
        genotypeMax=genotable.filter(lambda line: line[0] in flagMap and flagMap[line[0]]!="discard").map(lambda line: makeGenotypeCheckRef(line, checkMap=flagMap)).cache()

    else:
        print("Generating genotype dosage without checking reference allele alignments")
        genotypeMax=genotable.mapValues(lambda line: makeGenotype(line)).cache()
        if checkDup:
            genotypeCount=genotypeMax.map(lambda line: (line[0], 1)).reduceByKey(lambda a,b: a+b).filter(lambda line: line[1]==1).collectAsMap()
            genotypeMax=genotypeMax.filter(lambda line: line[0] in genotypeCount)

elif filetype.lower() == "gen":
    print("Genotype data format : GEN")
    genotable=genodata.map(lambda line: line.split(GENO_delim)).filter(lambda line: line[geno_id] in gwasOddsMapMaxCA).map(lambda line: (line[geno_id], line[geno_start::])).mapValues(lambda geno: [float(call) for call in geno])
    if check_ref:
        if use_maf:
            print("Correcting strand alignment, using MAF")
            genoA1f=genodata.map(lambda line: line.split(GENO_delim)).map(lambda line: (line[geno_id], line[geno_a1], line[geno_a1+1], getA1f([float(x) for x in line[geno_start::]]))).toDF(["Snpid_geno", "GenoA1", "GenoA2", "GenoA1f"])
            gwasA1f=gwastable.rdd.map(lambda line:(line[gwas_id], line[gwas_a1], line[gwas_a1+1], line[gwas_a1f])).toDF(["Snpid_gwas", "GwasA1", "GwasA2", "GwasA1f" ])
            checktable=genoA1f.join(gwasA1f, genoA1f["Snpid_geno"]==gwasA1f["Snpid_gwas"], "inner").cache()
            if checkDup:
                print("Searching and removing duplicated SNPs")
                flagList = checktable.rdd.map(lambda line: checkAlignmentDF(line, bpMap)).collect()
                flagMap = rmDup(flagList)
            else:
                flagMap = checktable.rdd.map(lambda line: checkAlignmentDF(line, bpMap)).collectAsMap()
        else:
            print("Correcting strand alignment, without using MAF")
            genoalleles=genodata.map(lambda line: line.split(GENO_delim)).map(lambda line: (line[geno_id], line[geno_a1], line[geno_a1+1])).toDF(["Snpid_geno", "GenoA1", "GenoA2"])
            gwasalleles=gwastable.rdd.map(lambda line:(line[gwas_id], line[gwas_a1], line[gwas_a1+1])).toDF(["Snpid_gwas", "GwasA1", "GwasA2"])
            checktable=genoalleles.join(gwasalleles, genoalleles["Snpid_geno"]==gwasalleles["Snpid_gwas"], "inner").cache()

            if checkDup:
                print("Searching and removing duplicated SNPs")
                flagList = checktable.rdd.map(lambda line: checkAlignmentDFnoMAF(line, bpMap)).collect()
                flagMap = rmDup(flagList)
            else:
                flagMap = checktable.rdd.map(lambda line: checkAlignmentDFnoMAF(line, bpMap)).collectAsMap()

        print("Generating genotype dosage while taking into account difference in strand alignment")
        flagMap=sc.broadcast(flagMap).value
        genotypeMax=genotable.filter(lambda line: line[0] in flagMap and flagMap[line[0]]!="discard" ).map(lambda line: makeGenotypeCheckRef(line, checkMap=flagMap)).cache()

    else:
        print("Generating genotype dosage without checking strand alignments")
        genotypeMax=genotable.mapValues(lambda line: makeGenotype(line)).cache()
        if checkDup:
            genotypeCount=genotypeMax.map(lambda line: (line[0], 1)).reduceByKey(lambda a,b: a+b).filter(lambda line: line[1]==1).collectAsMap()
            genotypeMax=genotypeMax.filter(lambda line: line[0] in genotypeCount)

print("Dosage generated in {:f} seconds".format(time()-tic) )
samplesize=int(len(genotypeMax.first()[1]))
print("Detected {} samples" .format(str(samplesize)))

#genoa1f.map(lambda line:"\t".join([line[0], "\t".join(line[1]), str(line[2])])).saveAsTextFile("../MOMS_info03_maf")

# Calculate PRS at the sepcified thresholds

In [None]:
def calcPRSFromGeno(genotypeRDD, oddsMap):
    totalcount=genotypeRDD.count()
    multiplied=genotypeRDD.map(lambda line:[call * oddsMap[line[0]] for call in line[1]])
    PRS=multiplied.reduce(lambda a,b: map(add, a, b))
    normalizedPRS=[x/totalcount for x in PRS]
    return (totalcount,PRS)

def calcAll(genotypeRDD, gwasRDD, thresholdlist, logsnp):
    prsMap={}
    thresholdNoMaxSorted=sorted(thresholdlist, reverse=True)

    thresholdmax=max(thresholdlist)
    idlog={}
    start=time()
    for threshold in thresholdNoMaxSorted:
        tic=time()
        gwasFilteredBC=sc.broadcast(filterGWASByP_DF(GWASdf=gwasRDD, pcolumn=gwas_p, idcolumn=gwas_id, oddscolumn=gwas_or, pHigh=threshold, logOdds=log_or))
        #gwasFiltered=spark.sql("SELECT snpid, gwas_or_float FROM gwastable WHERE gwas_p_float < {:f}".format(threshold)
        print("Filtered GWAS at threshold of {}. Time spent : {:f} seconds".format(str(threshold), time()-tic))
        checkpoint=time()
        filteredgenotype=genotypeRDD.filter(lambda line: line[0] in gwasFilteredBC.value)

        if not filteredgenotype.isEmpty():
            if logsnp:
                idlog[threshold]=filteredgenotype.map(lambda line:line[0]).collect()
            prsMap[threshold]=calcPRSFromGeno(filteredgenotype, gwasFilteredBC.value)

            print("Finished calculating PRS at threshold of {}. Time spent : {:f} seconds".format(str(threshold), time()-checkpoint))
    return prsMap, idlog

In [None]:
prsDict, snpids=calcAll(genotypeMax,gwastable, thresholds, logsnp=snp_log)

In [None]:
# log which SNPs are used in PRS
if snp_log:
    logoutput=writeSNPlog(snpids, snp_log)
# generate labels for samples
#if filetype.lower()=="vcf":
    #subjNames=genodata.filter(lambda line: "#CHROM" in line).map(lambda line: line.split(GENO_delim)[9::]).collect()[0]
    #output=writePRS(prsDict,  outputPath, samplenames=subjNames)

if sampleFilePath!="NOSAMPLE":
    # get sample name from the provided sample file
    subjNames=getSampleNames(sampleFilePath,sampleFileDelim,sampleFileID, skip=sample_skip)
    print("Extracted {} sample labels".format(len(subjNames[0])))
    output=writePRS(prsDict,  outputPath, samplenames=subjNames)
else:
    print("No sample file detected, generating labels for samples.")
    output=writePRS(prsDict,  outputPath, samplenames=None)
