To train a model, you first need to convert your sequences and targets into the input HDF5 format. Check out my tutorials for how to do that; they're linked from the [main page](../README.md).

For this tutorial, grab a small example HDF5 that I constructed here with 10% of the training sequences and only GM12878 targets for various DNase-seq, ChIP-seq, and CAGE experiments.

In [1]:
import os, subprocess

if not os.path.isfile('data/gm12878_l262k_w128_d10.h5'):
    subprocess.call('curl -o data/gm12878_l262k_w128_d10.h5 https://storage.googleapis.com/262k_binned/gm12878_l262k_w128_d10.h5', shell=True)

Next, you need to decide what sort of architecture to use. This grammar probably needs work; my goal was to enable hyperparameter searches to write the parameters to file so that I could run parallel training jobs to explore the hyperparameter space. I included an example set of parameters that will work well with this data in models/params_small.txt.

Then, run [basenji_train.py](https://github.com/calico/basenji/blob/master/bin/basenji_train.py) to train a model. The program will offer training feedback via stdout and write the model output files to the prefix given by the *-s* parameter.

The most relevant options here are:

| Option/Argument | Value | Note |
|:---|:---|:---|
| --rc | | Process even-numbered epochs as forward, odd-numbered as reverse complemented. Average the forward and reverse complement to assess validation accuracy. |
| -s | models/gm12878 | File path prefix to save the model. |
| params_file | models/params_small.txt | Table of parameters to setup the model architecture and optimization. |
| data_file | data/gm12878_l262k_w128_d10.h5 | HDF5 file containing the training and validation input and output datasets as generated by [basenji_hdf5_single.py](https://github.com/calico/basenji/blob/master/bin/basenji_hdf5_single.py) |

If you want to train, uncomment the following line and run it. Depending on your hardware, it may require many hours.

In [None]:
! basenji_train.py --logdir models/gm12878 --params models/params_small.txt --data data/gm12878_l262k_w128_d10.h5

  from ._conv import register_converters as _register_converters
  return f(*args, **kwds)
{'optimizer': 'adam', 'cnn_filters': [128, 160, 200, 250, 256, 32, 32, 32, 32, 32, 32, 384], 'loss': 'poisson', 'cnn_dropout': 0.05, 'learning_rate': 0.002, 'batch_buffer': 16384, 'cnn_filter_sizes': [20, 6, 6, 6, 3, 3, 3, 3, 3, 3, 3, 1], 'batch_size': 1, 'adam_beta2': 0.98, 'batch_renorm': 1, 'link': 'softplus', 'cnn_dense': [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0], 'adam_beta1': 0.97, 'cnn_dilation': [1, 1, 1, 1, 1, 2, 4, 8, 16, 32, 64, 1], 'cnn_pool': [2, 4, 4, 4, 1, 0, 0, 0, 0, 0, 0, 0]}
Targets pooled by 128 to length 2048
Convolution w/ 39 384x1 filters to final targets
Model building time 11.449232
Batcher initialized
2018-05-11 16:23:40.480890: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVX AVX2 FMA
Initializing...
Initialization time 17.491073


Alternatively, you can just download a trained model.

In [10]:
if not os.path.isfile('models/gm12878_d10/model_best.tf.meta'):
    subprocess.call('curl -o models/gm12878_d10.tf.index https://storage.googleapis.com/basenji_tutorial_data/model_gm12878_d10.tf.index', shell=True)
    subprocess.call('curl -o models/gm12878_d10.tf.meta https://storage.googleapis.com/basenji_tutorial_data/model_gm12878_d10.tf.meta', shell=True)
    subprocess.call('curl -o models/gm12878_d10.tf.data-00000-of-00001 https://storage.googleapis.com/basenji_tutorial_data/model_gm12878_d10.tf.data-00000-of-00001', shell=True)

models/gm12878_best.tf will now specify the name of your saved model to be provided to other programs.

To further benchmark the accuracy (e.g. computing significant "peak" accuracy), use [basenji_test.py](https://github.com/calico/basenji/blob/master/bin/basenji_test.py).

The most relevant options here are:

| Option/Argument | Value | Note |
|:---|:---|:---|
| --rc | | Average the forward and reverse complement to form prediction. |
| -o | data/gm12878_test | Output directory. |
| --ai | 0,1,2 | Make accuracy scatter plots for targets 0, 1, and 2. |
| --ti | 3,4,5 | Make BigWig tracks for targets 3, 4, and 5. |
| -t | data/gm12878_l262k_w128_d10.bed | BED file describing sequence regions for BigWig track output. |
| params_file | models/params_small.txt | Table of parameters to setup the model architecture and optimization. |
| model_file | models/gm12878_d10.tf | Trained saved model prefix. |
| data_file | data/gm12878_l262k_w128_d10.h5 | HDF5 file containing the test input and output datasets as generated by [basenji_hdf5_single.py](https://github.com/calico/basenji/blob/master/bin/basenji_hdf5_single.py) |

In [11]:
! basenji_test.py --rc -o data/gm12878_test --ai 0,1,2 -t data/gm12878_l262k_w128_d10.bed --ti 3,4,5 models/params_small.txt models/gm12878_d10/model_best.tf data/gm12878_l262k_w128_d10.h5

{'batch_buffer': 16384, 'loss': 'poisson', 'full_dropout': 0.05, 'adam_beta1': 0.97, 'cnn_filter_sizes': [22, 1, 6, 6, 6, 3], 'cnn_filters': [128, 128, 160, 200, 250, 256], 'adam_beta2': 0.98, 'dcnn_filter_sizes': [3, 3, 3, 3, 3, 3], 'dense': 1, 'full_units': 384, 'learning_rate': 0.002, 'link': 'softplus', 'cnn_dropout': 0.05, 'cnn_pool': [1, 2, 4, 4, 4, 1], 'batch_size': 1, 'batch_renorm': 1, 'dcnn_dropout': 0.1, 'dcnn_filters': [32, 32, 32, 32, 32, 32]}
Targets pooled by 128 to length 2048
Convolution w/ 128 4x22 filters strided by 1
Batch normalization
ReLU
Dropout w/ probability 0.050
Convolution w/ 128 128x1 filters strided by 1
Batch normalization
ReLU
Max pool 2
Dropout w/ probability 0.050
Convolution w/ 160 128x6 filters strided by 1
Batch normalization
ReLU
Max pool 4
Dropout w/ probability 0.050
Convolution w/ 200 160x6 filters strided by 1
Batch normalization
ReLU
Max pool 4
Dropout w/ probability 0.050
Convolution w/ 250 200x6 filters strided by 1
Batch normalization
ReLU

*data/gm12878_test/acc.txt* is a table specifiying the loss function value, R2, R2 after log2, and Spearman correlation for each dataset. 

In [12]:
! cat data/gm12878_test/acc.txt

   0  2.55459  0.12923  0.06884  0.20332  ENCSR000EJD_3_1
   1  2.07825  0.26504  0.11945  0.25228  ENCSR000EMT_2_1
   2  1.34195  0.23470  0.12942  0.26919  ENCSR000EMT_1_1
   3  2.73190  0.11670  0.08710  0.26056  ENCSR000EJD_1_1
   4  2.34003  0.13552  0.10720  0.28923  ENCSR000EJD_2_1
   5  1.65317  0.42328  0.22528  0.37787  ENCSR057BWO_2_1
   6  1.15429  0.17104  0.17254  0.34555  ENCSR000AKE_1_1
   7  0.84690  0.08580  0.07383  0.30523  ENCSR000AKF_2_1
   8  1.01278  0.30182  0.09221  0.16016  ENCSR000AOV_2_1
   9  0.82520  0.05287  0.06443  0.26996  ENCSR000AKI_2_1
  10  2.23546  0.50709  0.20589  0.29792  ENCSR000AKA_2_1
  11  1.02572  0.03412  0.03277  0.14032  ENCSR000AOX_2_1
  12  1.06876  0.19844  0.23515  0.46800  ENCSR000DRW_1_1
  13  1.23316  0.21266  0.22506  0.38231  ENCSR000AOW_1_1
  14  1.13732  0.06730  0.08054  0.29732  ENCSR000AKD_1_1
  15  0.91039  0.18123  0.20173  0.38487  ENCSR000AKE_2_1
  16  1.04031  0.20278  0.23683  0.47334  ENCSR000DRW_2_1
  17  1.33496 

*data/gm12878_test/peak.txt* is a table specifiying the number of peaks called, AUROC, and AUPRC for each dataset. 

In [13]:
! cat data/gm12878_test/peaks.txt

   0     627  0.64745  0.23860
   1     194  0.76438  0.26973
   2     124  0.82391  0.26973
   3     867  0.65559  0.28928
   4     644  0.66481  0.26887
   5     267  0.79283  0.33708
   6     343  0.82138  0.27525
   7     191  0.78184  0.12368
   8     143  0.80043  0.25877
   9       3  0.63227  0.00096
  10     350  0.79971  0.37140
  11     184  0.63327  0.08505
  12     295  0.85445  0.29855
  13     324  0.87453  0.34045
  14     130  0.67821  0.05758
  15     289  0.85102  0.27473
  16     273  0.85927  0.31028
  17     201  0.83182  0.37427
  18     116  0.84729  0.35814
  19      98  0.68407  0.03596
  20     189  0.86034  0.34349
  21     108  0.83877  0.40408
  22      95  0.93912  0.53441
  23     104  0.95152  0.54211
  24     145  0.71995  0.06288
  25     182  0.84500  0.29557
  26      55  0.66193  0.02123
  27      94  0.87050  0.26843
  28     202  0.76177  0.15448
  29     117  0.73136  0.05396
  30     468  0.73932  0.21746
  31     318  0.86267  0.32694
  32    

The directories *pr*, *roc*, *violin*, and *scatter* in *data/gm12878_test* contain plots for the targets indexed by 0, 1, and 2 as specified by the --ai option above.

E.g.

In [18]:
from IPython.display import IFrame
IFrame('data/gm12878_test/pr/t0.pdf', width=600, height=500)