In [50]:
import re
from collections import Counter
import pandas as pd
import math
import pysam
import operator

class PileupRecord:
    
    def __init__(self,line):
        fields = line.split("	")
        self.seq = fields[0]
        self.pos = int(fields[1])
        self.ref = fields[2]
        self.rCount = int(fields[3])
        self.rRes = fields[4]
        self.qual = fields[5][:-1]
        
    def print(self):
        print('{}    {}    {}    {}    {}    {}'.format(self.seq, self.pos, self.ref, self.rCount, self.rRes, self.qual))
        


def DetectPolymorphicSite(readResult,rCount):
    
    '''Detects variants at particular location, returns counts and types of them.
       input: string, info about particular location alignment results from all reads
       output: dataframe, columns: variant, count, type'''
    
    readResult = readResult.upper().replace(',','.')

    irrelevant = list(set(re.findall(r'\^[^\.]', readResult)))
    for s in irrelevant:
        readResult = readResult.replace(s,'')
    
    occ = re.findall(r'[\.][ACGT]*[+-][0-9][ACGT]*', readResult)
    var = []
    variants = list(set(occ))
    varCounts = [occ.count(indel) for indel in variants]
    varTypes = ['indel']*len(variants)
    for i in range(0,len(variants)):
        var.append([variants[i], varCounts[i], varTypes[i]])
        
    for s in variants:
        readResult = readResult.replace(s,'')
    
    occ = re.findall(r'[AGCT]', readResult)
    SNVs = list(set(occ))
    SNVCounts = [occ.count(SNV) for SNV in SNVs]
    SNVTypes = ['SNV']*len(SNVs)
    for i in range(0,len(SNVs)):
        var.append([SNVs[i], SNVCounts[i], SNVTypes[i]])

    matchCount = len(re.findall(r'[\.]', readResult))
    var.append(['.', matchCount, 'match'])
    
    var.sort(key = lambda x: x[1], reverse = True)

    if len(var) > 2:
        if var[1][1] == 1 and rCount > 50:
            return var[:1]
        return var[:2]
    else:
        return var   


def Genotyping(var,pro):
    
    '''Determines genotype.'''
    
    if len(var) == 1:
        if var[0][2] == 'match':
            genotype = (0,0)
        else:
            genotype = (1,1)
        P = []
    if len(var) == 2:
        pr = [0]*3
        k1 = var[0][1]
        k2 = var[1][1]
        # k1 = 40
        # k2 = 1
        p = pro
        pr[0] = math.factorial(k1+k2)//math.factorial(k1)//math.factorial(k2)*(p**k1)*(1-p)**(k2) # a1a1
        pr[1] = math.factorial(k1+k2)//math.factorial(k1+k2)//math.factorial(0)*(p**(k1+k2))*(1-p)**0 # a1a2
        pr[2] = math.factorial(k1+k2)//math.factorial(k2)//math.factorial(k1)*p**k2*(1-p)**(k1) # a2a2
    
        P = [pr[0]/(pr[0]+pr[1]+pr[2]), pr[1]/(pr[0]+pr[1]+pr[2]), pr[2]/(pr[0]+pr[1]+pr[2])]
        index, value = max(enumerate(P), key=operator.itemgetter(1))
    
        if var[0][2] == 'match': 
            if index == 0:
                genotype = (0,0)
            elif index == 1:
                genotype = (0,1)
            else:
                genotype =(1,1)
        elif var[1][2] == 'match':
            if index == 0:
                genotype = (1,1)
            elif index == 1:
                genotype = (0,1)
            else:
                genotype = (0,0)
            P.reverse()
        else:
            if index == 0:
                genotype = (1,1)
            elif index == 1:
                genotype = (1,2)
            else:
                genotype = (2,2)
  
    return genotype, P

def DetermineAltsField(polymorphic_site):
    
    polymorphic_site = [elem for elem in polymorphic_site if(elem[0] != '.')]
    if len(polymorphic_site) == 0:
        alts = ['.']
    else:
        alts = [p[0] for p in polymorphic_site]
    
    return alts

def AvgQual(qual):
    qual_prob = []
    for char in qual:
        qual_prob.append(1 - 10**-((ord(char)-33)/10.0))
    
    return sum(qual_prob)/len(qual_prob)

def Metrics(tp, fn, fp):
    prec = tp/(tp+fp)
    rec = tp/(tp+fn)
    f1 = 2*prec*rec/(prec+rec)
    
    return prec, rec, f1

pileup_file = open("merged-normal.pileup")
VCF_read = pysam.VariantFile("merged-normal.mpileup.called.vcf", "r")


    
#counts = []
#counts_true_pos = []
#counts_false_pos = []
#counts_false_neg = []
#counts_true_neg = []

probability = [0.6, 0.7, 0.8, 0.9, 0.95]
#probability = [0.6,0.8]
TP = []
FN = []
FP = []
for pr in probability:
    print(pr)
    data_test = []
    for rec in VCF_read.fetch():
        data_test.append([rec.chrom, rec.pos])
    
    true_pos = 0
    false_neg = 0
    false_pos = 0
    true_neg = 0

    for line in pileup_file:
        
        rec = PileupRecord(line)
        polymorphic_site = DetectPolymorphicSite(rec.rRes, rec.rCount)
        genotype, probs = Genotyping(polymorphic_site, pr)
        #counts.append([count[1] for count in polymorphic_site])
        if (genotype != (0,0)) and ([rec.seq, rec.pos] in data_test):
            true_pos += 1
            data_test.remove([rec.seq, rec.pos])
            #counts_true_pos.append([count[1] for count in polymorphic_site])
        elif (genotype != (0,0)) and ([rec.seq, rec.pos] not in data_test): 
            false_pos += 1
            #counts_false_pos.append([count[1] for count in polymorphic_site])
        elif (genotype == (0,0)) and ([rec.seq, rec.pos] in data_test):
            false_neg += 1
            data_test.remove([rec.seq, rec.pos])
            #counts_false_neg.append([count[1] for count in polymorphic_site])
        elif (genotype == (0,0)) and ([rec.seq, rec.pos] not in data_test):
            true_neg += 1
            #counts_true_neg.append([count[1] for count in polymorphic_site])   

    TP.append(true_pos)
    FN.append(false_neg)
    FP.append(false_pos)

In [None]:
for tp, fn, fp in TP, FN, FP:
    prec, re, f1 = Metrics(tp, fn, fp)
    p.append(prec)
    r.append(re)
    f.append(f1)

In [None]:
import matplotlib.pyplot as plt
fig=plt.figure(figsize =(10,6))
plt.plot(probability, TP, label = 'true positives') 
plt.plot(probability, FP, label = 'false positives')
plt.plot(probability, FN, label = 'false negatives')
plt.xlabel('p') 
plt.ylabel('metrics') 
plt.legend()

In [None]:
import matplotlib.pyplot as plt
fig=plt.figure(figsize = (10,6))
plt.plot(probability, p, label = 'precision') 
plt.plot(probability, r, label = 'recall')
plt.plot(probability, f, label = 'f1 score')
plt.xlabel('p') 
plt.ylabel('metrics') 
plt.legend()