In [11]:
import random, glob, os, argparse
import pandas as pd
import gzip as gz

In [12]:
buscofileloc = 'anof_funestus_new_buscooutputs.tsv.gz'
hapalignmentfile = 'anof_funestus_new_self_aln.hap.gz'
myasmFileName = 'primary_new.fasta.gz'

n_best_sol = 5
maxzeros = 10

niterations = 5
weight_missing = 1
weight_duplicate = 1
weight_single = 1
weight_fragmented = 1

contigsDictSet = set()
contigsDictionary = dict()
buscosDictionary = dict() #key: busco, value = [count of complete, count of fragmented], each busco can only be added to one contig


In [13]:
class PriorityQueue(object):
    def __init__(self): #set of cost, contigs, pid, qpct
        self.pq = []
        self.max_size = n_best_sol

    def length(self):
        return len(self.pq)

    def delete_node(self):
        max_value = 0
        for i in range(len(self.pq)):
            if self.pq[i][0] > self.pq[max_value][0]:
                max_value = i
        item = self.pq[max_value]
        self.pq.pop(max_value)
        return item

    def print_queue(self):
        for item in self.pq:
            print(item)

    def add_set(self):
        myGoodSet = set()
        for item in self.pq:
            myGoodSet.update(set(item[1]))
        return myGoodSet


    def head(self):
        min_value = 0
        for i in range(len(self.pq)):
            if self.pq[i][0] < self.pq[min_value][0]:
                min_value = i
        item = self.pq[min_value]
        return item[0]


    def insert_node(self, data):
        if (self.length() < self.max_size):
            self.pq.append(data)
        else:
            max_value = 0
            for i in range(len(self.pq)):
                if self.pq[i][0] > self.pq[max_value][0]:
                    max_value = i
            if data[0] < self.pq[max_value][0]:
                self.delete_node()
                self.pq.append(data)

    def return_pq(self):
        return self.pq


In [14]:
def find_count(mydf, step1, step2,step3, myqthreshold,mypidthreshold, qr_align):
    missing = 0
    duplicate = 0
    complete = 0
    fragmented = 0

    myfiltereddf = mydf[mydf['QPct'] >= myqthreshold + step1]
    mynewfiltereddf = myfiltereddf[myfiltereddf['PID'] >= mypidthreshold + step2]
    qr_align_filtered = mynewfiltereddf[mynewfiltereddf['QRAlignLenPct'] >= qr_align + step3]
    myfiltqrycontigs = set(qr_align_filtered['qName'])
    mynewset = myAllContigsSet - myfiltqrycontigs
    for contig in mynewset: 
        for busco in contigsDictionary[contig]: 
            for contig_and_type in BUSCOS2CTGSDICT[busco]: 
                if contig_and_type[1] == 'C':
                    complete += 1
                elif contig_and_type[1] == 'D':
                    duplicate += 1
                elif contig_and_type[1] == 'M':
                    missing+=1
                else:
                    fragmented+=1
    return {"fragmented": fragmented, "missing": missing, "complete": complete, "duplicate": duplicate}


In [15]:
def cost(find_count_dict):
    missingB = find_count_dict["missing"]
    dupeB = find_count_dict["duplicate"]
    fragB = find_count_dict["fragmented"]
    compB = find_count_dict["complete"]
    cost = (weight_missing*missingB)+(weight_duplicate*dupeB)+(fragB*weight_fragmented)
    if (weight_single*compB) != 0:
        cost = cost/(weight_single*compB)
    return cost

def generated_df(mydf, step1, step2,step3,myqthreshold,mypidthreshold,qr_align):
    myfiltereddf = mydf[mydf['QPct'] >= myqthreshold + step1]
    mynewfiltereddf = myfiltereddf[myfiltereddf['PID'] >= mypidthreshold + step2]
    my_final_filtered_df = mynewfiltereddf[mynewfiltereddf['QRAlignLenPct'] >= qr_align + step3]
    return (my_final_filtered_df, step1, step2,step3)

