In [41]:
import simdna
import simdna.synthetic as synthetic
from avutils import util
import numpy as np
import momma_dragonn
reload(momma_dragonn)

<module 'momma_dragonn' from '/Users/avantishrikumar/Research/momma_dragonn/momma_dragonn/__init__.pyc'>

In [54]:

def generate_sequences_set(seq_length, num_seqs, motif_names, mean_motifs=1, min_motifs=0, max_motifs=2, zero_prob=0):
    loadedMotifs = synthetic.LoadedEncodeMotifs(simdna.ENCODE_MOTIFS_PATH, pseudocountProb=0.001)
    embedInBackground = synthetic.EmbedInABackground(
        backgroundGenerator=synthetic.ZeroOrderBackgroundGenerator(seqLength=seq_length)
        , embedders=[
            synthetic.RepeatedEmbedder(
            synthetic.SubstringEmbedder(
                #synthetic.ReverseComplementWrapper(
                substringGenerator=synthetic.PwmSamplerFromLoadedMotifs(
                    loadedMotifs=loadedMotifs,motifName=motifName)
                #),
                ,positionGenerator=synthetic.UniformPositionGenerator()),
            quantityGenerator=synthetic.ZeroInflater(synthetic.MinMaxWrapper(
                synthetic.PoissonQuantityGenerator(mean_motifs),
                theMax=max_motifs, theMin=min_motifs), zeroProb=zero_prob)
            )
            for motifName in motif_names
        ]
    )
    sequenceSetGenerator = synthetic.GenerateSequenceNTimes(embedInBackground, num_seqs)
    return sequenceSetGenerator

def one_hot_encode_sequences_set(sequence_set_generator):
    one_hot_encoded_sequences = []
    for sequence in sequence_set_generator.generateSequences():
        one_hot_encoded_sequences.append(avutils.util.seq_to_2d_image(sequence.seq))
    return np.array(one_hot_encoded_sequences)

seq_length=200

In [55]:
one_hot_data_train = one_hot_encode_sequences_set(
                generate_sequences_set(
                    seq_length=seq_length, num_seqs=5000, motif_names=["CTCF_known1"]))

In [56]:
one_hot_data_valid = one_hot_encode_sequences_set(
                        generate_sequences_set(
                        seq_length=seq_length, num_seqs=1000, motif_names=["CTCF_known1"]))

In [53]:
import momma_dragonn
reload(momma_dragonn)
reload(momma_dragonn.data_loaders)
reload(momma_dragonn.data_loaders.core)

def model_creator_func():
    filter_width=20
    maxpool_filter_width=seq_length-(filter_width-1) #pool over entire region for now
    
    from keras.models import Graph
    graph = Graph() 
    graph.add_input(name="sequence", input_shape=(1,4,seq_length))
    #add convolutional layer
    graph.add_node(
        keras.layers.convolutional.Convolution2D(nb_filter=5, nb_row=4, nb_col=filter_width)
        name="conv", input="sequence")
    #add maxpool filter layer
    graph.add_node(
        keras.layers.convolutional.MaxPoolFilter2D(pool_size=(1,maxpool_filter_width), pool_stride=(1,1))
        name="filt", input="conv")
    #add a padding layer so deconv will be the right size
    graph.add_node(
        keras.layers.convolutional.ZeroPadding2D(padding=(0,(filter_width-1)*2)),
        name="padding", input="filt")
    #add a deconv layer to reconstruct the input
    graph.add_node(
        keras.layers.convolutional.Convolution2D(
            nb_filter=4, nb_row=1, nb_col=filter_width,
            W_constraint=keras.constraints.MaxNorm(m=10)),
        name="deconv", input="padding")
    #transpose to make the deconv axis the row axis
    graph.add_node(
        keras.layers.convolutional.ExchangeChannelsAndRows(),
        name="swapaxes", input="deconv")
    #softmax across rows
    graph.add_node(
        keras.layers.convolutional.SoftmaxAcrossRows(),
        name="output_softmax", input="swapaxes")
    #designate output node
    graph.add_output(name="output", input="output_softmax")
    #compile
    graph.compile(
        optimizer=keras.optimizers.Adam(),
        loss={"output": "one_hot_from_logits_categorical_cross_entropy"}
    )
    return graph
#model creator
model_creator = momma_dragonn.model_creators.flexible_keras.KerasModelFromFunc(
    func=model_creator_func,
    model_wrapper_class=momma_dragonn.model_wrappers.keras_model_wrappers.KerasGraphModelWrapper)    

#data loaders
train_data_loader = momma_dragonn.data_loaders.core.AtOnceDataLoader_XYDictAPI(
                        X={'sequence': one_hot_data_train}, Y={'output': one_hot_data_train})
valid_data_loader = momma_dragonn.data_loaders.core.AtOnceDataLoader_XYDictAPI(
                        X={'sequence': one_hot_data_valid}, Y={'output': one_hot_data_valid})
#model evaluator
model_evaluator = momma_dragonn.model_evaluators.GraphAccuracyStats(
    key_metric="onehot_rows_crossent", all_metrics=["onehot_rows_crossent"])

#stopping criterion
stopping_criterion_config = {"class": "EarlyStopping", "kwargs": {"max_epochs": 300, "epochs_to_wait": 10}}

#callbacks
end_of_epoch_callbacks = [momma_dragonn.end_of_epoch_callbacks.PrintPerfAfterEpoch(print_trend=True)]

#trainer
trainer = momma_dragonn.model_trainers.keras_model_trainer.KerasFitGeneratorModelTrainer(
    samples_per_epoch=3000, stopping_criterion_config=stopping_criterion_config)

#train model
trainer.train(model_wrapper=model_creator.get_model_wrapper(),
              model_evaluator=model_evaluator,
              valid_data_loader=valid_data_loader,
              other_data_loaders={'train': train_data_loader},
              end_of_epoch_callbacks=end_of_epoch_callbacks)