# Week 7 - de novo assembly with De Bruijn graphs


<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">

For this week's lab you'll be writing functions to count k-mers and to find the edges of a de Bruijn graph.


As a reminder, we have a relevant tutorial document intended for those new to programming or learning Python: [kmer-counting tutorial](kmer_counting.ipynb). The first couple of challenges below are covered (more gently) in this tutorial. 
    
By now you have probably already learned the programming concepts in the tutorial, but you might still find it worth looking at the _Dictionaries: algorithmic considerations_ section. 
</div>

## Setup

In [None]:
import os
import requests
from IPython.core.display import HTML
import matplotlib.pylab as plt
import networkx as nx

In [None]:
# Load stylesheet
HTML(requests.get('https://raw.githubusercontent.com/melbournebioinformatics/COMP90014/main/data/2023/style/custom.css').text)

## Data 

### Kmer Size
Kmer size has a large impact on overall assembly quality.  <br>
The effect of this can be seen here: 
https://github.com/rrwick/Bandage/wiki/Effect-of-kmer-size
<br>
Today we will just work with small strings and small k, but keep this in mind for the future. 

We'll define some toy "reads" from English strings, which you can use to see your functions in action.

The first two examples contain reads with some typos or "sequencing errors". 

What do you think these might do to our graphs?

In [None]:
yoda_reads = '''
ters_no
ers_not
matners
size_ma
tters_
atters_n
size_m
e_matter'''.strip().split("\n")
yoda_reads

In [None]:
caesar_reads = '''
_your_e
l,nd_me
nd_me_y
d_me_y
omans,_
s,_coun
me_you
eids,_
mans,_c
our_ea
'''.strip().split("\n")

In [None]:
# a version with no read errors
caesar_reads_perfect = '''
me_your_e
ymen,_len
_your_e
ds,_Roma
nds,_Roma
trymen,_
lend_me_
omans,_co
n,_lend_
riends,_R
ntrymen
,_country
e_your_ears
untryme
riends,_
'''.strip().split("\n")

And some error-free "reads" from a very small tRNA gene:

In [None]:
# Error free reads
# Range in length from 18-22
mt_te_reads = '''TGTAGTTGAAATACAACGAT
GGTCGTGGTTGTAGTCCGTGC
TGGTCGTGGTTGTAGTCCG
TATCATTGGTCGTGGTTGTAG
TTGTAGTTGAAATACAACGAT
ATCATTGGTCGTGGTTGTAG
ATGGTTTTTCATATCATTGG
GGTTTTTCATATCATTGGTCGT
GGTCGTGGTTGTAGTCCGT
TCATTGGTCGTGGTTGTAGTCC
GGTCGTGGTTGTAGTCCGTGCG
GATGGTTTTTCATATCATT
TGGTCGTGGTTGTAGTCC
AACGATGGTTTTTCATATCA
GGTTTTTCATATCATTGGTCG
GTTTTTCATATCATTGGTCGTG
ATTGGTCGTGGTTGTAGTCCGT
AACGATGGTTTTTCATAT
GTTTTTCATATCATTGGT
AACGATGGTTTTTCATAT'''.strip().split()

# Section 1: Counting k-mers 

## Exercise 1: Extract Kmers From a String


<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">

<b>Challenge:</b> Complete the function below to return all <b>unique</b> k-mers from a given string.

- [ ] Input: One read (str), kmer len k (int)
- [ ] Extract all kmers from read
- [ ] Return: All unique kmers as a Set
    
</div>

In [None]:
def get_string_kmers(read, k):
    '''
    Return a set of all k-mers of length k from string read.
    '''
    ### BEGIN SOLUTION
    kmers = set()
    for i in range(len(read) - k + 1):
        kmer = read[i:i+k]
        kmers.add(kmer)
    return kmers
    ### END SOLUTION

In [None]:
# Get unique kmers from the first read in "caesar_reads"
# Should return {'_yo', 'our', 'r_e', 'ur_', 'you'}
get_string_kmers(caesar_reads[0], 3)

In [None]:
# Get kmers from read "TGTAGTTGAAATACAACGAT"
get_string_kmers(mt_te_reads[0],15)

## Exercise 2: Get Kmers From Many Strings


<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">

<b>Challenge:</b> Write a function which applies `get_string_kmers()` to a list of reads and returns all **unique** kmers.

