Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

47 rescale filters after initializing with pwm #2

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 127 additions & 80 deletions concise/initializers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from keras import layers as kl
from keras import regularizers as kr
import keras.initializers as ki
from keras.initializers import Initializer, serialize, deserialize
from keras.initializers import Initializer, _compute_fans
from keras import backend as K
import concise
from concise.utils.pwm import PWM, pwm_list2pwm_array, pwm_array2pssm_array, DEFAULT_BASE_BACKGROUND
from keras.utils.generic_utils import get_custom_objects

import numpy as np
from scipy.stats import truncnorm
Expand Down Expand Up @@ -54,24 +52,113 @@ def _truncated_normal(mean,
return X


class PSSMBiasInitializer(Initializer):

def __init__(self, pwm_list=[], kernel_size=None, mean_max_scale=0.,
background_probs=DEFAULT_BASE_BACKGROUND):
"""Bias initializer

By defult, it will initialize all weights to 0.

# Arguments
pwm_list: list of PWM's
kernel_size: Has to be the same as kernel_size in kl.Conv1D
mean_max_scale: float; factor for convex conbination between
mean pwm match (mean_max_scale = 0.) and
max pwm match (mean_max_scale = 1.)
background_probs: A dictionary of background probabilities. Default: `{'A': .25, 'C': .25, 'G': .25, 'T': .25}`
"""

# handle pwm_list as a dictionary
if len(pwm_list) > 0 and isinstance(pwm_list[0], dict):
pwm_list = [PWM.from_config(pwm) for pwm in pwm_list]

if kernel_size is None:
kernel_size = len(pwm_list)

_check_pwm_list(pwm_list)
self.pwm_list = pwm_list
self.kernel_size = kernel_size
self.mean_max_scale = mean_max_scale
self.background_probs = background_probs

def __call__(self, shape, dtype=None):
# pwm_array
# print("PWMBiasInitializer shape: ", shape)
pwm = pwm_list2pwm_array(self.pwm_list,
shape=(self.kernel_size, 4, shape[0]),
background_probs=self.background_probs,
dtype=dtype)

pssm = pwm_array2pssm_array(pwm, background_probs=self.background_probs)

# maximum sequence match
max_scores = np.sum(np.amax(pssm, axis=1), axis=0)
mean_scores = np.sum(np.mean(pssm, axis=1), axis=0)

biases = - (mean_scores + self.mean_max_scale * (max_scores - mean_scores))

# ret = - (biases - 1.5 * self.init_motifs_scale)
return biases.astype(dtype)

def get_config(self):
return {
"pwm_list": [pwm.get_config() for pwm in self.pwm_list],
"kernel_size": self.kernel_size,
"mean_max_scale": self.mean_max_scale,
"background_probs": self.background_probs,
}

# TODO - specify the fraction of noise?
# stddev_pwm
# stddev_frac_pssm
#

# scale_glorot feature:
# TODO - add shift_mean_max_scale - this allows you to drop the bias initializer?
# - how to call this argument better?
# TODO - write some unit tests checking the initialization scale
# TODO - finish the PWM initialization example notebook


# TODO - why glorot normal and not uniform?
# TODO - can we have just a single initializer for both, pwm and pssm?

# IDEA - draw first a dirichlet distributed pwm (sums to 1) and then transform it to the pssm
# - how to choose the parameters of the dirichlet distribution?
# - create a histogram of all pwm values (for each base)
#
# related papers: http://web.stanford.edu/~hmishfaq/cs273b.pdf

# alpha * random + (1 - alpha) * motif


class PSSMKernelInitializer(Initializer):
"""Truncated normal distribution shifted by a position-specific scoring matrix (PSSM)
"""Initializer that generates tensors with a
truncated normal initializer shifted by
a position specific scoring matrix (PSSM)

# Arguments
pwm_list: a list of PWM's or motifs
pwm_list: a list of `concise.utils.pwm.PWM`'s
stddev: a python scalar or a scalar tensor. Standard deviation of the
random values to generate.
random values to generate.
seed: A Python integer. Used to seed the random generator.
background_probs: A dictionary of background probabilities.
Default: `{'A': .25, 'C': .25, 'G': .25, 'T': .25}`
add_noise_before_Pwm2Pssm: bool, if True the gaussian noise is added
to the PWM (representing nt probabilities) which is then
transformed to a PSSM with $log(p_{ij}/b_i)$. If False, the noise is added directly to the
PSSM.

Default: `{'A': .25, 'C': .25, 'G': .25, 'T': .25}`
scale_glorot: boolean; If True, each generated filter is min-max scaled to match

resulting PWM's are centered and rescaled
to match glorot_normal distribution.
add_noise_before_Pwm2Pssm: boolean; if True, add random noise before the
pwm->pssm transformation

# TODO - write down the exact formula for this initialization
"""

def __init__(self, pwm_list=[], stddev=0.05, seed=None,
background_probs=DEFAULT_BASE_BACKGROUND,
scale_glorot=True,
add_noise_before_Pwm2Pssm=True):
if len(pwm_list) > 0 and isinstance(pwm_list[0], dict):
pwm_list = [PWM.from_config(pwm) for pwm in pwm_list]
Expand All @@ -82,32 +169,36 @@ def __init__(self, pwm_list=[], stddev=0.05, seed=None,
self.seed = seed
self.background_probs = background_probs
self.add_noise_before_Pwm2Pssm = add_noise_before_Pwm2Pssm
self.scale_glorot = scale_glorot

def __call__(self, shape, dtype=None):
# print("PWMKernelInitializer shape: ", shape)

print("shape: ", shape)
pwm = pwm_list2pwm_array(self.pwm_list, shape, dtype, self.background_probs)

if self.add_noise_before_Pwm2Pssm:
# add noise with numpy truncnorm function
# adding noise on the pwm level
pwm = _truncated_normal(mean=pwm,
stddev=self.stddev,
seed=self.seed)

pssm = pwm_array2pssm_array(pwm, background_probs=self.background_probs)

# Force sttdev to be 0, because noise already added. May just use tf.Variable(pssm)
# return K.Variable(pssm) # this raise error
return K.truncated_normal(shape,
mean=pssm,
stddev=0,
dtype=dtype, seed=self.seed)
stddev_after = 0 # don't need to add any further noise on the PSSM level
else:
pssm = pwm_array2pssm_array(pwm, background_probs=self.background_probs)
return K.truncated_normal(shape,
mean=pssm,
stddev=self.stddev,
dtype=dtype, seed=self.seed)
stddev_after = self.stddev
# Force sttdev to be 0, because noise already added. May just use tf.Variable(pssm)

# TODO - could be problematic if any pwm < 0
pssm = pwm_array2pssm_array(pwm, background_probs=self.background_probs)
pssm = _truncated_normal(mean=pssm,
stddev=stddev_after,
seed=self.seed)
if self.scale_glorot:
# max, min for each motif individually
min_max_range = pssm.max(axis=1).max(0) - pssm.min(axis=1).min(0)
# TODO - wrong! [1, 2] range will just get rescaled but not centered
# i.e. *2 will do : [2, 4] and not [-1, 1]
alpha = _glorot_uniform_scale(shape) * 2 / min_max_range
pssm = alpha * pssm

return K.constant(pssm, dtype=dtype)

def get_config(self):
return {
Expand Down Expand Up @@ -175,6 +266,7 @@ def get_config(self):
}


# TODO pack everything into a single initializer without the bias init?
class PWMKernelInitializer(Initializer):
"""Truncated normal distribution shifted by a PWM

Expand Down Expand Up @@ -209,62 +301,17 @@ def get_config(self):
}


