In [1]:
import graphviz
import itertools

from Bio import SeqIO

In [2]:
class DeBruijnGraph:
    ''' De Bruijn directed multigraph built from a collection of
        strings. User supplies strings and k-mer length k.  Nodes
        are k-1-mers.  An Edge corresponds to the k-mer that joins
        a left k-1-mer to a right k-1-mer. '''
 
    @staticmethod
    def chop(st, k):
        ''' Chop string into k-mers of given length '''
        for i in range(len(st)-(k-1)):
            yield (st[i:i+k], st[i:i+k-1], st[i+1:i+k])
    
    class Node:
        ''' Node representing a k-1 mer.  Keep track of # of
            incoming/outgoing edges so it's easy to check for
            balanced, semi-balanced. '''
        
        def __init__(self, km1mer):
            self.km1mer = km1mer
            self.nin = 0
            self.nout = 0
        
        def isSemiBalanced(self):
            return abs(self.nin - self.nout) == 1
        
        def isBalanced(self):
            return self.nin == self.nout
        
        def __hash__(self):
            return hash(self.km1mer)
        
        def __str__(self):
            return self.km1mer
    
    def __init__(self, strIter, k, circularize=False):
        ''' Build de Bruijn multigraph given string iterator and k-mer
            length k '''
        self.G = {}     # multimap from nodes to neighbors
        self.nodes = {} # maps k-1-mers to Node objects
        for st in strIter:
            if circularize:
                st += st[:k-1]
            for kmer, km1L, km1R in self.chop(st, k):
                nodeL, nodeR = None, None
                if km1L in self.nodes:
                    nodeL = self.nodes[km1L]
                else:
                    nodeL = self.nodes[km1L] = self.Node(km1L)
                if km1R in self.nodes:
                    nodeR = self.nodes[km1R]
                else:
                    nodeR = self.nodes[km1R] = self.Node(km1R)
            
                nodeL.nout += 1
                nodeR.nin += 1
                self.G.setdefault(nodeL, []).append(nodeR)
        # Iterate over nodes; tally # balanced, semi-balanced, neither
        self.nsemi, self.nbal, self.nneither = 0, 0, 0
        # Keep track of head and tail nodes in the case of a graph with
        # Eularian walk (not cycle)
        self.head, self.tail = None, None
        for node in iter(self.nodes.values()):
            if node.isBalanced():
                self.nbal += 1
            elif node.isSemiBalanced():
                if node.nin == node.nout + 1:
                    self.tail = node
                if node.nin == node.nout - 1:
                    self.head = node
                self.nsemi += 1
            else:
                self.nneither += 1
    
    def nnodes(self):
        ''' Return # nodes '''
        return len(self.nodes)
    
    def nedges(self):
        ''' Return # edges '''
        return len(self.G)
    
    def hasEulerianWalk(self):
        ''' Return true iff graph has Eulerian walk. '''
        return self.nneither == 0 and self.nsemi == 2
    
    def hasEulerianCycle(self):
        ''' Return true iff graph has Eulerian cycle. '''
        return self.nneither == 0 and self.nsemi == 0
    
    def isEulerian(self):
        ''' Return true iff graph has Eulerian walk or cycle '''
        # technically, if it has an Eulerian walk
        return self.hasEulerianWalk() or self.hasEulerianCycle()
    
    def eulerianWalkOrCycle(self):
        ''' Find and return sequence of nodes (represented by
            their k-1-mer labels) corresponding to Eulerian walk
            or cycle '''
        assert self.isEulerian()
        g = self.G
        if self.hasEulerianWalk():
            g = g.copy()
            g.setdefault(self.tail, []).append(self.head)
        # graph g has an Eulerian cycle
        tour = []
        src = next(iter(g.keys())) # pick arbitrary starting node
        
        def __visit(n):
            while len(g[n]) > 0:
                dst = g[n].pop()
                __visit(dst)
            tour.append(n)
        
        __visit(src)
        tour = tour[::-1][:-1] # reverse and then take all but last node
            
        if self.hasEulerianWalk():
            # Adjust node list so that it starts at head and ends at tail
            sti = tour.index(self.head)
            tour = tour[sti:] + tour[:sti]
        
        # Return node list
        return list(map(str, tour))

