<a href="https://colab.research.google.com/github/epi2me-labs/tutorials/blob/master/Introduction_to_how_ONT's_medaka_works.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1>How medaka works</h1>

The following is a relatively short document describing how Oxford Nanopore Technologies' program for consensus calling of sequencing data, `medaka`, functions internally. We will demonstrate the core functionality required to process alignment data, how it is presented to a recurrent neural network, and how a consensus sequence is formed.

The document you are reading is not a static web page, but an interactive environment called a Colab notebook that lets you write and execute code. You can inspect, modify, and run any of the code on this page.

## Installation

Before getting started with how `medaka` works, we will install it into the Colab environment. This will enable us to both inspect and run the code. To install the `medaka` run the code cell below by clicking the "play" icon to the left or pressing on the cell and typing `<shift>-<enter>`

In [None]:
# setup some prerequisites - medaka model files are stored in git-lfs, which
# we need to install
%cd /content
!wget https://github.com/git-lfs/git-lfs/releases/download/v2.10.0/git-lfs-linux-amd64-v2.10.0.tar.gz
!tar -xzvf git-lfs-linux-amd64-v2.10.0.tar.gz
!mv git-lfs /bin/
!git lfs install

# install medaka - we'll do this from source so we can make a couple of
# modifications and use some extra bits later on
!apt-get install file
!rm -rf /content/medaka
!git clone https://github.com/nanoporetech/medaka.git
%cd /content/medaka
!sed -i 's/tensorflow==/tensorflow-gpu==/' requirements.txt
!pip install -r requirements.txt
!make scripts/mini_align
!python setup.py install
%cd /content

# install minimap2 and pomoxis for later
!git clone https://github.com/lh3/minimap2
!cd minimap2 && make && cp minimap2 /bin
!pip install git+https://github.com/rrwick/Porechop
!pip install pomoxis

/content
Reading package lists... Done
Building dependency tree       
Reading state information... Done
file is already the newest version (1:5.32-2ubuntu0.4).
The following package was automatically installed and is no longer required:
  libnvidia-common-440
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.
/content/medaka
Bundling models: ['r103_min_high_g345', 'r103_min_high_g360', 'r103_prom_high_g360', 'r103_prom_snp_g3210', 'r103_prom_variant_g3210', 'r10_min_high_g303', 'r10_min_high_g340', 'r941_min_fast_g303', 'r941_min_high_g303', 'r941_min_high_g330', 'r941_min_high_g340_rle', 'r941_min_high_g344', 'r941_min_high_g351', 'r941_min_high_g360', 'r941_prom_fast_g303', 'r941_prom_high_g303', 'r941_prom_high_g330', 'r941_prom_high_g344', 'r941_prom_high_g360', 'r941_prom_snp_g303', 'r941_prom_snp_g322', 'r941_prom_snp_g360', 'r941_prom_variant_g303', 'r941_prom_variant_g322', 'r941_prom_variant_g360']
running install
running bdist_

After running the above, it is necessary to select `Runtime>Restart runtime...` from the menu at the top of the page, after which we should be able to import medaka

In [None]:
import medaka
help(medaka)

## Medaka's input

