# 1. Generate perfect toy data

In [4]:
pipeline_version = 'v3'

In [1]:
import random
import numpy  as np

np.random.seed(1234)
random.seed(1234)

genome_length = 5000

genome = ''
nuc = ['A','C','G','T']
while len(genome) < genome_length:
    n = np.random.choice(nuc)
    genome += n
genome

'TTGCAAACTCTCGGTGAAGGGAAACACTGGTGATACGGGTTTACTATGTACTTTGCGTTAGTGACTCAAACCCTCTCACACACAAAGAGAGTTCGCGGCCGTATCGTGAGTTGGAAGTCTTGTGCGTACACGCGCGCCCTATTCGAACACGCCGTGACCCTGAATAAGAAAACAGCTAGATGGTTGGCTAACCGGTCATGACCCAACTTAAACCAGTCCGCTTTCTAACACTGCCGGGCATAAAAATTCCCGCTCTATTCTTGTCGTGGAGCCTTTATGAGCTTGACCTGATAGTCTCTATCAACGGTCATCACCAGTCGTCATTTAGGCTTACTAAATAGCTAGACCCATCACGCCGTCCAGCGCCTTCCCAAAGCTATCGGTGCCAAGTTCTTGAAGGCGCAGTGTGATAGAGAATCGAATTGTAGAGCCAAACATGTCTCACCCGGTACGTCATCGGGATTTCACGTAACGTGTTTAACCCACCCGTGTTTAAGGATAGAGTGCAGCTGGCGCGGGCGGTACAACCTGGGGCAGCTGTCATCGGCGCTCTAGGGCTATTCGGTGTTGGACATAGGTCAGCGCACCGTGGAAGATCGTCTATTGGCGTTTAGCCCCGGCGCCCAGTTAACCTTTCCGCCGAGCTCCTAACTCCACCTAACGAAACCGTTAAGGATTCTATACCTTATATGGCTCATCGTCCAACTTGTGATCTTGAGTACGCTTTTGCTGGAAGCTCTCTGCTCTGGTCACAAACTCTAAAGGCATATTGTCAGTATCTCTCCTCCCTACGTCCGACGAGACACGCGGGACGCACCCAGGTGGGTCGTGCCCGCTTCGCTCCTATTATTATAGCACCCTTGGACCGCGTATCTACCGACGTATTCGATTATGTGTACTAGGGTAGTAAACCTCCATAAGGACTCCGGCGTAACGGGATTGGTGCGAAATACACTTTTCTGTCAATGGGGTGTTTTTGACGAGCAACGTTCTGGACGA

In [60]:
def generate_long_reads(genome, read_length, overlap):
    n = len(genome)
    reads = []
    idxs = []
    count_unique_reads = 0
    for i in range(0,n,read_length):
        if i > 0:
            read = genome[i-overlap:i+read_length]
            idx_tuple = (i-overlap,i+read_length)
        else:
            read = genome[i:i+read_length]
            idx_tuple = (i,i+read_length)
        count_unique_reads += 1
        reads.extend([read for _ in range(5)])
        idxs.extend([idx_tuple for _ in range(5)])
    return reads, idxs, count_unique_reads

read_length_v1 = 500
overlap_v1 = 100
reads, idxs,count_unique_reads = generate_long_reads(genome, read_length_v1, overlap_v1)
print(f"len(reads)={len(reads)}, count_unique_reads={count_unique_reads}")
with open('reads_' + pipeline_version + '.txt','w') as f:
    for idx_read, read in enumerate(reads):
        print(idx_read, idxs[idx_read], len(read), read)
        f.write(f'{read}\n')
        if idx_read > 4:
            idx_prev_read = idx_read - 5
            assert reads[idx_read][:overlap_v1] == reads[idx_prev_read][-overlap_v1:], 'The overlaps were not produced as expected!!!'
f.close()

