In [1]:
import pandas as pd
import numpy as np
import time 
import pixiedust

from pomegranate import State, DiscreteDistribution, HiddenMarkovModel
from sklearn.model_selection import train_test_split
from utils import *
from hmm import GenerativeHMM, hmm_amino_acid_args

%load_ext autoreload
%autoreload 2


Pixiedust database opened successfully


Unable to check latest version <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:852)>


In [2]:
print("Loading data...")
start_time = time.time()
X_train, X_test, y_train, y_test = load_gfp_data("./data/gfp_amino_acid_")
mutated_df = load_saved_mutated_gfp_data()
print("Finished loading data in {0:.2f} seconds".format(time.time() - start_time))

Loading data...
Finished loading data in 3.21 seconds


In [3]:
wild_type_amino_acid = get_wild_type_amino_acid_sequence()
assert(X_train[0] == wild_type_amino_acid)
assert(count_substring_mismatch(wild_type_amino_acid, X_train[1000]) == 8)
assert(count_substring_mismatch(wild_type_amino_acid, mutated_df["mutated_amino_acid_sequence"].values[0]) == 1)

In [5]:
def get_data(X_train, length, n = 100, random=True): 
    if not random: 
        data = X_train[0:length]
    else: 
        indexes = np.random.choice(len(X_train), n)
        data = X_train[indexes]
    return np.array([list(x)[0:length] for x in data])

def sample_and_score(hmm, wild_type, n = 100, length = 100, file = None):
    """
    use the hmm model to sample n sequences of size = length. 
    then use the wild_type to count how far off the average sample is from the wild_type
    """
    assert(len(wild_type) == length)
    samples = hmm.sample(n, length)        
    average_diff = np.average([count_substring_mismatch(seq, wild_type) for seq in samples])
    print("Average difference: {0:.2f}, or {1:.2f} mismatches per letter".format(average_diff, 
                                                                 average_diff / length), file = file)
    print("Example sequence {0}".format(samples[np.random.randint(0, n)]), file = file)
    return average_diff

small_length, medium_length, large_length = 15, len(wild_type_amino_acid) // 4, len(wild_type_amino_acid)
small_X = get_data(X_train, small_length, 100)
medium_X = get_data(X_train, medium_length, 100)
large_X = get_data(X_train, large_length, 100)

In [6]:
diffs = [count_substring_mismatch(i, wild_type_amino_acid[0:small_length]) for i in small_X]
print("Small diffs:", diffs)
diffs = [count_substring_mismatch(i, wild_type_amino_acid[0:medium_length]) for i in medium_X]
print("Medium diffs:", diffs)
diffs = [count_substring_mismatch(i, wild_type_amino_acid[0:large_length]) for i in large_X]
print("Large diffs:", diffs)