- [ ] Input: List of reads
- [ ] Extract set of kmers for each read
- [ ] Merge kmers from all reads into single set
- [ ] Return: Set of all unique kmers from all reads

</div>

In [None]:
def get_kmers(reads, k):
    '''
    Given a list of strings representing reads, and a value k, 
    return a set of all k-mers of length k.
    '''    
    ### BEGIN SOLUTION
    kmers = set()
    for read in reads:
        read_kmers = get_string_kmers(read, k)
        kmers.update(read_kmers)
        #kmers = kmers | read_kmers
    return kmers
    ### END SOLUTION

In [None]:
# Should return {'tters_','ize_ma','atners','ers_no','matner','rs_not','size_m',
#                '_matte','e_matt','atters','ters_n','matter'}
# (i.e. 12 unique kmers)

the_kmers = get_kmers(yoda_reads, 6)

[print(x) for x in the_kmers]
print(f'Total kmer count: {len(the_kmers)}')

# Section 2:  Building de Bruijn Graph

## Building a Networkx graph 101

We could represent a graph using Python data structures. For instance, here is one way to represent a graph where nodes B and C are linked from A:
    
    B <- A -> C


In [None]:
nodes = ['A','B','C']
edges = [('A','B'), ('A','C')]

Notice that a tree is a kind of graph, and we have already built trees using data structures in past labs.

This time we'll use a library `networkx` intended specifically for graph manipulation.

Here's how to build that same graph in networkx. Note that we use `DiGraph` instead of `Graph`, which gives a directed graph.

In [None]:
g = nx.DiGraph()
g.add_edge('A','B')
g.add_edge('A','C')

We could have explicitly added nodes with code like `g.add_node('A')`, but since every node is connected to at least one edge in this case, networkx automatically adds the nodes for us when the edges are added. 

Remember you can look at networkx method documentation with `?` or `help()`. Have a look at the documentation for the `add_edge` method.

In [None]:
g.nodes()

In [None]:
g.edges()

Here's a way to draw the graph with a spring layout, where networkx will try to place nodes so they are not too close together. This is a simple layout algorithm - it's easy to see the result for this tiny graph, but may be difficult see for graphs of any size. The layout is somewhat random and for larger graphs, will be different every time you draw it.

In [None]:
nx.draw_spring(g, with_labels=True)

In [None]:
nx.draw_spring(g, with_labels=True, node_size=1200, node_color='#eeeeff', edge_color='red')

## Exercise 3: Get Suffixes and Prefixes from kmers

<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">
<b>Challenge:</b> Write functions `get_suffix()` and `get_prefix()` which just return the `k-1`-length strings which are, respectively the suffix and prefix of the supplied k-mer.
    
- [ ] Input: kmer (str)
- [ ] Prefix = first k-1 characters
- [ ] Suffix = last k-1 characters
- [ ] Return: substring corresponding to prefix or suffix
</div>

In [None]:
def get_prefix(kmer):
    ### BEGIN SOLUTION   
    return kmer[:-1]
    ### END SOLUTION


In [None]:
def get_suffix(kmer):
    ### BEGIN SOLUTION
    return kmer[1:]
    ### END SOLUTION

Let's test out those functions. 

We'll use the string AGGTA and try to extract the suffix A**GGTA** and the prefix **AGGT**A

In [None]:
# Should return 'GGTA'
get_suffix('AGGTA')

In [None]:
# Should return 'AGGT'
get_prefix('AGGTA')

## Exercise 4: Build a Directed Graph

<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">
  
<b>Challenge:</b> Given a set of reads and a value k, build a directed graph, using networkx, where nodes are the prefixes/suffixes of kmers and an edge exists for every kmer.

When adding an edge between a prefix and suffix node we will also add the full kmer as an edge label.

- [ ] Input: Reads (list of str), k (int)
- [ ] Init diGraph
- [ ] Extract all kmers
- [ ] For each kmer add edge from prefix to suffix
- [ ] Label edge with kmer name
- [ ] Return: populated graph

Hint: To label an edge: `graph.add_edge(from_node, to_node, label="the_label")`
</div>


In [None]:
def build_graph(reads, k):
    '''
    Given a set of reads and a value k, return the networkx de Bruijn graph object.
    '''
    
    kmers = get_kmers(reads, k)
    graph = nx.DiGraph()
    
    # Complete this function

    ### BEGIN SOLUTION    
    for kmer in kmers:
        prefix = get_prefix(kmer)
        suffix = get_suffix(kmer)
        graph.add_edge(prefix, suffix, label=kmer)

    return graph
    ### END SOLUTION