def find_neighbors_cost(step,myqthreshold,mypidthreshold, qr_align):
    cost1 = cost(find_count(mydf, step, step, -step,  myqthreshold, mypidthreshold, qr_align))
    min_cost = (cost1,generated_df(mydf, step, step,-step, myqthreshold,mypidthreshold, qr_align)[1],generated_df(mydf, step, step, -step,myqthreshold,mypidthreshold, qr_align)[2], generated_df(mydf, step, step,-step, myqthreshold,mypidthreshold, qr_align)[3])

    cost2 = cost(find_count(mydf, step, step, step, myqthreshold, mypidthreshold, qr_align))  # 1st quadrant
    min_cost = (cost1, generated_df(mydf, step, step, step, myqthreshold, mypidthreshold, qr_align)[1], generated_df(mydf, step, step, step, myqthreshold, mypidthreshold, qr_align)[2], generated_df(mydf, step, step,step, myqthreshold,mypidthreshold, qr_align)[3])


    cost3= cost(find_count(mydf, -step, step, step, myqthreshold,mypidthreshold, qr_align))
    if (cost3<min_cost[0]):
        min_cost = (cost3, generated_df(mydf, -step, step,step, myqthreshold,mypidthreshold, qr_align)[1],generated_df(mydf, -step, step,step, myqthreshold,mypidthreshold, qr_align)[2], generated_df(mydf, -step, step,step, myqthreshold,mypidthreshold, qr_align)[3])

    cost4= cost(find_count(mydf, -step, step, -step, myqthreshold,mypidthreshold, qr_align))
    if (cost4<min_cost[0]):
        min_cost = (cost4, generated_df(mydf, -step, step,-step, myqthreshold,mypidthreshold, qr_align)[1],generated_df(mydf, -step, step,-step, myqthreshold,mypidthreshold, qr_align)[2], generated_df(mydf, -step, step,-step, myqthreshold,mypidthreshold, qr_align)[3])

    cost5 = cost(find_count(mydf, -step, -step,step,myqthreshold,mypidthreshold, qr_align))
    if (cost5<min_cost[0]):
        min_cost = (cost5, generated_df(mydf, -step, -step,step,myqthreshold,mypidthreshold,qr_align)[1],generated_df(mydf, -step, -step,step,myqthreshold,mypidthreshold, qr_align)[2], generated_df(mydf, -step, -step,step, myqthreshold,mypidthreshold, qr_align)[3])

    cost6 = cost(find_count(mydf, -step, -step, -step, myqthreshold, mypidthreshold, qr_align))
    if (cost6<min_cost[0]):
        min_cost = (cost6, generated_df(mydf, -step, -step,-step,myqthreshold,mypidthreshold,qr_align)[1],generated_df(mydf, -step, -step,-step,myqthreshold,mypidthreshold, qr_align)[2], generated_df(mydf, -step, -step,-step, myqthreshold,mypidthreshold, qr_align)[3])

    cost7 = cost(find_count(mydf, step, -step,step,myqthreshold,mypidthreshold,qr_align))
    if (cost7<min_cost[0]):
        min_cost = (cost7, generated_df(mydf, step, -step,step,myqthreshold,mypidthreshold,qr_align)[1],generated_df(mydf, step, -step,step,myqthreshold,mypidthreshold,qr_align)[2], generated_df(mydf, step, -step,step, myqthreshold,mypidthreshold, qr_align)[3])

    cost8 = cost(find_count(mydf, step, -step,-step,myqthreshold,mypidthreshold,qr_align))
    if (cost8<min_cost[0]):
        min_cost = (cost8, generated_df(mydf, step, -step,-step,myqthreshold,mypidthreshold,qr_align)[1],generated_df(mydf, step, -step,-step,myqthreshold,mypidthreshold,qr_align)[2], generated_df(mydf, step, -step,-step, myqthreshold,mypidthreshold, qr_align)[3])
    return min_cost



def random_restart(myqthreshold,mypidthreshold, qr_align):
    step1 = random.uniform(-myqthreshold, 1.0 -myqthreshold)
    step2 = random.uniform(-mypidthreshold, 1.0 - mypidthreshold)
    step3 = random.uniform(-qr_align, 1.0 - qr_align)
    return (step1+myqthreshold, step2+mypidthreshold, step3+qr_align)