Small diffs: [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Medium diffs: [0, 3, 3, 1, 0, 1, 4, 0, 1, 0, 1, 1, 2, 0, 0, 0, 0, 2, 2, 2, 0, 2, 2, 3, 1, 1, 1, 0, 1, 0, 1, 4, 1, 0, 2, 1, 1, 1, 0, 2, 1, 2, 1, 0, 1, 1, 1, 0, 1, 1, 2, 0, 2, 3, 2, 0, 0, 1, 0, 2, 1, 0, 2, 1, 0, 0, 1, 1, 2, 1, 2, 1, 0, 1, 0, 1, 0, 2, 0, 1, 2, 1, 2, 0, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 1, 2, 0, 3, 1, 0]
Large diffs: [4, 3, 2, 3, 1, 5, 3, 3, 1, 5, 4, 5, 2, 3, 4, 3, 2, 4, 3, 2, 6, 3, 2, 2, 3, 3, 2, 5, 5, 5, 4, 5, 4, 5, 2, 1, 3, 11, 5, 4, 4, 2, 3, 2, 2, 2, 3, 4, 5, 2, 3, 2, 7, 3, 4, 2, 5, 2, 2, 4, 6, 5, 2, 6, 2, 3, 1, 2, 1, 4, 2, 1, 2, 2, 6, 3, 7, 2, 2, 1, 8, 4, 5, 3, 3, 2, 6, 2, 3, 2, 3, 5, 3, 2, 1, 2, 0, 3, 5, 2]


In [19]:
def train_and_save_hmm(args):
    start_time = time.time()
    hmm = GenerativeHMM(args)
    hmm.fit(small_X)
    log_path = "./logs/{0}.txt".format(hmm.name)
    logger = open(log_path, "w")
    print("Finished training in {:.2f} seconds".format(time.time() - start_time), file = logger)
    print("HMM Parameters:", file = logger)
    print(hmm.get_args(), file = logger)
    sample_and_score(hmm, wild_type_amino_acid[0:small_length], 100, small_length, file = logger)
    model_path = "./models/{0}.json".format(hmm.name)
    hmm.save_model(model_path)
    cached_hmm = GenerativeHMM(args)
    cached_hmm.load_model(model_path)
    try: 
        for i in get_all_amino_acids():
            for j in get_all_amino_acids(): 
                np.testing.assert_almost_equal(hmm.predict([list(i + j)]), cached_hmm.predict([list(i + j)]))
        print("Successfully finished training and saving {0} hmm!".format(hmm.name), file = logger)
        logger.close()
    except:
        print("Error in loading {0} hmm".format(hmm.name))
        logger.close()

def get_small_args():
    base_args = hmm_amino_acid_args()
    base_args["name"] = "amino_acid_small"
    base_args["max_iterations"] = 100
    base_args["hidden_size"] = 50
    base_args["n_jobs"] = 5
    return base_args

In [20]:
train_and_save_hmm(get_small_args())

[1] Improvement: 1272.6339987469269	Time (s): 0.1651
[2] Improvement: 87.35922020702628	Time (s): 0.1985
[3] Improvement: 162.9487532171006	Time (s): 0.2252
[4] Improvement: 283.4199117419703	Time (s): 0.2014
[5] Improvement: 401.2173105789716	Time (s): 0.2375
[6] Improvement: 660.0943599881184	Time (s): 0.2318
[7] Improvement: 632.2907845468504	Time (s): 0.2319
[8] Improvement: 375.04132737638224	Time (s): 0.2284
[9] Improvement: 251.74877640948654	Time (s): 0.1704
[10] Improvement: 152.5301803204522	Time (s): 0.2084
[11] Improvement: 79.67922144952371	Time (s): 0.2547
[12] Improvement: 47.999740636191376	Time (s): 0.2205
[13] Improvement: 31.80944296712528	Time (s): 0.1831
[14] Improvement: 13.116383630731534	Time (s): 0.207
[15] Improvement: 3.4367271744879417	Time (s): 0.1779
[16] Improvement: 1.1882763113427046	Time (s): 0.1402
[17] Improvement: 0.2543010032409114	Time (s): 0.1362
[18] Improvement: 0.012192319394017659	Time (s): 0.1333
[19] Improvement: 3.2047428106807274e-05	Time

In [21]:
"""
10, 100 sequences, 100 iterations, 100 sequences. 
10, 200, 1e8, 100 sequences, 
10, 200, 1000, 100 sequences. 
10, 500, 1000, 100 sequences. 
10, 200, 1000, 10000 seqeunces. 


## Fit 3 types.
## Fit small data -> large data. 
## Fit with different hidden sizes 10, 50, 200. 
## fit until it 1e8, 1e2 iterations
## all with more cores 5
## record times of all these. 
"""




'\n10, 100 sequences, 100 iterations, 100 sequences. \n10, 200, 1e8, 100 sequences, \n10, 200, 1000, 100 sequences. \n10, 500, 1000, 100 sequences. \n10, 200, 1000, 10000 seqeunces. \n\n\n## Fit 3 types.\n## Fit small data -> large data. \n## Fit with different hidden sizes 10, 50, 200. \n## fit until it 1e8, 1e2 iterations\n## all with more cores 5\n## record times of all these. \n'