Let's build graphs for our test data and inspect them with the plotting function below.

In [None]:
# helper func to render
def draw_debruijn(graph):
    fig = plt.figure(1, figsize=(16, 8), dpi=60)
    pos = nx.spring_layout(graph, seed=2, k=0.1, iterations=50)
    nx.draw_networkx_nodes(graph, pos, node_color='white', node_size=1000, edgecolors='black', linewidths=1)
    nx.draw_networkx_edges(graph, pos, width=1, arrows=True, arrowstyle='-|>', arrowsize=12, min_target_margin=22)
    nx.draw_networkx_labels(graph, pos, font_size=12, font_family="sans-serif")
    nx.draw_networkx_edge_labels(
        graph, pos, font_color='red', font_size=12, label_pos=0.6,
        edge_labels={e: graph.edges[e]['label'] for e in graph.edges}
    )
    plt.tight_layout()
    plt.show()

In [None]:
# Testing your build_graph() function

# build it
yoda_graph = build_graph(yoda_reads, 4)

# render it
draw_debruijn(yoda_graph)

<div class="alert alert-info">
<b>Question 1:</b> What has caused the bubble in this graph?
</div>

=== BEGIN MARK SCHEME ===

There is a typo in the word "matters" as "matners". 

For 4=4 this affects kmer matn/matt to ters/ners. 

The prefix "mat" can now be followed by "atn" or "att".

=== END MARK SCHEME ===

<span style="color:rgb(17, 122, 121); font-family:Courier"><i><b># -- GRADED CELL (1 marks) - complete this cell --</b></i></span>

YOUR ANSWER HERE

<div class="alert alert-info">
<b>Question 2:</b> How many kmers will be affected by a single error in a sequencing read?
</div>

=== BEGIN MARK SCHEME ===
K kmers will contain the error
=== END MARK SCHEME ===

<span style="color:rgb(17, 122, 121); font-family:Courier"><i><b># -- GRADED CELL (1 marks) - complete this cell --</b></i></span>

YOUR ANSWER HERE

# Section 2 : Extract Contigs From de Bruijn Graph

<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">

<b>Challenge:</b> Try to extract contigs from your graph as described in lectures. To extract a contig, we need to find an unbalanced node (or breakpoint), then walk along the directed graph only so long as our path is unambiguous.<br>
    
We will do this in 3 steps:
- Identify nodes that correspond to break points in the graph (ambiguous path, start, or end).
- Identify all the potential contig start points.
- Extending contigs from start points to break points.

</div>



**Finding Unbalanced nodes** 

You can find the in-degree and out-degree of nodes (number of edges leading in and out), and the edges themselves, with networkx methods like so:

In [None]:
# Build a sample graph
g = nx.DiGraph()
g.add_edge('A','B')
g.add_edge('A','C')
g.add_edge('Z','A')

In [None]:
# Count inbound edges of node A
g.in_degree('A')

In [None]:
# Count outbound edges of node A
g.out_degree('A')

In [None]:
# List outgoing edges of node A
g.out_edges('A')

In [None]:
# Get a list of children of A
# Here we extract the 2nd item from each outgoing edge tuple 
[x[1] for x in g.out_edges('A')]

In [None]:
# Alternative method to find the downstream neighbours of node A
g.neighbors('A')

To get a list of all nodes in the graph we can use the `.nodes()` method.

In [None]:
# Get all nodes 
list(g.nodes())

## Exercise 5: Find break points


<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">
   
To extract contigs from the graph we must first identify branch points or termini in the graph.
    
To do this we will check for unbalanced nodes (in_degree ≠ out_degree) and for balanced nodes where the in/out degrees are > 1.

**Challange:** Write a function that returns a list of nodes corresponding to a breakpoint in the graph
    
- [ ] Input: diGraph object
- [ ] Get list of all nodes
- [ ] Check if node is unbalanced
- [ ] Check if balanced but degree > 1
- [ ] Return: List of break points
</div>