len(reads)=50, count_unique_reads=10
0 (0, 500) 500 TTGCAAACTCTCGGTGAAGGGAAACACTGGTGATACGGGTTTACTATGTACTTTGCGTTAGTGACTCAAACCCTCTCACACACAAAGAGAGTTCGCGGCCGTATCGTGAGTTGGAAGTCTTGTGCGTACACGCGCGCCCTATTCGAACACGCCGTGACCCTGAATAAGAAAACAGCTAGATGGTTGGCTAACCGGTCATGACCCAACTTAAACCAGTCCGCTTTCTAACACTGCCGGGCATAAAAATTCCCGCTCTATTCTTGTCGTGGAGCCTTTATGAGCTTGACCTGATAGTCTCTATCAACGGTCATCACCAGTCGTCATTTAGGCTTACTAAATAGCTAGACCCATCACGCCGTCCAGCGCCTTCCCAAAGCTATCGGTGCCAAGTTCTTGAAGGCGCAGTGTGATAGAGAATCGAATTGTAGAGCCAAACATGTCTCACCCGGTACGTCATCGGGATTTCACGTAACGTGTTTAACCCACCCGTGTTTAAGGAT
1 (0, 500) 500 TTGCAAACTCTCGGTGAAGGGAAACACTGGTGATACGGGTTTACTATGTACTTTGCGTTAGTGACTCAAACCCTCTCACACACAAAGAGAGTTCGCGGCCGTATCGTGAGTTGGAAGTCTTGTGCGTACACGCGCGCCCTATTCGAACACGCCGTGACCCTGAATAAGAAAACAGCTAGATGGTTGGCTAACCGGTCATGACCCAACTTAAACCAGTCCGCTTTCTAACACTGCCGGGCATAAAAATTCCCGCTCTATTCTTGTCGTGGAGCCTTTATGAGCTTGACCTGATAGTCTCTATCAACGGTCATCACCAGTCGTCATTTAGGCTTACTAAATAGCTAGACCCATCACGCCGTCCAGCGCCTTCCCAAAGCTATCGGTGCCAAGTTCTTGAAGGCGCAGTGTGATAGAGAATCGAATTGTAGAGCC

In [51]:
# # Example string
# my_string = ''.join([str(i) for i in range(0,100)])

# # Get the last 100 characters
# last_100_characters = my_string[-100:]

# # Print the result
# print(last_100_characters)
# print(my_string[len(my_string)-100:len(my_string)])

# 2. Overlap - MHAP; Jaccard Matrix generation

In [52]:
!lscpu | grep -E '^Thread|^Core|^Socket|^CPU\('

CPU(s):                             8
Thread(s) per core:                 2
Core(s) per socket:                 4
Socket(s):                          1


In [53]:
%pip install mmh3
%pip install biopython
%pip install tqdm

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [7]:
import operator, mmh3, os
import numpy as np
from tqdm import tqdm
import multiprocessing as mp
from typing import List

## Define basic functions needed for MHAP

In [8]:
# hyperparameters (GLOBAL): Fig 2 of MHAP paper
CONSTANT_k = 16 # MHAP paper, this is k1 from the paper
CONSTANT_H = 1256
seed = 42
jaccard_thr = 0.06


# BASE FUNCTIONS ##########################################################

def create_kmers(read: str, k: int)->List[str]:
  # this is number of kmers, and also last idx in the string
  last_kmer_idx = len(read) - k + 1
  all_kmers = ['']*last_kmer_idx
  for i in range(0,last_kmer_idx):
    all_kmers[i] = read[i:i+k]

  return all_kmers

def create_sketch(kmers_read_i: List[str],
                       seed: int)->np.ndarray:
  """
  - kmers_read_i: all kmers for read i
  - seeds: int32 (otherwise, unpredictable), all seeds to be used to create independent hash functions and length needs to be equal to CONSTANT_H
  """
  sketch = np.array([0]*CONSTANT_H)

  fingerprint_1 = [mmh3.hash(key=kmer,seed=seed,signed=False) for kmer in kmers_read_i]
  xor_shifts = [xorshift(x) for x in fingerprint_1]
  sketch[0] = np.min(fingerprint_1)
#   print(f"XOR-PRNG Hash Function {0}")
#   print(fingerprint_1)
  for sketch_idx in range(1,CONSTANT_H):
    next_fingerprint = [0]*len(fingerprint_1)
    for idx, next_rng in enumerate(xor_shifts):
        next_fingerprint[idx] = next(next_rng)
    # print(f"XOR-PRNG Hash Function {sketch_idx}")
    # print(next_fingerprint)
    minmer = np.min(next_fingerprint)
    sketch[sketch_idx] = minmer

  return sketch

def xorshift(seed):
    """XORShift random number generator"""
    state = seed
    while True:
        state ^= state << 13
        state ^= state >> 17
        state ^= state << 5
        yield state & 0xFFFFFFFF  # Ensure 32-bit output

