In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format ='retina'

In [2]:
import tensorflow as tf
import numpy as np

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    gpu_num = 1
    try:
        tf.config.experimental.set_visible_devices(gpus[gpu_num], 'GPU')
        #tf.config.experimental.set_memory_growth(gpus[gpu_num], True)
        tf.config.experimental.set_virtual_device_configuration(
            gpus[gpu_num],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4000)])
    except RuntimeError as e:
        print(e)

In [23]:
from data_import_preprocessing import import_data_preprocessing
from model_define import CNNLSTM_model, CNNLSTM_offtarget_model

In [6]:
train_data_name = 'data/indel_simpletargetcompare_joinedun_081308231007combined.csv'
#test_data_name = 'data/nb_data_changed_HT_1-2.csv'
preprocessing = import_data_preprocessing(train_data_file_name = train_data_name,
                                          #test_data_file_name= test_data_name,
                                         )
print('train_data : ', preprocessing.train_data_file_sample_column)
#print('train_data : ', preprocessing.test_data_file_sample_column)

train_data :  ['barcode', 'sgRNA', 'target', 'Number_mismatch', 'total_readcnt', 'mutated_readcnt', 'indel_rate']


In [7]:
data = preprocessing(sgRNA_column='sgRNA',
                     indel_column='indel_rate',
                     targetDNA_column='target',
                     offtarget=True,
                     mismatch_calc=True,
                     split_data=0.1,
                     sgRNA_have_NGG=False
                    )
print(data.keys())
print(data['train'].keys())

dict_keys(['train', 'val', 'total', 'test'])
dict_keys(['seq', 'indel_rate', 'indel_class', 'mis_position', 'mis_number', 'read_cnt', 'info'])


In [19]:

input_shape = data['train']['seq'].shape[1:]
print(data['train']['seq'].shape)

(73917, 23, 4, 2)


In [29]:
X_train = data['train']['seq']
class_train = data['train']['indel_class']
rate_train = data['train']['indel_rate']

X_val = data['val']['seq']
class_val = data['val']['indel_class']
rate_val = data['val']['indel_rate']

X_test = data['val']['seq']
class_test = data['val']['indel_class']
rate_test = data['val']['indel_rate']

In [24]:
CNNLSTM = CNNLSTM_offtarget_model(input_shape=input_shape)

In [25]:
CNNLSTM_model = CNNLSTM.MTL_model()
CNNLSTM_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 23, 4, 2)]   0                                            
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 23, 1, 128)   1152        input_5[0][0]                    
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 23, 1, 128)   2176        input_5[0][0]                    
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 23, 1, 128)   3200        input_5[0][0]                    
______________________________________________________________________________________________

In [26]:
CNNLSTM_model.compile(optimizer='adam',
              loss={
                  #'num_mis': 'categorical_crossentropy',
                  'class_1': 'categorical_crossentropy',
                  'class_2': 'categorical_crossentropy',
                  'class_final': 'categorical_crossentropy',
                  'rate': 'mean_squared_error'},
              loss_weights={
                  #'num_mis': 0.5,
                  'class_1': 1,
                  'class_2': 1,
                  'class_final': 0.5,
                  'rate': 1},
              metrics={
                  #'num_mis': 'accuracy',
                  'class_1': "accuracy",
                  'class_2': "accuracy",
                  'class_final': "accuracy"}
             )

In [27]:
from sklearn.utils import class_weight
class_train_num = class_train.argmax(axis=-1)
class_weights = class_weight.compute_class_weight('balanced',
                                                  np.unique(class_train_num),
                                                  class_train_num
                                                 )
class_weights_dict = dict(enumerate(class_weights))#{ i : class_weights[i] for i in range(11)}
print("class_weight")
print(class_weights_dict)

class_wieghts_dict_tuned = class_weights_dict
for i in range(0,11):
    multiple_constant = 2
    cutoff_class = 5
    
    if i <=cutoff_class:
        class_wieghts_dict_tuned[i] = class_weights_dict[i]#1.0
    else:
        class_wieghts_dict_tuned[i] = class_weights_dict[i]#multiple_constant*np.tanh((class_weights_dict[i]-2)/2) + multiple_constant
    
print("sample_weight")
print(class_wieghts_dict_tuned)

sample_weight = np.array([class_wieghts_dict_tuned[i] for i in class_train_num])

class_weight
{0: 0.2465593040554514, 1: 0.6256146795202749, 2: 0.807659527972028, 3: 1.0341223873079828, 4: 1.554772622102562, 5: 1.948876819236448, 6: 3.762445281482236, 7: 4.834336167429693, 8: 5.962490925223844, 9: 7.272432113341204, 10: 0.8288796438543571}
sample_weight
{0: 0.2465593040554514, 1: 0.6256146795202749, 2: 0.807659527972028, 3: 1.0341223873079828, 4: 1.554772622102562, 5: 1.948876819236448, 6: 3.762445281482236, 7: 4.834336167429693, 8: 5.962490925223844, 9: 7.272432113341204, 10: 0.8288796438543571}


In [28]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                  min_delta=0.0005,
                                                  patience=50, verbose=0, mode='auto')

CNNLSTM_model.fit(x=X_train,
              y={
                  #'num_mis': num_mis_train,
                  'class_1': class_train,
                  'class_2': class_train,
                  'class_final': class_train,
                  'rate': rate_train},
              validation_data=(X_val, {#'num_mis': num_mis_val,
                                       'class_1': class_val,
                                       'class_2': class_val,
                                       'class_final': class_val,
                                       'rate': rate_val}),
              class_weight={
                  'class_1' : class_weights_dict,
                  'class_2' : class_weights_dict,
                  'class_final' : class_weights_dict},
              sample_weight={'rate' : sample_weight},
              shuffle=True,
              epochs=200,
              batch_size=128,
              verbose=1,
              callbacks=[early_stopping]
             )

Train on 73917 samples, validate on 8214 samples
Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200


Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200


Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200


Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200
Epoch 56/200
Epoch 57/200
Epoch 58/200


<tensorflow.python.keras.callbacks.History at 0x7f0d770ebc18>

In [None]:
test_prediction = regression_model.predict(X_test)
test_prediction = np.array(test_prediction).reshape(-1,)
test_true = rate_test.reshape(-1,)

test_correlation = np.corrcoef(test_prediction, test_true).flatten()[1]
pearson_corr = pearsonr(test_prediction, test_true)
spearman_corr = spearmanr(test_prediction, test_true)
print('Correlation : {}'.format(test_correlation))
print('Pearson Correlation : {}'.format(pearson_corr[0]))
print('Spearman Correlation : {}'.format(spearman_corr[0]))