In [None]:
# format code with "black" formatter. optional
%load_ext nb_black

# recreate "representation learning" paper

Koo PK, Eddy SR (2019). Representation learning of genomic sequence motifs with convolutional neural networks. _PLOS Computational Biology_ 15(12): e1007560. https://doi.org/10.1371/journal.pcbi.1007560

Also at https://www.biorxiv.org/content/10.1101/362756v4.full

## install python dependencies

In [None]:
%pip install --no-cache-dir https://github.com/p-koo/tfomics/tarball/master

## load data

In [None]:
!wget --timestamping https://www.dropbox.com/s/c3umbo5y13sqcfp/synthetic_dataset.h5

In [None]:
from pathlib import Path
import h5py
import numpy as np

In [None]:
data_path = Path("synthetic_dataset.h5")
with h5py.File(data_path, "r") as dataset:
    x_train = dataset["X_train"][:].astype(np.float32)
    y_train = dataset["Y_train"][:].astype(np.float32)
    x_valid = dataset["X_valid"][:].astype(np.float32)
    y_valid = dataset["Y_valid"][:].astype(np.int32)
    x_test = dataset["X_test"][:].astype(np.float32)
    y_test = dataset["Y_test"][:].astype(np.int32)

x_train = x_train.transpose([0, 2, 1])
x_valid = x_valid.transpose([0, 2, 1])
x_test = x_test.transpose([0, 2, 1])

N, L, A = x_train.shape
print(f"{N} sequences, {L} nts per sequence, {A} nts in alphabet")

## Max-pooling influences ability to build hierarchical motif representations

>The goal of this computational task is to simultaneously make 12 binary predictions for the presence or absence of each transcription factor motif in the sequence.


### make CNN models

from methods > cnn models

>All CNNs take as input a 1-dimensional one-hot-encoded sequence with 4 channels (one for each nucleotide: A, C, G, T), then processes the sequence with two convolutional layers, a fully-connected hidden layer, and a fully-connected output layer with 12 output neurons that have sigmoid activations for binary predictions. Each convolutional layer consists of a 1D cross-correlation operation, which calculates a running sum between convolution filters and the inputs to the layer, followed by batch normalization (Ioffe and Szegedy, 2015), which independently scales the features learned by each convolution filter, and a non-linear activation with a rectified linear unit (ReLU), which replaces negative values with zero.
>
>The first convolutional layer employs 30 filters each with a size of 19 and a stride of 1. The second convolutional layer employs 128 filters each with a size of 5 and a stride of 1. All convolutional layers incorporate zero-padding to achieve the same output length as the inputs. Each convolutional layer is followed by max-pooling with a window size and stride that are equal, unless otherwise stated. The product of the two max-pooling window sizes is equal to 100. Thus, if the first max-pooling layer has a window size of 2, then the second max-pooling window size is 50. This constraint ensures that the number of inputs to the fully-connected hidden layer is the same across all models. The fully-connected hidden layer employs 512 units with ReLU activations.
>
>Dropout (Srivastava et al, 2014), a common regularization technique for neural networks, is applied during training after each convolutional layer, with a dropout probability set to 0.1 for convolutional layers and 0.5 for fully-connected hidden layers. During training, we also employed L2-regularization with a strength equal to 1e-6. The parameters of each model were initialized according to (He et al, 2015), commonly known as He initialization.
>
>All models were trained with mini-batch stochastic gradient descent (mini-batch size of 100 sequences) for 100 epochs, updating the parameters after each mini-batch with Adam updates (Kingma and Ba, 2014), using recommended default parameters with a constant learning rate of 0.0003. Training was performed on a NVIDIA GTX Titan X Pascal graphical processing unit with acceleration provided by cuDNN libraries (Chetlur et al, 2014). All reported performance metrics and saliency logos are drawn strictly from the test set using the model parameters which yielded the lowest binary cross-entropy loss on the validation set, a technique known as early stopping.

In [None]:
import pandas as pd
import tensorflow as tf

print("tensorflow version", tf.__version__)

tfk = tf.keras
tfkl = tf.keras.layers