def jaccard_index(sketch_1, sketch_2):
    diff = list(map(operator.sub, sketch_1, sketch_2))
    return diff.count(0)/len(sketch_1)

# BASE FUNCTIONS END ##########################################################


## Get all the sketches

In [18]:
def worker(worker_args):
    idx, read, seed = worker_args
    all_kmers_i = create_kmers(read, k=CONSTANT_k)
    return idx, create_sketch(all_kmers_i, seed)

def get_all_sketches_parallel_v2(reads, seed, chunksize=10):  # Example chunksize set to 10
    with mp.Pool() as pool:
        task_args = [(idx, read, seed) for idx, read in enumerate(reads)]
        results = pool.imap(worker, task_args, chunksize=chunksize)
        all_sketches = list(tqdm(results, total=len(reads), unit='read'))
    return all_sketches

In [19]:
sketches_file = f'all_sketches_{CONSTANT_H}_{pipeline_version}.npy'
try:
    all_sketches = np.load(sketches_file)
    # all_sketches = np.load('aAKLDNBLKADNSLASDN')
    print("LOADED! no need to recalculate")
except:
    all_sketches = get_all_sketches_parallel_v2(reads, seed, chunksize=1)
    all_sketches.sort(key = lambda sketch_tuple: sketch_tuple[0])
    all_sketches = np.array(list(map(lambda sketch_tuple: sketch_tuple[1],all_sketches)))
    np.save(sketches_file,all_sketches)

100%|██████████| 50/50 [00:44<00:00,  1.12read/s]


## Get Jaccard Matrix

In [12]:

# jaccard_mtx_directory = f"./jaccard_mtx_directory_{str(jaccard_thr)}"

def get_jaccard_row(args):
    i, all_sketches, jaccard_threshold, save_rows_bool = args
    jaccard_mtx_row = np.zeros(shape=(1, len(all_sketches)),dtype=bool)
    for j in range(0, len(all_sketches)):
        jaccard_index_val = jaccard_index(all_sketches[i], all_sketches[j])
        outcome = 1 if jaccard_index_val > jaccard_threshold else 0
        jaccard_mtx_row[0, j] = outcome
    row_filename = f'./jaccard_mtx_directory/jaccard_mtx_row_{i}.npy'
    
    if save_rows_bool:
        np.save(row_filename, jaccard_mtx_row)  
    else:
        return (i, jaccard_mtx_row)
    
def get_jaccard_row_start_stop(args):
    i, all_sketches, jaccard_threshold, save_rows_bool, start_row_idx, end_row_idx = args

    if i < start_row_idx or i >= end_row_idx:
        return -1

    jaccard_mtx_row = np.zeros(shape=(1, len(all_sketches)))
    for j in range(0, len(all_sketches)):
        jaccard_index_val = jaccard_index(all_sketches[i], all_sketches[j])
        outcome = 1 if jaccard_index_val > jaccard_threshold else 0
        jaccard_mtx_row[0, j] = outcome
    row_filename = f'./jaccard_mtx_directory/jaccard_mtx_row_{i}.npy'
    
    if save_rows_bool:
        np.save(row_filename, jaccard_mtx_row)
        return
    else:
        return (i, jaccard_mtx_row)

def get_jaccard_matrix_parallel(all_sketches: List[np.ndarray], jaccard_threshold: float, save_rows_bool=True, start_row_idx:int=None, end_row_idx:int=None):
    total_rows = len(all_sketches)
    
    if start_row_idx == None or end_row_idx == None:
        task_function = get_jaccard_row # generate all the reads
        total_rows = len(all_sketches)
        end_row_idx = total_rows
        start_row_idx = 0
        task_args = [(i, all_sketches, jaccard_threshold, save_rows_bool) for i in range(start_row_idx,end_row_idx)]
    else:
        if save_rows_bool != True:
            raise Exception('Must save rows if going to start and stop anywhere except the very beginning or end')
        task_function = get_jaccard_row_start_stop # generate a subset of reads
        task_args = [(i, all_sketches, jaccard_threshold, save_rows_bool, start_row_idx, end_row_idx) for i in range(start_row_idx,end_row_idx)]

    total_tasks = end_row_idx - start_row_idx

    if save_rows_bool == True:
        print('Saving rows!')
        # Create the directory
        if not os.path.exists("./jaccard_mtx_directory"):
            os.mkdir("./jaccard_mtx_directory")
        else:
            print('Directory already exists! Exiting now')
            return

    results_post = []
    with mp.Pool() as pool:
        results_pre = pool.imap_unordered(task_function, task_args) # pool.imap_unordered(get_jaccard_row, task_args)
        for result in tqdm(results_pre, total=total_tasks, desc="Processing rows", unit="row"):
            results_post.append(result)

    if save_rows_bool == True:
        # Check if any exceptions occurred
        for result in results_post:
            if isinstance(result, Exception):
                raise result

        if start_row_idx == None or end_row_idx == None:
            print('All rows of jaccard matrix between successfully saved to numpy files!')
        else:
            print(f'start_row_idx={start_row_idx} up to but not including end_row_idx={end_row_idx} of jaccard matrix saved to numpy files!')
        return
    else:
        jaccard_matrix: np.ndarray = np.zeros(shape=(len(all_sketches),len(all_sketches)))
        for result_tuple in results_post:
            row_idx, row = result_tuple
            jaccard_matrix[row_idx,:] = row
        #     print(' 1 we saw this')
        # print(' 2 we saw this')
        return jaccard_matrix


