Skip to content

Commit

Permalink
Update LICENSE
Browse files Browse the repository at this point in the history
  • Loading branch information
cangermueller committed Mar 2, 2017
1 parent 3195a37 commit 70a0ec5
Show file tree
Hide file tree
Showing 7 changed files with 876 additions and 36 deletions.
2 changes: 2 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
The MIT License (MIT)

Copyright (c) 2017 Christof Angermueller

Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down
23 changes: 18 additions & 5 deletions deepcpg/models/cpg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""CpG models.
Provides models trained observed neighboring methylation states of multiple
Provides models trained with observed neighboring methylation states of multiple
cells.
"""

Expand All @@ -18,6 +18,7 @@


class CpgModel(Model):
"""Abstract class of a CpG model."""

def __init__(self, *args, **kwargs):
super(CpgModel, self).__init__(*args, **kwargs)
Expand All @@ -34,8 +35,12 @@ def _merge_inputs(self, inputs):
return kl.merge(inputs, mode='concat', concat_axis=2)


class DenseAvg(CpgModel):
"""54000 params"""
class FcAvg(CpgModel):
"""Fully-connected layer followed by global average layer.
Parameters: 54,000
Specification: fc[512]_gap
"""

def _replicate_model(self, input):
w_reg = kr.WeightRegularizer(l1=self.l1_decay, l2=self.l2_decay)
Expand All @@ -57,7 +62,11 @@ def __call__(self, inputs):


class RnnL1(CpgModel):
"""810000 parameters"""
"""Bidirectional GRU with one layer.
Parameters: 810,000
Specification: fc[256]_bgru[256]_do
"""

def __init__(self, act_replicate='relu', *args, **kwargs):
super(RnnL1, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -85,7 +94,11 @@ def __call__(self, inputs):


class RnnL2(RnnL1):
"""1112069 params"""
"""Bidirectional GRU with two layers.
Parameters: 1,100,000
Specification: fc[256]_bgru[128]_bgru[256]_do
"""

def __call__(self, inputs):
x = self._merge_inputs(inputs)
Expand Down
61 changes: 33 additions & 28 deletions deepcpg/models/dna.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def inputs(self, dna_wlen):


class CnnL1h128(DnaModel):
"""CNN with one convolutional and one hidden layer with 128 units.
"""CNN with one convolutional and one fully-connected layer with 128 units.
Specification: conv[128@11]_mp[4]_fc[128]_do[0.0]
Parameters: 4.100.000
Parameters: 4,100,000
Specification: conv[128@11]_mp[4]_fc[128]_do
"""

def __init__(self, nb_hidden=128, *args, **kwargs):
Expand All @@ -56,10 +56,10 @@ def __call__(self, inputs):


class CnnL1h256(CnnL1h128):
"""CNN with one convolutional and one hidden layer with 256 units.
"""CNN with one convolutional and one fully-connected layer with 256 units.
Specification: conv[128@11]_mp[4]_fc[256]_do[0.0]
Parameters: 8.100.000
Parameters: 8,100,000
Specification: conv[128@11]_mp[4]_fc[256]_do
"""

def __init__(self, *args, **kwargs):
Expand All @@ -68,10 +68,10 @@ def __init__(self, *args, **kwargs):


class CnnL2h128(DnaModel):
"""CNN with two convolutional and one hidden layer with 128 units.
"""CNN with two convolutional and one fully-connected layer with 128 units.
Specification: conv[128@11]_mp[4]_conv[256@3]_mp[2]_fc[128]_do[0.0]
Parameters: 4.100.000
Parameters: 4,100,000
Specification: conv[128@11]_mp[4]_conv[256@3]_mp[2]_fc[128]_do
"""

def __init__(self, nb_hidden=128, *args, **kwargs):
Expand Down Expand Up @@ -102,10 +102,10 @@ def __call__(self, inputs):