In [None]:
def get_breaks(graph):
    
    # Get a list of all nodes in graph
    nodes = list(graph.nodes())
    # Create empty list to store unbalanced nodes
    break_nodes = []  
    
    ### BEGIN SOLUTION
     
    for node in nodes:
        # Check the in and out degrees for the current node
        ins = graph.in_degree(node)
        outs = graph.out_degree(node)
        # If unbalanced add to list
        if ins != outs:
            break_nodes.append(node)
        # If balanced BUT in/out degree > 1, add to list.
        elif ins > 1:
            break_nodes.append(node)
    ### END SOLUTION
    
    return break_nodes


In [None]:
# Test your function
# Where are these nodes in the graph?
get_breaks(yoda_graph)

## Exercise 6: Get contig start nodes

<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">

Now we will identify unbalanced and branch-point nodes which correspond to contig starts. <br>
    
<b>Challenge:</b> Write a function that takes as input a graph and a list of unbalanced nodes, and returns a list of nodes that correspond to contig starts.

    
<b>Hint:</b> Consider the first and last nodes in a graph, are these both unbalanced? Should these both be contig starts? 
    
- [ ] Input: graph, list of break-point nodes
- [ ] For each break-point node assess whether is should be a contig start
- [ ] Return: List of start nodes
</div>

In [None]:
def get_contig_starts(graph, break_nodes):
    
    contig_starts = []
    
    ### BEGIN SOLUTION
    for node in break_nodes:
        in_degree = graph.in_degree(node)
        out_degree = graph.out_degree(node)
        
        # If unbalanced node only has outbound node
        # Then it is a start node
        if in_degree == 0:
            contig_starts.append(node)
        
        # If internal node has > 1 child, 
        # Each child becomes a start node
        elif out_degree > 1:
            children = list(graph.neighbors(node))
            contig_starts += children
        
        # If node has multiple parents but only one child
        # node becomes a start node
        # This excludes nodes with 0 children at graph termini
        elif in_degree > out_degree and out_degree > 0:
            contig_starts.append(node)
    ### END SOLUTION
    
    return contig_starts

In [None]:
yoda_breaks = get_breaks(yoda_graph)

yoda_starts = get_contig_starts(yoda_graph, yoda_breaks)

print(yoda_starts)

## Exercise 7: Extract a contig

<div style="color: rgb(27,94,32); background: rgb(200,230,201); border: solid 1px rgb(129,199,132); padding: 10px;">
Get the full contig seqences from starting nodes to end nodes.

**Challange:** Return extended contig sequence, given a starting node.

<b>Challenge:</b> 
- [ ] Input: graph, start node, list of all start nodes
- [ ] Init contig with prefix from start node
- [ ] Move to next node
- [ ] Append final character from the new node to the contig
- [ ] End if node had no children OR next node is in contig starts list
- [ ] Return: Contig (str)
</div>


In [None]:
def get_contig(graph, starting_node, contig_starts):
    
    ### BEGIN SOLUTION
    
    contig = ''
    contig += starting_node
    node = starting_node
    
    # List of children of current node
    next_nodes = [x for x in graph.neighbors(node)]
    
    # While only one child node and that node is not a contig start
    while len(next_nodes) == 1 and next_nodes[0] not in contig_starts:
        # update current node
        node = next_nodes[0]
        # Add last character of current node to contig string
        contig += node[-1]
        # Get children of current node
        next_nodes = [x for x in graph.neighbors(node)]
    
    ### END SOLUTION
    
    return contig

In [None]:
# Get contig starting with 'atn'
get_contig(yoda_graph, 'atn', yoda_starts)

In [None]:
# Get contig starting with 'siz'
get_contig(yoda_graph, 'siz', yoda_starts)

## Exercise 8 : Bring it all together!

Now let's tie it all together!

In [None]:
def print_contigs(contigs):
    for contig in contigs:
        print(contig)

In [None]:
def extract_contigs(graph):
    breakpoint_nodes = get_breaks(graph)
    contig_starts = get_contig_starts(graph, breakpoint_nodes)
    
    contigs = []
    for starting_node in contig_starts:
        contig = get_contig(graph, starting_node, contig_starts)
        contigs.append(contig)
    
    print_contigs(contigs)
    return contigs

In [None]:
contigs = extract_contigs(yoda_graph)

In [None]:
# Let's look at the graph again to see if our contigs match expected break-points in the graph.
nx.draw_spring(yoda_graph, with_labels=True, node_size=1200, node_color='#eeeeff', edge_color='red')