In [20]:
threshold = 0.05
jaccard_matrix_file = f'jaccard_matrix_{str(threshold)}_{pipeline_version}.npy'
try:
    jaccard_matrix = np.load(jaccard_matrix_file)
    print("LOADED! no need to recalculate")
except:
    jaccard_matrix = get_jaccard_matrix_parallel(all_sketches=all_sketches,jaccard_threshold=threshold,save_rows_bool=False)
    np.save(jaccard_matrix_file,jaccard_matrix)

LOADED! no need to recalculate


# 3. Error Correction - Falcon Sense

In [21]:
import numpy as np

def FittingAlignment(s: str, t: str,
                     match, mismatch, indel_penalty):
    """Given two amino acid strings, a scoring matrix BLOSUM, and an indel penalty, generates the optimal score of fitting alignment
    between the two strings, plus the alignment."""
    score, backtrack_matrix = FittingAlignmentMatrix(s, t, match, mismatch, indel_penalty)
    
    alignment_s, alignment_t = FittingAlignmentOutput(backtrack_matrix, s, t)
    
    st = ''.join([l for l in alignment_s if l != '-'])
    index = 0
    for i in range(len(s)-len(st)+1):
        if s[i:i+len(st)] == st:
            index = i+1
            break
    
    return ((alignment_s, alignment_t, index), score)

def FittingAlignmentMatrix(s: str, t: str,
                           match, mismatch, indel_penalty):
    """Generates the score achieved by fitting alignment of two strings, s and t, and the associated backtracking matrix.
    
    The backtracking matrix elements are None or tuples of integers representing the preceding index in the path."""
    # whichever string is longer will be the "row side" of the alignment matrix (or s if the same length)
    #   will call the longer string v and the shorter one w
    v = s
    w = t
    if len(t) > len(s):
        v = t
        w = s
    
    scores = []
    backtrack_matrix = []
    for _ in range(len(v) + 1):
        scores.append([0] * (len(w) + 1))
        backtrack_matrix.append([None] * (len(w) + 1))
    
    # avoid penalizing gaps in w prior to its first character incorporation by leaving all scores in first column 0
    #  and leaving backtrack matrix value None
        
    # fill in first row
    for j in range(1, len(w) + 1):
        scores[0][j] = scores[0][j - 1] - indel_penalty
        backtrack_matrix[0][j] = (0, j-1)
    
    # fill in remaining positions
    for i in range(1, len(v) + 1):
        for j in range(1, len(w) + 1):
            # determine the value to add to scores[i-1][j-1] from scoring matrix
            match_mismatch_adjustment = match
            if v[i-1] != w[j-1]:
                match_mismatch_adjustment = -mismatch
            
            # consider scores from the three neighbors
            considered_scores = [scores[i-1][j] - indel_penalty, scores[i][j-1] - indel_penalty, scores[i-1][j-1] + match_mismatch_adjustment]
            considered_positions = [(i-1, j), (i, j-1), (i-1, j-1)]
            
            # if we're at the sink, additionally consider all other nodes in final column
            if i == len(v) and j == len(w):
                for k in range(len(v)):
                    considered_scores.append(scores[k][len(w)])
                    considered_positions.append((k, len(w)))
                    
            best_score_index = considered_scores.index(max(considered_scores))
            
            # set score and backtrack matrix direction
            scores[i][j] = considered_scores[best_score_index]
            backtrack_matrix[i][j] = considered_positions[best_score_index]
                
    return (scores[-1][-1], backtrack_matrix)