class CnnL2h256(CnnL2h128):
"""CNN with two convolutional and one hidden layer with 256 units.
"""CNN with two convolutional and one fully-connected layer with 256 units.
Specification: conv[128@11]_mp[4]_conv[256@3]_mp[2]_fc[256]_do[0.0]
Parameters: 8.100.000
Parameters: 8,100,000
Specification: conv[128@11]_mp[4]_conv[256@3]_mp[2]_fc[256]_do
"""

def __init__(self, *args, **kwargs):
Expand All @@ -114,11 +114,11 @@ def __init__(self, *args, **kwargs):


class CnnL3h128(DnaModel):
"""CNN with three convolutional and one hidden layer with 128 units.
"""CNN with three convolutional and one fully-connected layer with 128 units.
Parameters: 4,400,000
Specification: conv[128@11]_mp[4]_conv[256@3]_mp[2]_conv[512@3]_mp[2]_
fc[128]_do[0.0]
Parameters: 4.400.000
fc[128]_do
"""

def __init__(self, nb_hidden=128, *args, **kwargs):
Expand Down Expand Up @@ -154,11 +154,11 @@ def __call__(self, inputs):


class CnnL3h256(CnnL3h128):
"""CNN with three convolutional and one hidden layer with 256 units.
"""CNN with three convolutional and one fully-connected layer with 256 units.
Parameters: 8,300,000
Specification: conv[128@11]_mp[4]_conv[256@3]_mp[2]_conv[512@3]_mp[2]_
fc[256]_do[0.0]
Parameters: 8.300.000
fc[256]_do
"""

def __init__(self, *args, **kwargs):
Expand All @@ -172,8 +172,9 @@ class CnnRnn01(DnaModel):
Convolutional-recurrent model with two convolutional layers followed by a
bidirectional GRU layer.
Specification: conv[128@11]_pool[4]_conv[256@7]_pool[4]_bGRU[256]_do[0.0]
Parameters: 1.100.000"""
Parameters: 1,100,000
Specification: conv[128@11]_pool[4]_conv[256@7]_pool[4]_bgru[256]_do
"""

def __call__(self, inputs):
x = inputs[0]
Expand All @@ -196,9 +197,10 @@ def __call__(self, inputs):


class ResNet01(DnaModel):
"""Residual network with 3x2 bottleneck residual units.
"""Residual network with bottleneck residual units.
Parameters: 1.700.000
Parameters: 1,700,000
Specification: conv[128@11]_mp[2]_resb[2x128|2x256|2x512|1x1024]_gap_do
He et al., 'Identity Mappings in Deep Residual Networks.'
"""
Expand Down Expand Up @@ -289,9 +291,10 @@ def __call__(self, inputs):


class ResNet02(ResNet01):
"""Residual network with 3x3 bottleneck residual units.
"""Residual network with bottleneck residual units.
Parameters: 2.000.000
Parameters: 2,000,000
Specification: conv[128@11]_mp[2]_resb[3x128|3x256|3x512|1x1024]_gap_do
He et al., 'Identity Mappings in Deep Residual Networks.'
"""
Expand Down Expand Up @@ -333,9 +336,10 @@ def __call__(self, inputs):


class ResConv01(ResNet01):
"""Residual network with two convolutional layers in each residual units.
"""Residual network with two convolutional layers in each residual unit.
Parameters: 2.800.000
Parameters: 2,800,000
Specification: conv[128@11]_mp[2]_resc[2x128|1x256|1x256|1x512]_gap_do
He et al., 'Identity Mappings in Deep Residual Networks.'
"""
Expand Down Expand Up @@ -420,7 +424,8 @@ class ResAtrous01(DnaModel):
units. Atrous convolutional layers allow to increase the receptive field and
hence better model long-range dependencies.
Parameters: 2.000.000
Parameters: 2,000,000
Specification: conv[128@11]_mp[2]_resa[3x128|3x256|3x512|1x1024]_gap_do
He et al., 'Identity Mappings in Deep Residual Networks.'
Yu and Koltun, 'Multi-Scale Context Aggregation by Dilated Convolutions.'
Expand Down Expand Up @@ -518,7 +523,7 @@ def __call__(self, inputs):
def list_models():
models = dict()
for name, value in globals().items():
if inspect.isclass(value) and name.lower().find('model') == 0:
if inspect.isclass(value) and name.lower().find('model') == -1:
models[name] = value
return models