class PWMBiasInitializer(Initializer):
"""Bias initializer complementary to `PWMKernelInitializer`

# Arguments
pwm_list: list of PWM's
kernel_size: Has to be the same as kernel_size in kl.Conv1D
mean_max_scale: float; factor for convex conbination between
mean pwm match (mean_max_scale = 0.) and
max pwm match (mean_max_scale = 1.)
# util functions
def _glorot_uniform_scale(shape):
"""Compute the glorot_uniform scale
"""
# TODO - automatically determined kernel_size

def __init__(self, pwm_list=[], kernel_size=None, mean_max_scale=0.):
# handle pwm_list as a dictionary
if len(pwm_list) > 0 and isinstance(pwm_list[0], dict):
pwm_list = [PWM.from_config(pwm) for pwm in pwm_list]

if kernel_size is None:
kernel_size = len(pwm_list)

self.pwm_list = pwm_list
self.kernel_size = kernel_size
self.mean_max_scale = mean_max_scale
_check_pwm_list(pwm_list)

def __call__(self, shape, dtype=None):
# pwm_array
# print("PWMBiasInitializer shape: ", shape)
pwma = pwm_list2pwm_array(self.pwm_list,
shape=(self.kernel_size, 4, shape[0]),
dtype=dtype)

# maximum sequence match
max_scores = np.sum(np.amax(pwma, axis=1), axis=0)
# mean sequence match = 0.25 * pwm length
mean_scores = np.sum(np.mean(pwma, axis=1), axis=0)

biases = - (mean_scores + self.mean_max_scale * (max_scores - mean_scores))

# ret = - (biases - 1.5 * self.init_motifs_scale)
return biases.astype(dtype)

def get_config(self):
return {
"pwm_list": [pwm.get_config() for pwm in self.pwm_list],
"kernel_size": self.kernel_size,
"mean_max_scale": self.mean_max_scale,
}

fan_in, fan_out = _compute_fans(shape)
return np.sqrt(2 * 3.0 / max(1., float(fan_in + fan_out)))


AVAILABLE = ["PWMBiasInitializer", "PWMKernelInitializer",
"PSSMBiasInitializer", "PSSMKernelInitializer"]


def get(name):
try:
return ki.get(name)
except ValueError:
return get_from_module(name, globals())
return get_from_module(name, globals())
2 changes: 2 additions & 0 deletions tests/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,5 @@ def test_empty_pwm_list(kernel_initializer, bias_initializer):
batch_input_shape=input_shape,
)


# TODO - write the test for glorot_normal