def FittingAlignmentOutput(backtrack_matrix, s: str, t: str):
    """Reads out the fitting alignment of s and t using the input backtracking matrix."""
    # whichever string is longer is the "row side" of the alignment matrix (or s if the same length)
    #   will call the longer string v and the shorter one w
    v = s
    w = t
    if len(t) > len(s):
        v = t
        w = s
        
    alignment_v = ""
    alignment_w = ""
    
    # account for cases where optimal path didn't reach last column until final edge
    if backtrack_matrix[len(v)][len(w)] == (len(v), len(w) - 1):
        alignment_v = "-" + alignment_v
        alignment_w = w[-1] + alignment_w
    elif backtrack_matrix[len(v)][len(w)] == (len(v) - 1, len(w) - 1):
        alignment_v = v[-1] + alignment_v
        alignment_w = w[-1] + alignment_w
    
    pos = [backtrack_matrix[len(v)][len(w)][0], backtrack_matrix[len(v)][len(w)][1]]
    
    while backtrack_matrix[pos[0]][pos[1]] != None:
        if backtrack_matrix[pos[0]][pos[1]][0] == pos[0] - 1 and backtrack_matrix[pos[0]][pos[1]][1] == pos[1] - 1:
            alignment_v = v[pos[0] - 1] + alignment_v
            alignment_w = w[pos[1] - 1] + alignment_w
            
            pos[0] -= 1
            pos[1] -= 1
        elif backtrack_matrix[pos[0]][pos[1]][0] == pos[0] - 1 and backtrack_matrix[pos[0]][pos[1]][1] == pos[1]:
            alignment_v = v[pos[0] - 1] + alignment_v
            alignment_w = "-" + alignment_w
            
            pos[0] -=1
        else:
            alignment_v = "-" + alignment_v
            alignment_w = w[pos[1] - 1] + alignment_w
            
            pos[1] -=1
    
    # return in order based on what corresponds to s and t
    if len(t) > len(s):
        return (alignment_w, alignment_v)
    return (alignment_v, alignment_w)

def OverlapAlignment(match_reward: int, mismatch_penalty: int, indel_penalty: int,
                    s: str, t: str) -> tuple[int, str, str]:
    """Generates the score of the optimal overlap alignment between s and t, where a suffix of s is aligned with a prefix of t, along
    with the suffix of s and prefix of t that produce this score, dependent on the given match reward, mismatch penalty, and indel penalty."""
    score, backtrack_matrix = OverlapAlignmentMatrix(match_reward, mismatch_penalty, indel_penalty, s, t)
    
    return score

def OverlapAlignmentMatrix(match_reward: int, mismatch_penalty: int, indel_penalty: int,
                           s: str, t: str) -> tuple[int, list[list[tuple[int, int]]]]:
    """Generates the score achieved by the optimal overlap alignment of two strings, s and t, and the associated backtracking matrix.
    
    The backtracking matrix elements are None or tuples of integers representing the preceding index in the path."""
    # initialize scoring grid and backtracking matrix
    scores = []
    backtrack_matrix = []
    for _ in range(len(s) + 1):
        scores.append([0] * (len(t) + 1))
        backtrack_matrix.append([None] * (len(t) + 1))
        
    # avoid penalizing gaps in t prior to its first character incorporation by leaving all scores in first column 0
    #  and leaving backtrack matrix value None
        
    # fill in first row
    for j in range(1, len(t) + 1):
        scores[0][j] = scores[0][j - 1] - indel_penalty
        backtrack_matrix[0][j] = (0, j-1)
        
    # fill in remaining positions
    for i in range(1, len(s) + 1):
        for j in range(1, len(t) + 1):
            # determine whether there's a match here
            match_adjustment = -1 * mismatch_penalty
            if s[i-1] == t[j-1]:
                match_adjustment = match_reward
            
            # consider scores from the three neighbors
            considered_scores = [scores[i-1][j] - indel_penalty, scores[i][j-1] - indel_penalty, scores[i-1][j-1] + match_adjustment]
            considered_positions = [(i-1, j), (i, j-1), (i-1, j-1)]
            
            # if we're at the sink, additionally consider all other nodes in final row
            if i == len(s) and j == len(t):
                for k in range(len(t)):
                    considered_scores.append(scores[len(s)][k])
                    considered_positions.append((len(s), k))
                    
            best_score_index = considered_scores.index(max(considered_scores))
            
            # set score and backtrack matrix direction
            scores[i][j] = considered_scores[best_score_index]
            backtrack_matrix[i][j] = considered_positions[best_score_index]
        
    return (scores[-1][-1], backtrack_matrix)