def get_model(
    pool1: int, pool2: int, n_classes: int = 12, batch_size: int = None
) -> tfk.Sequential:
    """Return a Model object with two convolutional layers, a
    fully-connected hidden layer, and output. Sigmoid activation is
    applied to logits.

    Parameters
    ----------
    pool1 : int
        Size of pooling window in the max-pooling operation after the first
        convolution.
    pool2 : int
        Size of pooling window in the max-pooling operation after the second
        convolution.
    n_classes : int
        Number of output units.
    batch_size : int
        Batch size of input. If `None`, batch size can be variable.

    Returns
    -------
    Instance of `tf.keras.Sequential`. This model is not compiled.
    """
    if pool1 * pool2 != 100:
        raise ValueError("product of pool sizes must be 100")
    l2_reg = tfk.regularizers.l2(1e-6)
    return tfk.Sequential(
        [
            tfkl.Input(shape=(L, A), batch_size=batch_size),
            # layer 1
            tfkl.Conv1D(
                filters=30,
                kernel_size=19,
                strides=1,
                padding="same",
                use_bias=False,
                kernel_regularizer=l2_reg,
            ),
            tfkl.BatchNormalization(),
            tfkl.Activation(tf.nn.relu),
            tfkl.MaxPool1D(pool_size=pool1, strides=pool1),
            tfkl.Dropout(0.1),
            # layer 2
            tfkl.Conv1D(
                filters=128,
                kernel_size=5,
                strides=1,
                padding="same",
                use_bias=False,
                kernel_regularizer=l2_reg,
            ),
            tfkl.BatchNormalization(),
            tfkl.Activation(tf.nn.relu),
            tfkl.MaxPool1D(pool_size=pool2, strides=pool2),
            tfkl.Dropout(0.1),
            # layer 3
            tfkl.Flatten(),
            tfkl.Dense(
                units=512, activation=None, use_bias=None, kernel_regularizer=l2_reg
            ),
            tfkl.BatchNormalization(),
            tfkl.Activation(tf.nn.relu),
            tfkl.Dropout(0.5),
            # layer 4 (output). do not use activation (ie linear activation) so we can inspect
            # the logits later.
            tfkl.Dense(
                units=n_classes,
                activation=None,
                use_bias=True,
                kernel_initializer=tfk.initializers.GlorotNormal(),
                bias_initializer=tfk.initializers.Zeros(),
                name="logits",
            ),
            tfkl.Activation(tf.nn.sigmoid, name="predictions"),
        ]
    )

### train models

In [None]:
save_dir = Path("models")
pool_pairs = [(1, 100), (2, 50), (4, 25), (10, 10), (25, 4), (50, 2), (100, 1)]

In [None]:
for pool1, pool2 in pool_pairs:
    print(f"++ training model with pool sizes {pool1}, {pool2}")
    model = get_model(pool1=pool1, pool2=pool2)

    metrics = [
        tfk.metrics.AUC(curve="ROC", name="auroc"),
        tfk.metrics.AUC(curve="PR", name="aupr"),  # precision-recall
    ]
    model.compile(
        optimizer=tfk.optimizers.Adam(learning_rate=0.001),
        loss=tfk.losses.BinaryCrossentropy(from_logits=False),
        metrics=metrics,
    )

    callbacks = [
        tfk.callbacks.EarlyStopping(
            monitor="val_aupr",
            patience=20,
            verbose=1,
            mode="max",
            restore_best_weights=False,
        ),
        tfk.callbacks.ReduceLROnPlateau(
            monitor="val_aupr",
            factor=0.2,
            patience=5,
            min_lr=1e-7,
            mode="max",
            verbose=1,
        ),
    ]
    # train
    history: tfk.callbacks.History = model.fit(
        x=x_train,
        y=y_train,
        batch_size=100,
        epochs=100,
        shuffle=True,
        validation_data=(x_valid, y_valid),
        callbacks=callbacks,
        verbose=2,
    )
    # save
    save_dir.mkdir(exist_ok=True)
    filepath = save_dir / f"model-{pool1:03d}-{pool2:03d}.h5"
    model.save(filepath)
    # cannot save directly with json standard lib because numpy datatypes
    # will cause an error. pandas converts things for us.
    df_hist = pd.DataFrame(history.history)
    df_hist.to_json(filepath.with_suffix(".json"))

In [None]:
!ls $save_dir

### evaluate models

End goal is to get percent matches with JASPAR data.

In [None]:
# Download JASPAR database.
!wget --timestamping https://www.dropbox.com/s/ha1sryrxfhx7ex7/JASPAR_CORE_2016_vertebrates.meme

In [None]:
%%bash
# only run this if tomtom program not found
if command -v tomtom; then
  echo "tomtom program installed"
  exit
fi
mkdir meme-src
cd meme-src
curl -fL https://meme-suite.org/meme/meme-software/5.3.1/meme-5.3.1.tar.gz | tar xz --strip-components 1
./configure --prefix=$HOME/meme --with-url=http://meme-suite.org --enable-build-libxml2 --enable-build-libxslt
make
make test
make install

In [None]:
# add meme programs to PATH
import os

