In [12]:
import pandas as pd
import numpy as np
import time 
import pixiedust
import sys
import ipdb 

from pomegranate import State, DiscreteDistribution, HiddenMarkovModel
from sklearn.model_selection import train_test_split
from utils import load_gfp_data, count_substring_mismatch, get_all_amino_acids, get_wild_type_amino_acid_sequence
from hmm import GenerativeHMM, hmm_amino_acid_args
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
print("Loading data...")
start_time = time.time()
X_train, X_test, y_train, y_test = load_gfp_data("./data/gfp_amino_acid_")
mutated_df = load_saved_mutated_gfp_data()
print("Finished loading data in {0:.2f} seconds".format(time.time() - start_time))

Loading data...
Finished loading data in 1.62 seconds


In [14]:
wild_type_amino_acid = get_wild_type_amino_acid_sequence()
assert(X_train[0] == wild_type_amino_acid)
assert(count_substring_mismatch(wild_type_amino_acid, X_train[1000]) == 8)
assert(count_substring_mismatch(wild_type_amino_acid, mutated_df["mutated_amino_acid_sequence"].values[0]) == 1)

In [15]:
def get_data(X_train, length, n = 100, random=True): 
    if not random: 
        data = X_train[0:length]
    else: 
        indexes = np.random.choice(len(X_train), n)
        data = X_train[indexes]
    return np.array([list(x)[0:length] for x in data])

def sample_and_score(hmm, wild_type, n = 100, length = 100, logger = None):
    """
    use the hmm model to sample n sequences of size = length. 
    then use the wild_type to count how far off the average sample is from the wild_type
    """
    assert(len(wild_type) == length)
    samples = hmm.sample(n, length)        
    average_diff = np.average([count_substring_mismatch(seq, wild_type) for seq in samples])
    print("Average difference: {0:.2f}, or {1:.2f} mismatches per letter".format(average_diff, 
                                                                 average_diff / length), file = logger)
    print("Example sequence {0}".format(samples[np.random.randint(0, n)]), file = logger)
    return average_diff

small_length, medium_length, large_length = 15, len(wild_type_amino_acid) // 4, len(wild_type_amino_acid)
small_X = get_data(X_train, small_length, 100)
medium_X = get_data(X_train, medium_length, 100)
large_X = get_data(X_train, large_length, 100)

In [16]:
diffs = [count_substring_mismatch(i, wild_type_amino_acid[0:small_length]) for i in small_X]
print("Small diffs:", diffs)
diffs = [count_substring_mismatch(i, wild_type_amino_acid[0:medium_length]) for i in medium_X]
print("Medium diffs:", diffs)
diffs = [count_substring_mismatch(i, wild_type_amino_acid[0:large_length]) for i in large_X]
print("Large diffs:", diffs)