In [3]:
class DeBruijnGraph2(DeBruijnGraph):
    def to_dot(self, weights=False):
        ''' Return string with graphviz representation.  If 'weights'
            is true, label edges corresponding to distinct k-1-mers
            with weights, instead of drawing separate edges for
            k-1-mer copies. '''
        g = graphviz.Digraph(comment='DeBruijn graph')
        for node in iter(self.G.keys()):
            g.node(node.km1mer, node.km1mer)
        for src, dsts in iter(self.G.items()):
            if weights:
                weightmap = {}
                for dst in dsts:
                    weightmap[dst] = weightmap.get(dst, 0) + 1
                for dst, v in weightmap.items():
                    g.edge(src.km1mer, dst.km1mer, label=str(v))
            else:
                for dst in dsts:
                    g.edge(src.km1mer, dst.km1mer)
        return g

In [4]:
def neighbors1mm(kmer, alpha):
    ''' Generate all neighbors at Hamming distance 1 from kmer '''
    neighbors = []
    for j in range(len(kmer)-1, -1, -1):
        oldc = kmer[j]
        for c in alpha:
            if c == oldc: continue
            neighbors.append(kmer[:j] + c + kmer[j+1:])
    return neighbors

In [5]:
def kmerHist(reads, k):
    ''' Return k-mer histogram and average # k-mer occurrences '''
    kmerhist = {}
    for read in reads:
        for kmer in [ read[i:i+k] for i in range(len(read)-(k-1)) ]:
            kmerhist[kmer] = kmerhist.get(kmer, 0) + 1
    return kmerhist

In [6]:
def correct1mm(read, k, kmerhist, alpha, thresh):
    ''' Return an error-corrected version of read.  k = k-mer length.
        kmerhist is kmer count map.  alpha is alphabet.  thresh is
        count threshold above which k-mer is considered correct. '''
    # Iterate over k-mers in read
    for i in range(len(read)-(k-1)):
        kmer = read[i:i+k]
        # If k-mer is infrequent...
        if kmerhist.get(kmer, 0) <= thresh:
            # Look for a frequent neighbor
            for newkmer in neighbors1mm(kmer, alpha):
                if kmerhist.get(newkmer, 0) > thresh:
                    # Found a frequent neighbor; replace old kmer
                    # with neighbor
                    read = read[:i] + newkmer + read[i+k:]
                    break
    # Return possibly-corrected read
    return read

In [7]:
# Non-greedy SCS, for comparison:
def overlap(a, b, min_length=3):
    ''' Return length of longest suffix of 'a' matching
        a prefix of 'b' that is at least 'min_length'
        characters long.  If no such overlap exists,
        return 0. '''
    start = 0  # start all the way at the left
    while True:
        start = a.find(b[:min_length], start)  # look for b's suffx in a
        if start == -1:  # no more occurrences to right
            return 0
        # found occurrence; check for full suffix/prefix match
        if b.startswith(a[start:]):
            return len(a)-start
        start += 1  # move just past previous match

def scs(ss):
    ''' Returns shortest common superstring of given
        strings, which must be the same length '''
    shortest_sup = None
    for ssperm in itertools.permutations(ss):
        sup = ssperm[0]  # superstring starts as first string
        for i in range(len(ss)-1):
            # overlap adjacent strings A and B in the permutation
            olen = overlap(ssperm[i], ssperm[i+1], min_length=1)
            # add non-overlapping portion of B to superstring
            sup += ssperm[i+1][olen:]
        if shortest_sup is None or len(sup) < len(shortest_sup):
            shortest_sup = sup  # found shorter superstring
    return shortest_sup  # return shortest

