In [None]:
from tensorflow import keras
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifier
from keras.models import Sequential
from keras.layers import Dense,Dropout
from keras.layers import LeakyReLU
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc, plot_precision_recall_curve

In [None]:
class ModelDataset(Dataset):
    def __init__(self, x, y, cells_df):
        self.X = torch.tensor(x.to_numpy(dtype='float32'))
        self.Y = torch.tensor(list(y['barcode']))
        self.Xdf = x
        self.Ydf = y
        self.cells = cells_df.loc[x.index,:]
    def __getitem__(self, index):
        return self.X[index], self.Y[index]
    def __len__(self):
        return len(self.X)

In [None]:
model_data = torch.load('/Users/elikond/Desktop/wcr8_model_data__inputs_scRNA_sa_nomt__labels_barcodes.pt')
train_data = model_data['train_data']
test_data = model_data['test_data']
valid_data = model_data['val_data']

In [None]:
X_train = np.array(train_data.Xdf)
X_test = np.array(test_data.Xdf)
X_valid = np.array(valid_data.Xdf)

In [None]:
Y_train = train_data.Ydf
Y_test = test_data.Ydf
Y_valid = valid_data.Ydf

In [None]:
def convert_y(Y):
    process_columns = ["Process {}".format(i+1) for i in range(6)]
    Y[process_columns] = pd.DataFrame(Y['barcode'].tolist(), index= Y.index)
    Y_new = Y.iloc[:,2:].astype(int)
    return np.array(Y_new)

In [None]:
y_train = convert_y(Y_train)
y_test = convert_y(Y_test)
y_valid = convert_y(Y_valid)

In [None]:
def create_model(optimizer="RMSprop",init='uniform'):
    model = Sequential()
    model.add(Dense(16, input_dim=1806, activation=LeakyReLU()))
    model.add(Dropout(0.2))
    model.add(Dense(27, activation='relu'))
    model.add(Dropout(0.3))
    model.add(Dense(7, activation='sigmoid'))
    model.add(Dropout(0.2))
    model.add(Dense(27, activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(6, activation='softmax'))
    
    def sk_pr_auc(y_true, y_pred):
        auprcs = list()
        for i in range(6):
            precision, recall, thresholds = precision_recall_curve(tf.gather(y_true, i, axis=1), tf.gather(y_pred, i, axis=1))
            auc_precision_recall = auc(recall, precision)
            auprcs.append(auc_precision_recall)
        return sum(auprcs) / len(auprcs)
    
    def processes(y_true, y_pred):
        auprs = list()
        for i in range(6):
        precision, recall, thresholds = precision_recall_curve(tf.gather(y_true, i, axis=1), tf.gather(y_pred, i, axis=1))
        auc_precision_recall = auc(recall, precision)
        return auprs

    model.compile(loss='categorical_crossentropy', optimizer=optimizer,
                  metrics=["accuracy", sk_pr_auc, processes], run_eagerly=True)
    return model

model = create_model()

In [None]:
train=model.fit(X_train, y_train, epochs=50, batch_size=50, verbose=0,validation_data=(X_valid,y_valid))

plt.plot(train.history['sk_pr_auc'], label='train')
plt.plot(train.history['val_sk_pr_auc'], label='test')
plt.title('WCR8 Macro-Averaged AUPRC')
plt.xlabel('Epochs')
plt.ylabel('Macro AUPRC')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

grid_theme = {'axes.grid': True}
matplotlib.rcParams.update(grid_theme)

plt.show()

In [None]:
#plt.plot(train.history['sk_pr_auc'], label='Train')
plt.plot(train.history['val_sk_pr_auc'], label='Validation')
plt.title('WCR8 Macro-Averaged AUPRC')
plt.xlabel('Epochs')
plt.ylabel('Macro AUPRC')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

grid_theme = {'axes.grid': True}
plt.rcParams.update(grid_theme)

plt.show()

In [None]:
train.history.keys()

In [None]:
train.history

In [None]:
print(train.history['sk_pr_auc'])

In [None]:
param_grid = {
   
    'epochs': [50,75,100], 
    'batch_size':[32,50,100],
    'optimizer':['RMSprop', 'Adam','SGD'],
    
}

# create model

# Creating Model Object with KerasClassifier
model_cv = KerasClassifier(build_fn=create_model, verbose=0)

grid = GridSearchCV(estimator=model_cv,  
                    n_jobs=-1, 
                    verbose=1,
                    cv=5,
                    param_grid=param_grid)

grid_cv_model = grid.fit(X_train, y_train,) # Fitting the GridSearch Object on the Train Set

# Printing the Best Parameters as a Result of Grid Search Cross Validation on the Screen
print("Best: %f using %s" % (grid_cv_model.best_score_, grid_cv_model.best_params_))

In [None]:
cv_model = grid_cv_model.best_estimator_

In [None]:
pred = model.predict(X_test)

precision = dict()
recall = dict()
plt.figure(figsize=(20,20))
for i in range(6):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        pred[:, i])    
    plt.subplot(1,6,i+1)
    plt.plot(recall[i], precision[i], lw=2)
    plt.title('Process' + str(i))
    
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("WCR8 Test Precision-Recall Curves")
plt.show()

In [None]:
j = 0
for i in range(6):
    precision, recall, thresholds = precision_recall_curve(tf.gather(y_test, i, axis=1), tf.gather(pred, i, axis=1))
    auc_precision_recall = auc(recall, precision)
    j += auc_precision_recall
    print(auc_precision_recall)

In [None]:
y_test_probs = model.predict(X_test)[:,0]
y_test_preds = (model.predict(X_test) > 0.5).astype("int32")[:,0]

precision_scores = []
recall_scores = []

# Define probability thresholds to use, between 0 and 1
probability_thresholds = np.linspace(0, 1, num=100)

# Find true positive / false positive rate for each threshold
for p in probability_thresholds:
    
    y_test_preds = []
    
    for prob in y_test_probs:
        if prob > p:
            y_test_preds.append(1)
        else:
            y_test_preds.append(0)
            
    precision, recall, _ = precision_recall_curve(y_test[:,0], y_test_preds)
    print(recall)    
    
    precision_scores.append(precision)
    recall_scores.append(recall)

plt.plot(recall_scores, precision_scores, lw=2)