As input the core `medaka` algorithm accepts sequencing reads aligned to an assembly sequence. If you have run the `medaka_consensus` pipeline you will have given as input an assembly sequence and your sequencing data. The pipeline simply runs [`minimap2`](https://https://github.com/lh3/minimap2) to calculate alignments of the reads to the assembly.

For the purposes of this demonstration we will download pre-aligned data from an R9.4.1 MinION sequencing run:

In [None]:
!mkdir -p /content/data && cd /content/data/ \
    && wget https://ont-research.s3-eu-west-1.amazonaws.com/datasets/r941_zymo/references.fasta \
    && wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus.bam \
    && wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus.bam.bai \
    && wget https://ont-research.s3-eu-west-1.amazonaws.com/labs_resources/misc/saureus_canu.fasta

The downloaded `saureus.bam` file contains alignments of sequencing reads to the downloaded `saureus_canu.fasta`. The depth of sequencing has been reduced to around 150-fold coverage of the genome.

# Diving in: counting bases

The first step of `medaka`'s calculation is to parse the alignment data into a base counts table ready for input to the neural network. In this section we explore the functions responsible for doing this, how exactly counting is performed and what the results may represent.



## Pileup interface

At the heart of `medaka` resides a straight-forward base-counting procedure. From the alignment data comparing sequencing reads to the reference sequence a pileup is created, much like the display and alignment viewer such as [IGV](https://software.broadinstitute.org/software/igv/) would display.

The pileup is summarise by counting the different base types contained within its columns. The function responsible for this counting excercise is called `pileup_counts` in the [`features`](https://github.com/nanoporetech/medaka/blob/d195b9cc1ee7681a121be9fe4fb016a00744ef47/medaka/features.py#L109) module:

In [None]:
from medaka.features import pileup_counts
help(pileup_counts)

Help on function pileup_counts in module medaka.features:

pileup_counts(region, bam, dtype_prefixes=None, region_split=100000, workers=8, tag_name=None, tag_value=None, keep_missing=False, num_qstrat=1, weibull_summation=False)
    Create pileup counts feature array for region.
    
    :param region: `medaka.common.Region` object
    :param bam: .bam file with alignments.
    :param dtype_prefixes: prefixes for query names which to separate counts.
        If `None` (or of length 1), counts are not split.
    :param region_split: largest region to process in single thread.
    :param workers: worker threads for calculating pileup.
    :param tag_name: two letter tag name by which to filter reads.
    :param tag_value: integer value of tag for reads to keep.
    :param keep_missing: whether to keep reads when tag is missing.
    :param num_qstrat: number of layers for qscore stratification.
    :param weibull_summation: use a Weibull partial-counts approach,
        requires 'WL' and 

The `pileup_counts` function above has various arguments, most of which are advanced options and not used within the default operation of medaka. To create a counts matrix we call the function with a Samtools-style region string (`medaka` uses 0-based end exclusive co-ordinates) and a filepath to our alignment file:

In [None]:
from timeit import default_timer as now
from medaka.common import Region

t0 = now()
region = Region.from_string('tig00000061:0-1499707')
bam_file = '/content/data/saureus.bam'
pileup_data = pileup_counts(region, bam_file)
pileup_data = pileup_data[0]  # implementation detail that need not trouble us
counts, positions = pileup_data
t1 = now()
print("{:.2f}s to form pileup counts.".format(t1 - t0))

10.62s to form pileup counts.


### The counts matrix

The `pileup_counts` function returned two structures. The latter of these is a positions table, this records which pileup columns are reference positions and which are caused by inserted bases in one or more reads:


In [None]:
display(positions)

array([(      0, 0), (      1, 0), (      2, 0), ..., (1499704, 0),
       (1499705, 0), (1499706, 0)],
      dtype=[('major', '<i8'), ('minor', '<i8')])

The field `minor` in the above array indicates reference and insertion columns: it takes a value `0` for a reference position and counts upwards for all following insertion events. The `major` field keeps track of the reference base co-ordinate.

The base counts themselves from the alignment pileup are stored separately:

In [None]:
display(counts.shape)
display(counts)

(3508694, 10)

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 4, 0, 0],
       ...,
       [0, 0, 0, ..., 2, 0, 0],
       [0, 0, 0, ..., 2, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint64)

The matrix is of shape (# pileup columns, 10), each row of the matrix corresponds to the counts of bases and gaps in the pileup columns (yes, the rows and columns get confusing). There are 10 entries one each for the fours base types and gap, multiplied by two as reads on the forward and reverse strand are counted separately. The ordering of the entries is given by: 

In [None]:
from medaka.features import libmedaka
ffi, lib = libmedaka.ffi, libmedaka.lib
plp_bases = lib.plp_bases
codes = ffi.string(plp_bases).decode()
display(','.join(codes))

'a,c,g,t,A,C,G,T,d,D'

in which lower-case letters denote reverse strand counts (upper case, forward) and 'd' and 'D' count deletions. A point of note is that this counting strategy it itself makes a distinction between bases which are deleted in reads with respect to reference sequence and bases which are deleted in reads with respect to other reads (the bases in the other reads being insertions with respect to the reference). Previous versions of `medaka` have performed a symmetrization here: by adding in deletion counts for all read that span a pileup column, whether that pileup column is a reference position (`minor=0`) or an insertion column (`minor>0`).



### An aside on performance

Although medaka is mainly written in python, this first calculation is performed in C for speed; `pileup_counts` defers to a `C` implementation of the base counting which makes use of the pileup API from [`htslib`](https://github.com/samtools/htslib). This `C` function is 1-2 orders of magnitude faster than a previous python implementation, using [`pysam`](https://github.com/pysam-developers/pysam). Nevertheless this seemingly trivial step is often the performance bottleneck in `medaka`, particularly when using GPUs to run the neural network calculations. Further it is within the `htslib` function [`resolve_cigar2`](https://github.com/samtools/htslib/blob/9279d76e1186d7155ceea9db9db8c9298f6139bd/sam.c#L3956) that the code is bogged down:



In [None]:
%cd /content/medaka/
print("Compiling standalone pileup counts program.")
!rm -rf pileup
!sed -i "s/gcc -pthread/gcc -pg -pthread/" Makefile
!make pileup
print("\nRunning pileup...")
!time ./pileup /content/data/saureus.bam tig00000061:1-1499707 2>/dev/null > pileup.txt
!gprof pileup gmon.out | head -n 17

/content/medaka
Compiling standalone pileup counts program.
gcc -pg -pthread  -g -Wall -fstack-protector-strong -D_FORTIFY_SOURCE=2 -fPIC -std=c99 -msse3 -O3 \
	-Isrc -Isubmodules/samtools-1.9/htslib-1.9 \
	src/medaka_common.c src/medaka_counts.c src/medaka_bamiter.c libhts.a \
	-lm -lz -llzma -lbz2 -lpthread -lcurl -lcrypto \
	-o pileup -std=c99 -msse3 -O3

Running pileup...

real	0m28.472s
user	0m26.345s
sys	0m1.940s
Flat profile:

Each sample counts as 0.01 seconds.
  %   cumulative   self              self     total           
 time   seconds   seconds    calls  Ts/call  Ts/call  name    
 49.58      5.85     5.85                             resolve_cigar2
 27.97      9.15     3.30                             calculate_pileup
 16.02     11.04     1.89                             bam_plp_next
  2.88     11.38     0.34                             print_pileup_data
  1.10     11.51     0.13                             bam_cigar2rqlens
  1.10     11.64     0.13                         

Because of this `medaka` uses multiple worker threads to perform the base counting: the calculation is split into 1 Mbase shards and then further subdivided into 100kbase chunks with the chunks being reassembled for each shard. In conclusion, note the time taken for the simple pure `C` program `pileup` above and that for running the full `medaka consensus` program below:

In [None]:
!rm -rf /content/data/saureus.hdf
!time medaka consensus /content/data/saureus.bam /content/data/saureus.hdf --region tig00000061 --batch_size 50

[23:06:01 - Predict] Processing region(s): tig00000061:0-1499707
[23:06:01 - Predict] Setting tensorflow threads to 1.
[23:06:01 - Predict] Found a GPU.
[23:06:01 - Predict] If cuDNN errors are observed, try setting the environment variable `TF_FORCE_GPU_ALLOW_GROWTH=true`. To explicitely disable use of cuDNN use the commandline option `--disable_cudnn. If OOM (out of memory) errors are found please reduce batch size.
[23:06:01 - Predict] Processing 2 long region(s) with batching.
[23:06:01 - Predict] Using model: /usr/local/lib/python3.6/dist-packages/medaka-0.11.5-py3.6-linux-x86_64.egg/medaka/data/r941_min_high_g344_model.hdf5.
[23:06:01 - ModelLoad] Building model with cudnn optimization: True
[23:06:02 - DLoader] Initializing data loader
[23:06:02 - PWorker] Running inference for 1.5M draft bases.
[23:06:02 - Sampler] Initializing sampler for consensus of region tig00000061:0-1000000.
[23:06:02 - Sampler] Initializing sampler for consensus of region tig00000061:999000-1499707.
[23

On the standard Colab environment with two CPU cores and an NVIDIA P4 GPU, we see that wallclock time for `medaka consensus` is less that the single-threaded `pileup` program. That this speed can be achieved is down to the asynchronous multi-threaded queuing that `medaka` implements for the generation of the counts matrices and the raw power of GPUs to process the data.

## Normalization

After obtained the base-counts matrix produced in the section above `medaka` performs a normalization of the counts. Across the pileup columns, all count vectors with equal corresponding `major` position index are normalized by the total count for the column with `minor=0` (the reference position). This choice of normalization accounts for the lack of symmetry described above, and that whilst consensus insertions are typically rare, isolated insertions may still occur within any one read spanning two reference positions. There are on average up to three pileup columns for every input reference position.

Ordinarily this normalization is performed in a [post-processing](https://github.com/nanoporetech/medaka/blob/d195b9cc1ee7681a121be9fe4fb016a00744ef47/medaka/features.py#L372) method of the `CountsFeatureEncoder` class, for the purposes of exposition the operation under normal behaviour is:

In [None]:
import numpy as np
minor_inds = np.where(positions['minor'] > 0)
major_pos_at_minor_inds = positions['major'][minor_inds]
major_ind_at_minor_inds = np.searchsorted(
    positions['major'], major_pos_at_minor_inds, side='left')

depth = np.sum(counts, axis=1)
depth[minor_inds] = depth[major_ind_at_minor_inds]

feature_array = counts / np.maximum(1, depth).reshape((-1, 1))
display(feature_array)

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

The normalization is across all bases, it is not split by strand; splitting the normalization by strand would potentially lose important information with respect to strand bias and relative errors. The plot below visualizes the final input to the neural network used in `medaka`.

In [None]:
#@markdown ***Plot feature array*** *(click to show code)*

from bokeh.plotting import figure
from bokeh.models import Range1d
import bokeh.io as bkio

# select just a region to plot
reg = slice(10000,10050)
pdata = feature_array[reg] * 255
ppos = positions[reg]
# create RGBA image
pdata = np.stack([pdata]*4, axis=-1)
pdata[:,:,3] = 0 # alpha channel
pdata[:,[x.upper() == 'A' for x in codes],0] *= 0
pdata[:,[x.upper() == 'C' for x in codes],1] *= 0
pdata[:,[x.upper() == 'G' for x in codes],2] *= 0
pdata[:,[x.upper() == 'T' for x in codes],1:3] *= 0
pdata[:,[x.upper() == 'D' for x in codes],0:3] *= 10
pdata = pdata.astype(dtype=np.uint8)
pdata = np.transpose(pdata, axes=[1,0,2])[::-1,:,:]
pdata = 255 - pdata

# create a figure
p = figure(
    title="Base counts",
    plot_height=300, plot_width=800)

p.x_range.range_padding = p.y_range.range_padding = 0
p.image_rgba(image=[pdata], x=0, y=0, dw=pdata.shape[1], dh=pdata.shape[0])
ylabels = np.arange(0.5,10.5)
p.yaxis.ticker = ylabels
p.yaxis.major_label_overrides = dict(zip(ylabels, reversed(codes)))
p.y_range = Range1d(
    start=0, end=10,
    bounds=(0, 10))
xlabels = np.arange(0.5, pdata.shape[1])
p.xaxis.ticker = xlabels
p.xaxis.major_label_overrides = dict(zip(
    xlabels, ('{}.{}'.format(x['major'], x['minor']) for x in ppos)
))
p.xaxis.major_label_orientation = 3.14/2
bkio.output_notebook(hide_banner=True)
bkio.show(p)

# The neural network

Having counted bases in an alignment pileup `medaka` proceeds to analyse these counts using a [Recurrent Neural Network](https://en.wikipedia.org/wiki/Recurrent_neural_network), (RNN). A full discussion of such algorithms is beyond the scope of this discussion, this section demonstrates their use in calculating a consensus sequence from the base counts array. When `medaka` is used as a variant caller different methods are used.





## The model

In order to construct a consensus sequence `medaka` uses a multi-layer bidirection RNN. This is defined using the [`keras`](https://www.tensorflow.org/guide/keras) API in `tensorflow`. The following code is adapted from the [`models`](https://github.com/nanoporetech/medaka/blob/d195b9cc1ee7681a121be9fe4fb016a00744ef47/medaka/models.py#L31) module of `medaka`, it has been simplified to show only the essential parts:


In [None]:
from pkg_resources import resource_filename

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, CuDNNGRU, Bidirectional

from medaka.labels import BaseLabelScheme

# parameters of the model
gru_size = 128
time_steps, feature_len = (1000, counts.shape[1])
symbols = BaseLabelScheme.symbols  # [-, A, C, G, T]
num_classes = len(symbols)

# build the model
model = Sequential(name='medaka')
input_shape = (time_steps, feature_len)
for i in range(2):
    gru = CuDNNGRU(gru_size, return_sequences=True, name="gru_{}".format(i))
    model.add(Bidirectional(gru, input_shape=input_shape))
model.add(Dense(
    num_classes, activation='softmax', name='classify',
    input_shape=(time_steps, 2 * gru_size)))

# add pre-trained model weights
weight_file = resource_filename('medaka','data/r941_min_high_g344_model.hdf5')
model.load_weights(weight_file)
model.summary()

Model: "medaka"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bidirectional_16 (Bidirectio (None, 1000, 256)         107520    
_________________________________________________________________
bidirectional_17 (Bidirectio (None, 1000, 256)         296448    
_________________________________________________________________
classify (Dense)             (None, 1000, 5)           1285      
Total params: 405,253
Trainable params: 405,253
Non-trainable params: 0
_________________________________________________________________


The model takes the count matrix as input and outputs for each corresponding column of the counts matrix a set of five scores. The five scores express the possibility that the consensus sequence should contain one of the four bases A, C, G, or T, or a gap '-' character at the pileup column under consideration.

## Making predictions

In order to make predictions using the RNN, `medaka` splits the normalized counts array into overlapping chunks before processing by the model. Chunking the array allows for more efficient parallel computation while overlapping is a mitigation against edge-effects at the boundaries of chunks.

As mentioned above in the aside on performance, `medaka` has a somewhat elaborate system for managing data chunks. For the purposes of exposition the code below implements a simple chunking and batching of the data.

In [None]:
from functools import partial
from medaka.common import sliding_window, grouper

# create a function to perform windowing on an array
overlap = 200
window = partial(
    sliding_window,
    window=time_steps, step=time_steps - overlap, axis=0)

# run the network on input data
def get_predictions(data, batch_size=40):
    for batch in grouper(data, batch_size=batch_size):
        batch = np.stack(batch)
        results = model.predict_on_batch(batch)
        yield from results

t0 = now()
predictions = get_predictions(window(feature_array))
seq_chunks = list()
for pred predictions:
    # remove half the overlapping region of chunks
    pred = pred[overlap // 2:-overlap // 2]
    # find the most likely base at each position and form the sequence
    mp = np.argmax(pred, -1)
    seq = ''.join((symbols[x] for x in mp))
    seq = seq.replace('*', '')
    seq_chunks.append(seq)
sequence = ''.join(seq_chunks)
t1 = now()
print("{:.2f} to run predictions".format(t1 - t0))
print("Total sequence length: {}.".format(len(sequence)))

7.50 to run predictions
Total sequence length: 1500668.


The code performs a simple undoing of the overlapping before stitching the consensus sequence pieces back together. This is sufficient to obtain results here; the full `medaka` implementation also keeps track of the `positions` array to ensure the sequence stitching is performed correctly with respect to the original input reference sequence.



### Checking our results

We can write out the full consensus sequence derived above and compare it to the truth sequence by using the `assess_assembly` program from the [`pomoxis`](https://github.com/nanoporetech/pomoxis) package. By also examining the original draft sequence, we can see the improvement in quality from `medaka`:

In [None]:
output = "/content/data/output.fasta"
with open(output, 'w') as fh:
    fh.write(">seq\n{}\n".format(sequence))
for fname in ("/content/data/saureus_canu.fasta", output):
    print("Analysing: {}.".format(fname))
    !assess_assembly -r /content/data/references.fasta -i "$fname" 2>/dev/null
    print("\n")

Analysing: /content/data/saureus_canu.fasta.
Writing list of indels 100 bases and longer to assm_indel_ge100.txt.
#  Percentage Errors
  name     mean     q10      q50      q90   
 err_ont  0.133%   0.056%   0.068%   0.290% 
 err_bal  0.133%   0.056%   0.068%   0.290% 
    iden  0.004%   0.000%   0.001%   0.024% 
     del  0.117%   0.049%   0.059%   0.166% 
     ins  0.012%   0.003%   0.006%   0.046% 

#  Q Scores
  name     mean      q10      q50      q90  
 err_ont  28.76    32.54    31.65    25.38  
 err_bal  28.76    32.54    31.65    25.38  
    iden  44.39      inf    50.00    36.20  
     del  29.30    33.10    32.29    27.79  
     ins  39.20    45.23    42.22    33.41  

All done, output written to assm_stats.txt, assm_summ.txt and assm_indel_ge100.txt


Analysing: /content/data/output.fasta.
Writing list of indels 100 bases and longer to assm_indel_ge100.txt.
#  Percentage Errors
  name     mean     q10      q50      q90   
 err_ont  0.007%   0.003%   0.005%   0.013% 
 err_ba

# Remarks

In this short walkthrough we have examined some of the internals of Oxford Nanopore Technologies' `medaka` program performs GPU accelerated consensus calculations from aligned sequencing data. The public `medaka` codebase implements various alternative forms of the algorithms presented here including run length compression and support for multiple datatypes. Hopefully this guide will prove useful to anyone wishing to implement algorithms similar to that implemented in `medaka`.