In [8]:
def pick_maximal_overlap(reads, k):
    ''' Return a pair of reads from the list with a
        maximal suffix/prefix overlap >= k.  Returns
        overlap length 0 if there are no such overlaps.'''
    reada, readb = None, None
    best_olen = 0
    for a, b in itertools.permutations(reads, 2):
        olen = overlap(a, b, min_length=k)
        if olen > best_olen:
            reada, readb = a, b
            best_olen = olen
    return reada, readb, best_olen

def greedy_scs(reads, k):
    ''' Greedy shortest-common-superstring merge.
        Repeat until no edges (overlaps of length >= k)
        remain. '''
    read_a, read_b, olen = pick_maximal_overlap(reads, k)
    while olen > 0:
        reads.remove(read_a)
        reads.remove(read_b)
        reads.append(read_a + read_b[-(len(read_b) - olen):])
        read_a, read_b, olen = pick_maximal_overlap(reads, k)
    return ''.join(reads)

In [9]:
# read sequences from FASTA file
def read_fasta(file_path):
    sequences = []
    for record in SeqIO.parse(file_path, 'fasta'):
        sequences.append(str(record.seq))

    return sequences

def visualize_overlap_graph(sequences, output_file, k):
    '''Visualize overlap graph with Graphviz.'''
    graph = graphviz.Digraph()
    for seq in sequences:
        graph.node(seq)

    for seq1 in sequences:
        for seq2 in sequences:
            if seq1 != seq2:
                overlap_len = overlap(seq1, seq2, k)
                if overlap_len > 0:
                    graph.edge(seq1, seq2, label=str(overlap_len))

    graph.render(output_file, format='png', cleanup=True)


In [13]:
sequences = read_fasta("test.fa")
alpha = "ACTG"
k = 7
thresh = 3

# Build de Bruijn graph
kmer_histogram = kmerHist(sequences, k)
corrected_sequences = [correct1mm(seq, k, kmer_histogram, alpha, thresh) for seq in sequences]
dbg = DeBruijnGraph2(corrected_sequences, k)

# Plot overlap graph
dbg.to_dot().render("debruijn_graph", format="png", cleanup=True)
visualize_overlap_graph(corrected_sequences, "overlap_graph", k)

print(corrected_sequences)
print("Reconstructed Sequence:", greedy_scs(corrected_sequences.copy(), k))


['AAGCTTGCGTTACCTAAAGG', 'AATGGATAGTGAGGTTAAGG', 'ACTAGATCCGAATGGATAGT', 'ACTAGCAAGCTTGCGTTACC', 'AGATCCGAATGGATAGTGAG', 'AGCAAGCTTGCGTTACCTAA', 'AGCACTAGCAAGCTTGCGTT', 'AGCTGAATGCGAGAAAGCAC', 'ATGCGAGAAAGCACTAGCAA', 'CCCCCCGAGAGCTGAATGCG', 'CCCGAGAGCTGAATGCGAGA', 'CGAGAAAGCACTAGCAAGCT', 'CTTGCGTTACCTAAAGGGCA', 'GAAAGCACTAGCAAGCTTGC', 'GAATCGAATGGATAGTGAGG', 'GAGAGCTGAATGCGAGAAAG', 'GCGTTACCTAAAGGGCAGGT', 'GGATAGTGAGGTTAAGGCCC', 'GGCCCCCCCGAGAGCTGAAT', 'GGTTAAGGCCCCCCCGAGAG', 'TAAGGCCCCCCCGAGAGCTG', 'TAGTGAGGTTAAGGCCCCCC', 'TCCGAATGGATAGTGAGGTT', 'TCGAATGGATAGTGAGGTTA', 'TGAATGCGAGAAAGCACTAG', 'TGAGGTTAAGGCCCCCCCGA', 'TTACCTAAAGGGCAGGTGGT']
Reconstructed Sequence: ACTAGATCCGAATGGATAGTGAGGTTGAATCGAATGGATAGTGAGGTTAAGGCCCCCCCGAGAGCTGAATGCGAGAAAGCACTAGCAAGCTTGCGTTACCTAAAGGGCAGGTGGT