def hill_climbing2(mytuple):
    myqthreshold = mytuple[0]
    mypidthreshold = mytuple[1]
    qr_align = mytuple[2]
    priority_queue = PriorityQueue()
    step = 0.005
    searching_plateau = False
    current_cost = cost(find_count(mydf,step,step,step,myqthreshold,mypidthreshold, qr_align))
    my_q_step = 0
    my_pid_step = 0
    qr_align_step = 0
    steps_searched_plateau = 0
    max_steps_in_plateau =10
    
    for i in range(1,niterations):
        step += 0.0005
        current_cost, my_q_step,my_pid_step, qr_align_step = find_neighbors_cost(step,myqthreshold,mypidthreshold, qr_align)
        myqthreshold +=my_q_step
        mypidthreshold +=my_pid_step
        qr_align+=qr_align_step
        if mydf.empty:
            myqthreshold, mypidthreshold,qr_align  = random_restart(myqthreshold,mypidthreshold, qr_align)
            current_cost = cost(find_count(mydf, step,step,step,myqthreshold,mypidthreshold, qr_align))
            continue
        if myqthreshold < 0 or mypidthreshold<0:
            myqthreshold, mypidthreshold,qr_align  = random_restart(myqthreshold, mypidthreshold, qr_align)
            current_cost = cost(find_count(mydf, step, step, step, myqthreshold, mypidthreshold, qr_align))
            continue

        if priority_queue.length() != 0 and (current_cost == priority_queue.head()):
            searching_plateau = True

        if priority_queue.length() == 0:
            filtered_df = mydf[mydf['QPct'] >= myqthreshold]
            filtered_df = filtered_df[filtered_df['PID'] >= mypidthreshold]
            filtered_df = filtered_df[filtered_df['QRAlignLenPct'] >= qr_align]
            priority_queue.insert_node([current_cost, filtered_df['qName'],myqthreshold,mypidthreshold, qr_align])
        if current_cost >  priority_queue.head():
            if searching_plateau == True:
                searching_plateau = False
                steps_searched_plateau = 0
            myqthreshold, mypidthreshold, qr_align  = random_restart(myqthreshold,mypidthreshold, qr_align)
            current_cost = cost(find_count(mydf, step,step,step,myqthreshold,mypidthreshold, qr_align))
            continue
        if current_cost < priority_queue.head():
            if searching_plateau == True:
                searching_plateau = False
                steps_searched_plateau = 0
            filtered_df = mydf[mydf['QPct'] >= myqthreshold]
            filtered_df = filtered_df[filtered_df['PID'] >= mypidthreshold]
            filtered_df = filtered_df[filtered_df['QRAlignLenPct'] >= qr_align]
            priority_queue.insert_node([current_cost,filtered_df['qName'],myqthreshold,mypidthreshold, qr_align])
        if (searching_plateau == True):
            steps_searched_plateau+=1
        if (searching_plateau == True and steps_searched_plateau == max_steps_in_plateau):
            searching_plateau = False
            steps_searched_plateau = 0
            myqthreshold, mypidthreshold, qr_align = random_restart(myqthreshold, mypidthreshold, qr_align)
            current_cost = cost(find_count(mydf, step, step, step,myqthreshold, mypidthreshold, qr_align))
            continue

    return priority_queue

    #WriteNewAssembly("/Users/mansiagrawal/PycharmProjects/hillClimbing/primary_new.fasta","/Users/mansiagrawal/PycharmProjects/hillClimbing/primary_new_new.fasta",myGoodContigsSet)
    #WriteNewAssembly(myasmFileName,"./primary.hap.fasta",myGoodContigsSet)

def CalculateContigSizes(asmFileName):
    # contigsDict[contigname] = [contiglen,headerpos,startseqpos,endseqpos]
    fin = gz.open(asmFileName)
    lastPos = headerPos = fin.tell()
    totalLines = sum(1 for line in fin)
    fin.seek(lastPos)
    seqLen = 0
    seqName = ''
    lastPos = 0
    count = 0
    myContigSizeDict = dict()
    while count < totalLines:
        lastPos = headerPos = fin.tell()
        line = fin.readline().decode().replace('\n', '')
        count = count + 1
        if line[0:1] == '>':
            seqName = line.split(" ")[0].replace('>', '').replace('/', '_')
            lastPos = startPos = fin.tell()
            line = fin.readline().decode().replace('\n', '')
            count = count + 1
            while line[0:1] != '>' and line[0:1] != '':
                seqLen = seqLen + len(line)
                endPos = lastPos
                lastPos = fin.tell()
                line = fin.readline().decode().replace('\n', '')
                count = count + 1
            if line[0:1] == '>' or line[0:1] == '':
                myContigSizeDict[seqName] = [seqLen, headerPos, startPos, endPos]
                seqName = ''
                seqLen = 0
                count = count - 1
                fin.seek(lastPos)
    fin.close()
    return myContigSizeDict