In [22]:
'''
Algorithm FalconSense
Require: A set of reads R that are aligned on the template S
Ensure: A corrected sequence for S based on the consensus of the reads in
R
1: for every read R ∈ R do
2: Compute the alignment A between R and S allowing matches or indels
(no mismatches);
3: for each aligned position in A do
4: If it is a match aligning S[j] with R[i], add tuple (j, 0, R[i]);
5: If it is a deletion aligning S[j] with −, add tuple (j, 0, −);
6: If it is an insertion aligning − with R[i], suppose this is the dth
insertion after S[j]; add tuple (j, d, R[i]);
7: end for
8: end for

9: For each distinct tuple (p, d, b), let count(p, d, b) be the number of oc-
currences of the tuple in the list.

10: Sort all distinct tuples (p, d, b) by increasing order of p, d and followed
by the alphabetic order of b;
11: for each distinct tuple (p, d, b) in sorted order do
12: if count(p, d, b) > 1

2 x∈{A,C,G,T,−} count(p, 0, x) then

13: Output b;
14: else
15: Output −;
16: end if
17: end for
'''
def FalconSense(aligned_reads):
    consensus_dict = {}
    for read in aligned_reads:
        p = read[2] - 1
        d = 0
        for i in range(len(read[0])):
            if read[0][i] != "-":
                p += 1
                d = 0
            else:
                d += 1
            curr_tuple = (p, d, read[1][i])
            if curr_tuple in consensus_dict:
                consensus_dict[curr_tuple] +=1
            else:
                consensus_dict[curr_tuple] = 1
                
    consensus_data = sorted(consensus_dict.items(), key=lambda x: (x[0][0], x[0][1]))

    consensus_dict = {k: v for k, v in consensus_data}
    options = ['A','C','G','T','-']
    consensus_string = ''
    for tup,count in consensus_dict.items():
        if tup[2] != '-':
            counts = sum(consensus_dict[(tup[0],0,x)] for x in options if (tup[0],0,x) in consensus_dict)
            if count > counts/2:
                consensus_string += tup[2]
            else:
                consensus_string += '-'
    return consensus_dict,consensus_string

In [23]:
# say that what we receive is a dictionary mapping int (read index) to a list of ints (overlapped read indices)
#   (if a: [b] is in it, should b: [a] be as well? yes!! otherwise falcon sense won't work)

def draw_overlap_graph(reads, overlap_dict, match = 3, mismatch = 1, indel_penalty = 2):
    graph = {k:[] for k in range(len(reads))}
    
    # keys will be tuples/pairs of read indices
    overlap_scores = {}
    fitting_scores = {}
    
    corrected_reads = []
    
    for read_idx in range(len(reads)):
        print(read_idx)
        fitting_aligned_reads = []
        fitting_alignment_scores = []
        
        for overlap_read_idx in overlap_dict[read_idx]:
            # perform fitting alignment with no mismatches allowed
            aligned = FittingAlignment(reads[read_idx],reads[overlap_read_idx],match,mismatch=1000,indel_penalty=indel_penalty)
            fitting_aligned_reads.append(aligned[0])
            fitting_alignment_scores.append(aligned[1])
            
            fitting_scores[(read_idx, overlap_read_idx)] = aligned[1]
            #fitting_scores[(overlap_read_idx, read_idx)] = aligned[1]
            
            if (read_idx, overlap_read_idx) not in overlap_scores:
                overlap_A = OverlapAlignment(match, mismatch, indel_penalty, reads[read_idx], reads[overlap_read_idx])
                overlap_B = OverlapAlignment(match, mismatch, indel_penalty, reads[overlap_read_idx], reads[read_idx])
                
                overlap_scores[(read_idx, overlap_read_idx)] = overlap_A
                overlap_scores[(overlap_read_idx, read_idx)] = overlap_B
            
        error_correction = FalconSense(fitting_aligned_reads)[1]
        corrected_read = "".join([c for c in error_correction if c != "-"])
        
        corrected_reads.append(corrected_read)
        
    pairs_seen = []
    
    for pair in overlap_scores:
        if pair not in pairs_seen:
            pair_1 = pair
            pair_2 = tuple(list(pair)[::-1])
        
            pairs_seen += [pair_1, pair_2]
        
            scores = [overlap_scores[pair_1], overlap_scores[pair_2], fitting_scores[pair_1]]
            
            max_idx = np.argmax(scores)
            # max score was from fitting alignment -- ignore
            if max_idx == 2:
                continue
            elif max_idx == 0:
                graph[pair_1[0]].append(pair_1[1])
            else:
                graph[pair_2[0]].append(pair_2[1])
                
    return graph,corrected_reads