Expand Down
33 changes: 33 additions & 0 deletions deepcpg/models/joint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Joint models.
Provides models two join features of DNA and CpG model.
"""
from __future__ import division
from __future__ import print_function

import inspect

from keras import layers as kl
from keras import models as km
from keras import regularizers as kr
Expand Down Expand Up @@ -39,12 +45,21 @@ def _build(self, models, layers=[]):


class JointL0(JointModel):
"""Concatenates inputs without trainable layers.
Parameters: 0
"""

def __call__(self, models):
return self._build(models)


class JointL1h512(JointModel):
"""One fully-connected layer with 512 units.
Parameters: 524,000
Specification: fc[512]
"""

def __init__(self, nb_layer=1, nb_hidden=512, *args, **kwargs):
super(JointL1h512, self).__init__(*args, **kwargs)
Expand All @@ -64,18 +79,36 @@ def __call__(self, models):


class JointL2h512(JointL1h512):
"""Two fully-connected layers with 512 units.
Parameters: 786,000
Specification: fc[512]_fc[512]
"""

def __init__(self, *args, **kwargs):
super(JointL2h512, self).__init__(*args, **kwargs)
self.nb_layer = 2


class JointL3h512(JointL1h512):
"""Three fully-connected layers with 512 units.
Parameters: 1,000,000
Specification: fc[512]_fc[512]_fc[512]
"""

def __init__(self, *args, **kwargs):
super(JointL3h512, self).__init__(*args, **kwargs)
self.nb_layer = 3


def list_models():
models = dict()
for name, value in globals().items():
if inspect.isclass(value) and name.lower().find('model') == -1:
models[name] = value
return models


def get(name):
return get_from_module(name, globals())
80 changes: 80 additions & 0 deletions docs/source/modules.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Module architectures

DeepCpG consists of a DNA module to recognize features in the DNA sequence, a CpG module to recognize features in the methylation neighborhood of multiple cells, and joint module to combine the features from the DNA and CpG module.

DeepCpG provides different architectures for the DNA, CpG, and joint module. Architectures differ in the number of layers and neurons, and are hence more or less complex. More complex modules are usually more accurate, but more expensive to train. You can select a certain architecture using the `--dna_module`, `--cpg_model`, and `--joint_model` argument of `dcpg_train.py`, for example:

```
dcpg_train.py
--dna_module CnnL2h128
--cpg_module RnnL1
--joint_module JointL2h512
```

In the following, the following layer specifications will be used:

| Specification | Description |
|---------------|--------------------------------------------------------------------------|
| conv[x@y] | Convolutional layer with x filters of size y |
| mp[x] | Max-pooling layer with size x |
| fc[x] | Full-connected layer with x units |
| do | Dropout layer |
| bgru[x] | Bidirectional GRU with x units |
| gap | Global average pooling layer |
| resb[x,y,z] | Residual network with three bottleneck residual units of size x, y, z |
| resc[x,y,z] | Residual network with three convolutional residual units of size x, y, z |
| resa[x,y,z] | Residual network with three Atrous residual units of size x, y, z |


## DNA modules

| Name | Parameters | Specification |
|-------------|------------|-------------------------------------------------------------------|
| CnnL1h128 | 4,100,000 | conv[128@11]_mp[4]_fc[128]_do |
| CnnL1h256 | 8,100,000 | conv[128@11]_mp[4]_fc[256]_do |
| CnnL2h128 | 4,100,000 | conv[128@11]_mp[4]_conv[256@3]_mp[2]_fc[128]_do |
| CnnL2h256 | 8,100,000 | conv[128@11]_mp[4]_conv[256@3]_mp[2]_fc[256]_do |
| CnnL3h128 | 4,400,000 | conv[128@11]_mp[4]_conv[256@3]_mp[2]_conv[512@3]_mp[2]_fc[128]_do |
| CnnL3h256 | 8,300,000 | conv[128@11]_mp[4]_conv[256@3]_mp[2]_conv[512@3]_mp[2]_fc[128]_do |
| CnnRnn01 | 1,100,000 | conv[128@11]_pool[4]_conv[256@7]_pool[4]_bgru[256]_do |
| ResNet01 | 1,700,000 | conv[128@11]_mp[2]_resb[2x128|2x256|2x512|1x1024]_gap_do |
| ResNet02 | 2,000,000 | conv[128@11]_mp[2]_resb[3x128|3x256|3x512|1x1024]_gap_do |
| ResConv01 | 2,800,000 | conv[128@11]_mp[2]_resc[2x128|1x256|1x256|1x512]_gap_do |
| ResAtrous01 | 2,000,000 | conv[128@11]_mp[2]_resa[3x128|3x256|3x512|1x1024]_gap_do |

Th prefixes `Cnn`, `CnnRnn`, `ResNet`, `ResConv`, and `ResAtrous` denote the class of the DNA module.

Modules starting with `Cnn` are convolutional neural networks (CNNs). DeepCpG CNN architectures consist of a series of convolutional and max-pooling layers, which are followed by one fully-connected layer. Module `CnnLxhy` has `x` convolutional-pooling layers, and one fully-connected layer with `y` units. For example, `CnnL2h128` has two convolutional layers, and one fully-connected layer with 128 units. `CnnL3h256` has three convolutional layers and one fully-connected layer with 256 units. `CnnL1h128` is the fastest module, but modules with more layers and neurons usually perform better. In my experiments, `CnnL2h128` provided a good trade-off between performance and runtime, which I recommend as default.

`CnnRnn01` is a [convolutional-recurrent neural network](http://nar.oxfordjournals.org/content/44/11/e107). It consists of two convolutional-pooling layers, which are followed by a bidirectional recurrent neural network (RNN) with one layer and gated recurrent units (GRUs). `CnnRnn01` is slower than `Cnn` architectures and did not perform better in my experiments.

Modules starting with `ResNet` are [residual neural networks](https://arxiv.org/abs/1603.05027). ResNets are very deep networks with skip connections to improve the gradient flow and to allow learning how many layers to use. A residual network consists of multiple residual blocks, and each residual block consists of multiple residual units. Residual units have a bottleneck architecture with three convolutional layers to speed up computations. `ResNet01` and `ResNet02` have three residual blocks with two and three residual units, respectively. ResNets are slower than CNNs, but can perform better on large datasets.

Modules starting with `ResConv` are ResNets with modified residual units that have two convolutional layers instead of a bottleneck architecture. `ResConv` modules performed worse than `ResNet` modules in my experiments.

Modules starting with `ResAtrous` are ResNets with modified residual units that use [Atrous convolutional layers](http://arxiv.org/abs/1511.07122) instead of normal convolutional layers. Atrous convolutional layers have dilated filters, i.e. filters with 'holes', which allow scanning wider regions in the inputs sequence and thereby better capturing distant patters in the DNA sequence. However, `ResAtrous` modules performed worse than `ResNet` modules in my experiments


## CpG modules

| Name | Parameters | Specification |
|-------|------------|--------------------------------|
| FcAvg | 54,000 | fc[512]_gap |
| RnnL1 | 810,000 | fc[256]_bgru[256]_do |
| RnnL2 | 1,100,000 | fc[256]_bgru[128]_bgru[256]_do |

`FcAvg` is a lightweight module with only 54000 parameters, which first transforms observed neighboring CpG sites of all cells independently, and than averages the transformed features across cells. `FcAvg` is very fast, but performs worse than RNN modules.

`Rnn` modules consists of bidirectional recurrent neural networks (RNNs) with gated recurrent units (GRUs) to summarize the methylation neighborhood of cells in a more clever way than averaging. `RnnL1` consists of one fully-connected layer with 256 units to transform the methylation neighborhood of each cell independently, and one bidirectional GRU with 2x256 units to summarize the transformed methylation neighborhood of cells. `RnnL2` has two instead of one GRU layer. `RnnL1` is faster and performed as good as `RnnL2` in my experiments.


## Joint modules

| Name | Parameters | Specification |
|-------------|------------|----------------------------------------|
| JointL0 | 0 | |
| JointL1h512 | 524,000 | fc[512] |
| JointL2h512 | 786,000 | fc[512]_fc[512] |
| JointL3h512 | 1,000,000 | Specification: fc[512]_fc[512]_fc[512] |

Joint modules join the feature from the DNA and CpG module. `JointL0` simply concatenates the features and has no learnable parameters (ultra fast). `JointLXh512` has `X` fully-connect layers with 512 neurons. Modules with more layers usually perform better, at the cost of a higher runtime. I recommend using `JointL2h512` or `JointL3h12`.

0 comments on commit 70a0ec5

Please sign in to comment.