def WriteNewAssembly(myasmFileName, newASMFileName, myGoodContigsSet):
    fin = gz.open(myasmFileName, 'r')
    fout = open(newASMFileName, 'w')
    #myContigSizeDict = CalculateContigSizes(myasmFileName)
    for contig in myGoodContigsSet:
        myContigPositionsList = myContigSizeDict[contig]
        fin.seek(myContigPositionsList[1])  # extract headerpos
        fout.write(fin.readline().decode())
        newPos = fin.tell()
        mySeq = fin.readline().decode().replace('\n', '')
        while newPos != myContigPositionsList[3]:
            newPos = fin.tell()
            mySeq = mySeq + fin.readline().decode().replace('\n', '')
        fout.write(mySeq + '\n')
    fout.close()

In [16]:
mydf = pd.read_csv(hapalignmentfile, sep='\t', header=None, names=['qName', 'tName', 'qSize', 'QPct', 'PID', 'QRAlignLenPct'], dtype={'qName': object, 'tName': object})

In [17]:
def CalculateContigSizes(asmFileName):
    # contigsDict[contigname] = [contiglen,headerpos,startseqpos,endseqpos]
    fin = gz.open(asmFileName)
    lastPos = headerPos = fin.tell()
    totalLines = sum(1 for line in fin)
    fin.seek(lastPos)
    seqLen = 0
    seqName = ''
    lastPos = 0
    count = 0
    myContigSizeDict = dict()
    while count < totalLines:
        lastPos = headerPos = fin.tell()
        line = fin.readline().decode().replace('\n', '')
        count = count + 1
        if line[0:1] == '>':
            seqName = line.split(" ")[0].replace('>', '').replace('/', '_')
            lastPos = startPos = fin.tell()
            line = fin.readline().decode().replace('\n', '')
            count = count + 1
            while line[0:1] != '>' and line[0:1] != '':
                seqLen = seqLen + len(line)
                endPos = lastPos
                lastPos = fin.tell()
                line = fin.readline().decode().replace('\n', '')
                count = count + 1
            if line[0:1] == '>' or line[0:1] == '':
                myContigSizeDict[seqName] = [seqLen, headerPos, startPos, endPos]
                seqName = ''
                seqLen = 0
                count = count - 1
                fin.seek(lastPos)
    fin.close()
    return myContigSizeDict

In [18]:
myContigSizeDict = CalculateContigSizes(myasmFileName)
myAllContigsSet = set(myContigSizeDict)

for elem in myAllContigsSet:
    contigsDictionary[elem] = set()
BUSCOS2CTGSDICT = dict() 
for line in gz.open(buscofileloc):
    line = line.decode()
    line = line.strip().split()

    if line[1][0] != 'M':
        if line[2] not in contigsDictionary.keys():
            contigsDictionary[line[2]] = set() 
        contigsDictionary[line[2]].add(line[0])
    if len(line) >= 1:
        myBUSCO = line[0]
        myBUSCOtype = line[1][0]
        if len(line) > 2:
            myCtg = line[2]
        else:
            myCtg = ''
        if myBUSCO not in BUSCOS2CTGSDICT.keys():
            BUSCOS2CTGSDICT[myBUSCO] = []
        if myBUSCOtype != 'M':
            BUSCOS2CTGSDICT[myBUSCO].append([myCtg, myBUSCOtype])

contigsDictSet = set(contigsDictionary.keys())


In [19]:
from multiprocessing import Pool
import torch
num_of_cores = 5
num_total_solutions = 10

In [20]:
#regular cpu
#run with the pandas
import random
import heapq
import time

def generate_random_tuple(num_of_cores):
    return list(random.random() for _ in range(num_of_cores))
parentUPQ = []
qthreshold_tuple = generate_random_tuple(num_of_cores)
pidthreshold_tuple = generate_random_tuple(num_of_cores)
qralign_tuple = generate_random_tuple(num_of_cores)
new_tuple = []

for i in range(num_of_cores):
 list_= [qthreshold_tuple[i],pidthreshold_tuple[i], qralign_tuple[i]]
 new_tuple.append(tuple(list_))

start_time = time.time()

if __name__ == '__main__':
    with Pool(num_of_cores) as p:
        myUPQlist = p.map(hill_climbing2, new_tuple)

   
end_time = time.time()
difference = end_time - start_time
print(difference)

pq = []
for lst in myUPQlist:
  for item in lst.return_pq():
      heapq.heappush(pq, item)

final_solution = []
for i in range(min(num_total_solutions, len(pq))):
    final_solution.append(heapq.heappop(pq))

4.000008821487427