In [24]:
mtx = jaccard_matrix
mtx

array([[1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 1., 1.],
       [0., 0., 0., ..., 1., 1., 1.],
       [0., 0., 0., ..., 1., 1., 1.]])

In [25]:
overlap_dict = {i:list(np.where(arr == 1)[0]) for i,arr in enumerate(mtx)}
for k,v in overlap_dict.items():
    idx = v.index(k)
    del v[idx]
    
    for val in v:
        if k not in overlap_dict[val]:
            overlap_dict[val].append(k)
overlap_dict

{0: [1, 2, 3, 4, 5, 6, 7, 8, 9],
 1: [0, 2, 3, 4, 5, 6, 7, 8, 9],
 2: [0, 1, 3, 4, 5, 6, 7, 8, 9],
 3: [0, 1, 2, 4, 5, 6, 7, 8, 9],
 4: [0, 1, 2, 3, 5, 6, 7, 8, 9],
 5: [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 6: [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14],
 7: [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14],
 8: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14],
 9: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14],
 10: [5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19],
 11: [5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19],
 12: [5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19],
 13: [5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19],
 14: [5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19],
 15: [10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24],
 16: [10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24],
 17: [10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24],
 18: [10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24],
 19: [10, 11, 12, 13, 14, 15

# 4. Generate Overlap Graph (before reduction)

In [27]:
import networkx as nx

overlap_graph_file = f'overlap_graph_{pipeline_version}.gexf'

try:
    G = nx.read_gexf(overlap_graph_file)
    corrected_reads = np.load('corrected_reads.npy')
    print('pre-computed results loaded!')
except:
    print('computing the overlap graph...')
    g,corrected_reads = draw_overlap_graph(reads, overlap_dict)
    G = nx.DiGraph(g)
    nx.write_gexf(G, overlap_graph_file)
    np.save('corrected_reads.npy',corrected_reads)
    print(' DONE computing the overlap graph!')

computing the overlap graph...
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
 DONE computing the overlap graph!


# 5. Layout - Find and eliminate cycles, transitive reduction, reduced overlap graph, maximal non branching to generate contigs

In [28]:
def maximal_non_branching_paths(g):
    """Finds maximal non-branching paths in a graph."""
    paths = []
    graph = nx.DiGraph(g)
    nodes_seen = []
    for v in g:
        if check_1_in_1_out(g,v)!= True:
            if len(g[v]) > 0:
                for w in g[v]:
                    non_branching_path = [v,w]
                    while check_1_in_1_out(g,w):
                        non_branching_path.append(g[w][0])
                        w = g[w][0]
                    for node in non_branching_path:
                        if node not in nodes_seen:
                            nodes_seen.append(node)
                    paths.append(non_branching_path)
                    
    cy = sorted(nx.simple_cycles(graph))
    cycle = [c + [c[0]] for c in cy]
    for c in cycle:
        for nodes in c:
            if node not in nodes_seen:
                nodes_seen.append(node)
        seen = False
        for node in c:
            if node in nodes_seen:
                seen=True
                break
        if seen == False:
            paths.append(c)
    return paths


def check_1_in_1_out(graph,node):
    if node not in graph:
        return False
    if len(graph[node]) != 1:
        return False
    l = []
    for val in graph.values():
        l+=val
    if l.count(node) != 1:
        return False
    return True

def remove_cycles_from_graph(G):
    while True:
        try:
            cycle = nx.find_cycle(G, orientation='original')  # Find the first cycle
        except nx.NetworkXNoCycle:
            break  # No more cycles found, exit the loop
        
        # Identify nodes and edges involved in the cycle
        cycle_nodes = set(node for edge in cycle for node in edge)
        cycle_edges = set(cycle)
        
        # Remove nodes and edges involved in the cycle
        G.remove_nodes_from(cycle_nodes)
        G.remove_edges_from(cycle_edges)
    
    return G


In [29]:
#G = nx.read_gexf('overlap_graph_dummy.gexf')
is_acyclic = nx.is_directed_acyclic_graph(G)
print("Is acyclic now?", is_acyclic)
G = remove_cycles_from_graph(G)
print('Nodes that are in cycles have been removed')
is_acyclic = nx.is_directed_acyclic_graph(G)
print("Is acyclic now?", is_acyclic)

G_new = nx.transitive_reduction(G)
print('Transitive reduction complete!')
nx.write_gexf(G_new,'overlap_graph_tr.gexf')
dict_edge_list = nx.to_dict_of_lists(G_new)
contigs = maximal_non_branching_paths(dict_edge_list)
print('maximal branching paths found!')
print(contigs)    

Is acyclic now? True
Nodes that are in cycles have been removed
Is acyclic now? True
Transitive reduction complete!
maximal branching paths found!
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [19, 20, 21, 22, 23, 24, 25], [19, 10], [25, 26, 27, 28, 29], [35, 36, 37, 38, 39], [39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [39, 30, 31, 32, 33, 34, 25]]


# 7. Consensus - MSA for reads in each contig

# Celebrate! 

In [37]:
contig_break_1_19to10_overlap = reads[10][-100:]
assert reads[10][-100:] == reads[19][:100]

contig_break_1_19to20_overlap = reads[19][-100:]
assert reads[19][-100:] == reads[20][:100]

In [40]:
hamming_break_1 = 0
for i in range(0,len(contig_break_1_19to10_overlap)):
    if contig_break_1_19to10_overlap[i] == contig_break_1_19to20_overlap[i]:
        hamming_break_1 += 1

hamming_break_1 /= len(contig_break_1_19to10_overlap)

hamming_break_1

0.16

In [57]:
overlap_alignment_19to10 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s=reads[19], t=reads[10])
overlap_alignment_10to19 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s=reads[10], t=reads[19])
overlap_perfect_19to10 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s='A'*len(reads[19]), t='A'*len(reads[10]))
overlap_perfect_10to19 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s='A'*len(reads[10]), t='A'*len(reads[19]))

