In [22]:
from Bio import SeqIO
import Bio
from collections import defaultdict
from graphviz import Digraph
import numpy as np

class Vertex:
    
    def __init__(self, seq):
        self.seq = seq
        self.coverage = 1
        self.in_edges = {}
        self.out_edges = {}
        
    def increase_coverage(self):
        self.coverage += 1

class Edge:
    
    def __init__(self, k1, k2):
        self.seq = k1 + k2[-1]
        self.n = 2
        self.coverage = 0
    
    def calc_coverage(self, c1, c2):
        self.coverage = (c1 + c2)/2


class Graph:

    def __init__(self,k):
        self.vertices = {}
        self.k = k
        
        
    def add_read(self,read):
        read_lng = len(read)
        if read_lng < self.k:
            return
            
        kmer = read[:k]
        if kmer in self.vertices:
            self.vertices[kmer].increase_coverage()
        else:
            self.vertices[kmer] = Vertex(kmer)
        
        for next_kmer_indx in range(1,read_lng-k+1,1):
            next_kmer = read[next_kmer_indx:(next_kmer_indx+k)]
            if next_kmer in self.vertices:
                self.vertices[next_kmer].increase_coverage()
            else:
                self.vertices[next_kmer] = Vertex(next_kmer)
            
            new_edge = Edge(kmer,next_kmer)
            
            self.vertices[next_kmer].in_edges[kmer]  = [new_edge]
            
            self.vertices[kmer].out_edges[next_kmer] = [new_edge]

            kmer = next_kmer
    
    def calc_init_edge_coverage(self):
        for current_vertex in self.vertices.keys():
            for next_vertex in self.vertices[current_vertex].out_edges.keys():
                self.vertices[current_vertex].out_edges[next_vertex][0].calc_coverage(self.vertices[current_vertex].coverage,self.vertices[next_vertex].coverage)
    
    def visualize(self, path, full):
        
        self.graph =  Digraph()

        for vertex, edge in self.vertices.items():
            print(edge.coverage, str(vertex))
            label = str(vertex) if full else 'coverage={}'.format(edge.coverage) 
            self.graph.node(vertex, label=label)
                
            for child_vertex, child_edge in edge.out_edges.items():
                label = \
                    str(child_edge[0].seq) if full \
                    else 'coverage={} size={}'.format(child_edge[0].coverage, child_edge[0].n)
                self.graph.edge(vertex, child_vertex, label=label) 
        
        with open (path, 'w') as handle:
            handle.write(self.graph.source)
    
    def merge(self):
        
        filtered_vertices = []
        for vertex, edge in self.vertices.items(): 
            if len(edge.in_edges) == 1 and len(edge.out_edges) == 1:
                filtered_vertices.append((vertex, edge))
        
        for vertex, edge in filtered_vertices:
            if vertex in self.vertices and len(self.vertices) > 2:
                in_vertex = list(edge.in_edges.keys())[0]
                out_vertex = list(edge.out_edges.keys())[0]
                new_edge = Edge(edge.in_edges[in_vertex][0].seq, edge.out_edges[out_vertex][0].seq[-1])
                
                new_edge.n = \
                    self.vertices[out_vertex].in_edges[vertex][0].n + \
                    self.vertices[in_vertex].out_edges[vertex][0].n - 1
                
                new_edge.coverage = np.mean([
                        self.vertices[in_vertex].out_edges[vertex][0].coverage,
                        self.vertices[out_vertex].in_edges[vertex][0].coverage
                    ])
                
                self.vertices[in_vertex].out_edges[out_vertex] = [new_edge]
                self.vertices[out_vertex].in_edges[in_vertex] = [new_edge]                
                
                del self.vertices[out_vertex].in_edges[vertex]
                del self.vertices[in_vertex].out_edges[vertex]
                del self.vertices[vertex]

    

if __name__ == '__main__':
    
    dataset = './hw_4_5_dataset.fasta'

    k = 3
    
    my_graph = Graph(k)

    with open(dataset, "r") as handle:
        for record in SeqIO.parse(handle, "fasta"):
            read = str(record.seq)
            my_graph.add_read(read)
    

    
    my_graph.calc_init_edge_coverage()
    
    my_graph.visualize('graph.dot', full=True) # .dot file can be seen via Xdot, for example
    
    my_graph.merge()
    my_graph.visualize('graph_merged.dot', full=True)


1 ABC
1 BCD
1 ABC
1 BCD