Small diffs: [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Medium diffs: [1, 0, 0, 4, 1, 0, 1, 3, 2, 1, 4, 1, 1, 0, 0, 0, 1, 2, 2, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 2, 1, 0, 2, 2, 1, 0, 0, 3, 2, 0, 0, 1, 0, 3, 3, 1, 1, 1, 1, 0, 0, 0, 1, 2, 0, 0, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 2, 2]
Large diffs: [6, 2, 1, 3, 1, 6, 3, 3, 6, 4, 2, 2, 7, 2, 5, 2, 4, 6, 6, 2, 3, 1, 1, 2, 4, 6, 4, 5, 3, 2, 5, 3, 2, 8, 4, 9, 2, 3, 3, 2, 3, 3, 7, 2, 2, 2, 4, 2, 2, 2, 4, 2, 4, 7, 3, 4, 8, 4, 5, 5, 2, 7, 4, 8, 5, 5, 4, 6, 5, 4, 2, 3, 5, 2, 1, 8, 6, 5, 6, 5, 2, 2, 2, 3, 2, 3, 2, 2, 4, 6, 7, 3, 3, 1, 3, 1, 2, 3, 8, 2]


In [21]:
def get_data(X_train, length, n = 100, random = True): 
    """
    gets n random sequences of size length from the dataset X_train
    """
    if not random: 
        data = X_train[0:length]
    else: 
        indexes = np.random.choice(len(X_train), n)
        data = X_train[indexes]
    return np.array([list(x[0:length]) for x in data])

def sample_and_score(hmm, base_str, n = 100, length = 100, logger = None):
    """
    use the hmm model to sample n sequences of size = length. 
    then use the wild_type to count how far off the average sample is from the wild_type
    prints all results in the logger file
    """
    assert(len(base_str) == length)
    samples = hmm.sample(n, length)        
    average_diff = np.mean([count_substring_mismatch(seq, base_str) for seq in samples])
    print("Average difference: {0:.2f}, or {1:.2f} mismatches per letter".format(average_diff, 
                                                                 average_diff / length), file = logger)
    print("Example sequence {0}".format(samples[np.random.randint(0, n)]), file = logger)
    return average_diff

def train_and_save_hmm(X, args):
    start_time = time.time()
    hmm = GenerativeHMM(args)
    logger = None
    hmm.fit(X)
    print("Finished training in {:.2f} seconds".format(time.time() - start_time), file = logger)
    print("HMM Parameters:", file = logger)
    print(hmm.get_args(), file = logger)
    sample_and_score(hmm, wild_type_amino_acid[0:args["length"]], 100, args["length"], logger = logger)
    wild_type_prob = np.e ** hmm.predict([list(wild_type_amino_acid[0:args["length"]])])
    mutation_prob = np.e ** hmm.predict([list(wild_type_amino_acid[0:args["length"] - 3] + "ACG")])
    print("Wild type prob: {0}. Mutation prob: {1}".format(wild_type_prob, mutation_prob), file = logger)
    model_path = "./models/{0}.json".format(hmm.name)
    hmm.save_model(model_path)
    cached_hmm = GenerativeHMM(args)
    cached_hmm.load_model(model_path)
    try: 
        for i in get_all_amino_acids():
            for j in get_all_amino_acids(): 
                np.testing.assert_almost_equal(hmm.predict([list(i + j)]), cached_hmm.predict([list(i + j)]))
        print("Successfully finished training and saving {0} model!".format(hmm.name), file = logger)
        if logger: logger.close()
    except:
        for i in get_all_amino_acids():
            for j in get_all_amino_acids(): 
                print(hmm.predict([list(i + j)]), cached_hmm.predict([list(i + j)]), file = logger)
        print("Error in loading {0} hmm".format(hmm.name), file = logger)
        if logger: logger.close()

def get_args(parser_args):
    args = hmm_amino_acid_args()
    args["n_jobs"] = parser_args.n_jobs
    args["hidden_size"] = parser_args.hidden_size
    args["max_iterations"] = parser_args.max_iterations
    args["name"] = parser_args.name
    args["length"] = parser_args.length
    return args

def get_base_args():
    base_args = hmm_amino_acid_args()
    base_args["name"] = "hmm_base"
    base_args["max_iterations"] = 100
    base_args["hidden_size"] = 20
    base_args["n_jobs"] = 10
    base_args["length"] = 15
    return base_args

In [22]:
train_and_save_hmm(small_X, get_base_args())

[1] Improvement: 1398.7271447617622	Time (s): 0.0424
[2] Improvement: 100.26230049001333	Time (s): 0.04371
[3] Improvement: 167.0684396539914	Time (s): 0.04403
[4] Improvement: 307.7334251848615	Time (s): 0.04365
[5] Improvement: 503.4235540154327	Time (s): 0.04374
[6] Improvement: 623.5416012330415	Time (s): 0.04401
[7] Improvement: 515.8650776445228	Time (s): 0.04417
[8] Improvement: 313.81438339190197	Time (s): 0.04363
[9] Improvement: 280.9509821049363	Time (s): 0.04636
[10] Improvement: 226.54593051518884	Time (s): 0.04502
[11] Improvement: 65.89640639356455	Time (s): 0.04528
[12] Improvement: 31.125666125417297	Time (s): 0.04591
[13] Improvement: 8.982763983025407	Time (s): 0.04537
[14] Improvement: 0.6497688962727182	Time (s): 0.04544
[15] Improvement: 0.004298003109752813	Time (s): 0.04541
[16] Improvement: 2.1676243022739072e-09	Time (s): 0.0467
[17] Improvement: 0.0	Time (s): 0.04737
Total Training Improvement: 4544.59174239921
Total Training Time (s): 0.9058
Finished trainin

In [21]:
"""
10, 100 sequences, 100 iterations, 100 sequences. 
10, 200, 1e8, 100 sequences, 
10, 200, 1000, 100 sequences. 
10, 500, 1000, 100 sequences. 
10, 200, 1000, 10000 seqeunces. 


## Fit 3 types.
## Fit small data -> large data. 
## Fit with different hidden sizes 10, 50, 200. 
## fit until it 1e8, 1e2 iterations
## all with more cores 5
## record times of all these. 
"""




'\n10, 100 sequences, 100 iterations, 100 sequences. \n10, 200, 1e8, 100 sequences, \n10, 200, 1000, 100 sequences. \n10, 500, 1000, 100 sequences. \n10, 200, 1000, 10000 seqeunces. \n\n\n## Fit 3 types.\n## Fit small data -> large data. \n## Fit with different hidden sizes 10, 50, 200. \n## fit until it 1e8, 1e2 iterations\n## all with more cores 5\n## record times of all these. \n'