In [None]:
import pickle
import numpy as np
import pandas as pd
from keras import models
from keras.layers import Dense, Concatenate, Dropout, BatchNormalization
from keras.callbacks import EarlyStopping
from keras.regularizers import l2
import time
import os
import sys
import CNN_functions as CNNFuncs

In [None]:
### Dataset Parameters
######################
test_size = 0.2
random_state = 1
model_mode = "RCC"
GOIs_mode = "top100"
extract_genes_no = 8000
extract_neg_selec = False

### Model Parameters
####################
activation_func = 'relu' 
activation_func2 = 'linear'
kernel_initializer = 'he_uniform'
dense_layer_dim = 500
l2_reg = None
batch_size = 250
num_epoch = 300


In [None]:
model_datasets, tRCC_datasets, essential_genes = CNNFuncs.preprocess_data(test_size=test_size,
                                                                          random_state=random_state,
                                                                          mode=model_mode,
                                                                          extract_top_genes=extract_genes_no,
                                                                          extract_neg_selec=extract_neg_selec)

X_train = model_datasets[0]
Y_train = model_datasets[2]

X_test = model_datasets[1]
Y_test = model_datasets[3]

tRCC_gene_exp = tRCC_datasets[0]
tRCC_gene_effect = tRCC_datasets[1]

top100_essential_genes = essential_genes["top100_essential_genes"]
top_common_essential_genes = essential_genes["common_essential_genes"]

In [None]:
model = models.Sequential()
model.add(Dense(500, input_dim=X_train.shape[1], activation=activation_func, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))
model.add(Dense(300, activation=activation_func, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))
model.add(Dense(150, activation=activation_func, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))
model.add(Dense(75, activation=activation_func, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))
model.add(Dense(dense_layer_dim, activation=activation_func, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))
model.add(Dense(dense_layer_dim, activation=activation_func, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))
model.add(Dense(dense_layer_dim, activation=activation_func, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))
model.add(Dense(dense_layer_dim, activation=activation_func, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))
model.add(Dense(Y_train.shape[1], activation=activation_func2, kernel_initializer=kernel_initializer, kernel_regularizer=l2(l2_reg) if l2_reg != None else None))

model.compile(loss='mse', optimizer='adam') 

t = time.time()

early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='min')
history = model.fit(X_train, Y_train, epochs=num_epoch, validation_split=1/9, batch_size=batch_size, shuffle=True, callbacks=[early_stopping_callback])

cost_testing = model.evaluate(X_test, Y_test, verbose=0, batch_size=batch_size)

loss = history.history['loss'][-1]
val_loss = history.history['val_loss'][-1]

print("\n\nModel training completed in %.1f mins.\nloss:%.4f valloss:%.4f testloss:%.4f" % ((time.time() - t)/60, loss, val_loss, cost_testing))

In [None]:
predictions = pd.DataFrame(model.predict(X_test), columns=X_test.columns, index=X_test.index)

In [None]:
CNNFuncs.GeneRelMiniPlot(X_test, Y_test, top100_essential_genes[:30].to_list(),
                         rel2_X=X_test,
                         rel2_Y=predictions,
                         main_title="CNN-trained Model Predictions on {0} CCL Test Dataset (num_epochs={1}, num_samples={3}, num_features={2})".format(model_mode, num_epoch, X_train.shape[1], X_train.shape[0]),
                         axes_labels=["Gene Expression", "Gene Effect Score"],
                         axis_label_fontsize=9.75,
                         legend_titles={"rel1": "Testing CCL", "rel2": "Testing CCL Prediction"},
                                        sizeProps=[0.925, 0.08, 0.99, 0],
                                        wspace=0.35,
                                        hspace=0.55)

In [None]:
pred_RMSE_top100_essential_genes = CNNFuncs.PredGeneRMSE(Y_test, predictions, top100_essential_genes)

In [None]:
CNNFuncs.LinePlot(pred_RMSE_top100_essential_genes, 
            {"RMSE": "red"}, 
            axes_labels=["Top 100 Essential Genes", 
                         "Root Mean Squared Error"], 
            main_title="CNN-trained Model RMSE on {0} CCL Test Dataset (num_epochs={1}, num_samples={3}, num_features={2})".format(model_mode, num_epoch, X_train.shape[1], X_train.shape[0]),
            legend_titles={"RMSE": "{0} Testing RMSE".format(model_mode)})

In [None]:
print("Average RMSE for predicted gene effect scores on testing subset of CCLs: {0:.4}".format(pred_RMSE_top100_essential_genes["RMSE"].mean()))

In [None]:
save_bool = input("Save Model? y/n: ")

if save_bool == "y":
    
    model_name = input("Model Name: ")

    if os.path.exists("models/{0}/{1}".format(model_mode, model_name)):
        raise ValueError("model already exists")
    
    os.mkdir("models/{0}/{1}".format(model_mode, model_name))

    model.save("models/{0}/{1}/model.h5".format(model_mode, model_name))
    print("saved model to {0}".format("models/{0}/{1}/model.h5".format(model_mode, model_name)))

    model_info = {"test_size": test_size,
                  "random_state": random_state,
                  "model_mode": model_mode,
                  "extract_genes_no": extract_genes_no,
                  "extract_neg_selec": extract_neg_selec,
                  "dense_layer_dim": dense_layer_dim,
                  "l2_reg": l2_reg,
                  "num_epoch": num_epoch}
    
    pd.DataFrame.from_dict(model_info, orient='index').to_csv("models/{0}/{1}/model_info.csv".format(model_mode, model_name))