os.environ["PATH"] += f'{os.pathsep}{Path.home() / "meme" / "bin"}'

In [None]:
from collections import namedtuple

import matplotlib.pyplot as plt
import subprocess
import tfomics
import tfomics.impress

In [None]:
# Container for comparison between motifs and filters for one model.
meme_entry = namedtuple(
    "meme_entry",
    "match_fraction match_any filter_match filter_qvalue min_qvalue num_counts",
)

outputs = {}

for pool1, pool2 in pool_pairs:
    print("\n++++ evaluating cnn", pool1, pool2)
    # Load model.
    model = tfk.models.load_model(save_dir / f"model-{pool1:03d}-{pool2:03d}.h5")
    _ = model.evaluate(x_test, y_test)

    # layers: (0)conv -> (1)batchnorm -> (2)relu
    W = tfomics.moana.filter_activations(
        x_test=x_test, model=model, layer=2, window=20, threshold=0.5
    )

    # Create meme file
    W_clipped = tfomics.moana.clip_filters(W, threshold=0.5, pad=3)
    meme_file = save_dir / f"filters-{pool1:03d}-{pool2:03d}.meme"
    tfomics.moana.meme_generate(W_clipped, output_file=meme_file, prefix="filter")
    print("++ saved motifs to", meme_file)

    # Use tomtom to determine which motifs our filters are similar to.
    print("++ running tomtom")
    output_path = "filters"
    jaspar_path = "JASPAR_CORE_2016_vertebrates.meme"
    args = [
        "tomtom",
        "-thresh",
        "0.5",
        "-dist",
        "pearson",
        "-evalue",
        "-oc",
        output_path,
        meme_file,
        jaspar_path,
    ]
    ret = subprocess.run(args, check=True)

    # See which motifs the filters are similar to.
    num_filters = moana.count_meme_entries(meme_file)
    out = evaluate.motif_comparison_synthetic_dataset(
        Path(output_path) / "tomtom.tsv", num_filters=num_filters
    )
    # Save comparisons to dict.
    outputs[f"cnn-{pool1:03d}-{pool2:03d}"] = meme_entry(*out)

    # Plot logos with motif names.
    fig = plt.figure(figsize=(25, 4))
    tfomics.impress.plot_filters(W, fig, num_cols=6, names=filter_match, fontsize=14)
    fig.suptitle(f"filters - cnn {pool1} x {pool2}")
    plt.savefig(save_dir / f"filter-logos-{pool1:03d}-{pool2:03d}.pdf")
    plt.show()

In [None]:
print("match fractions")
for k, v in outputs.items():
    print(f"{k}: {v.match_fraction:0.3f}")

## Sensitivity of motif representations to the number of filters

## Motif representations are not very sensitive to 1st layer filter size

## Motif representations are affected by the ability to assemble whole motifs in deeper layers

In [None]:
l2_reg = tfk.regularizers.l2(1e-6)
cnn_50_2 = tfk.Sequential(
    [
        tfkl.Input(shape=(L, A)),
        # layer 1
        tfkl.Conv1D(
            filters=30,
            kernel_size=19,
            strides=1,
            padding="same",
            use_bias=False,
            kernel_regularizer=l2_reg,
        ),
        tfkl.BatchNormalization(),
        tfkl.Activation(tf.nn.relu),
        tfkl.MaxPool1D(pool_size=50, strides=2),
        tfkl.Dropout(0.1),
        # layer 2
        tfkl.Conv1D(
            filters=128,
            kernel_size=5,
            strides=1,
            padding="same",
            use_bias=False,
            kernel_regularizer=l2_reg,
        ),
        tfkl.BatchNormalization(),
        tfkl.Activation(tf.nn.relu),
        tfkl.MaxPool1D(pool_size=50, strides=50),
        tfkl.Dropout(0.1),
        # layer 3
        tfkl.Flatten(),
        tfkl.Dense(
            units=512, activation=None, use_bias=None, kernel_regularizer=l2_reg
        ),
        tfkl.BatchNormalization(),
        tfkl.Activation(tf.nn.relu),
        tfkl.Dropout(0.5),
        # layer 4 (output). do not use activation (ie linear activation) so we can inspect
        # the logits later.
        tfkl.Dense(
            units=12,
            activation=None,
            use_bias=True,
            kernel_initializer=tfk.initializers.GlorotNormal(),
            bias_initializer=tfk.initializers.Zeros(),
            name="logits",
        ),
        tfkl.Activation(tf.nn.sigmoid, name="predictions"),
    ]
)

## Distributed representations build whole motif representations in deeper layers