overlap_alignment_19to20 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s=reads[19], t=reads[20])
overlap_alignment_20to19 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s=reads[20], t=reads[19])
overlap_perfect_19to20 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s='A'*len(reads[19]), t='A'*len(reads[20]))
overlap_perfect_20to19 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s='A'*len(reads[20]), t='A'*len(reads[19]))

overlap_alignment_10to20 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s=reads[10], t=reads[20])
overlap_alignment_20to10 = OverlapAlignment(match_reward = 3, mismatch_penalty = 1, indel_penalty = 2, s=reads[20], t=reads[10])

In [59]:
print(f"""overlap_alignment_19to10={overlap_alignment_19to10}\noverlap_alignment_10to19={overlap_alignment_10to19}\noverlap_alignment_19to20={overlap_alignment_19to20}\noverlap_alignment_20to19={overlap_alignment_20to19}\noverlap_alignment_10to20={overlap_alignment_10to20}\noverlap_alignment_20to10={overlap_alignment_20to10}
      \noverlap_perfect_19to10={overlap_perfect_19to10}\noverlap_perfect_10to19={overlap_perfect_10to19}\noverlap_perfect_19to20={overlap_perfect_19to20}\noverlap_perfect_20to19={overlap_perfect_20to19}""")

overlap_alignment_19to10=610
overlap_alignment_10to19=607
overlap_alignment_19to20=599
overlap_alignment_20to19=594
overlap_alignment_10to20=597
overlap_alignment_20to10=608
      
overlap_perfect_19to10=1800
overlap_perfect_10to19=1800
overlap_perfect_19to20=1800
overlap_perfect_20to19=1800


In [55]:
def identify_contig_breaks(contigs:List[List[int]])->List[List[int]]:
    for contig in contigs:
        print(contig)

    return

In [56]:
identify_contig_breaks(contigs)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[10, 11, 12, 13, 14]
[15, 16, 17, 18, 19]
[19, 20, 21, 22, 23, 24, 25]
[19, 10]
[25, 26, 27, 28, 29]
[35, 36, 37, 38, 39]
[39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
[39, 30, 31, 32, 33, 34, 25]
