# Data 


In [None]:
import logging 
import numpy as np
import importlib
import pandas as pd
from os.path import join
from data.data_access import Data
from preprocessing import pre
from utils.logs import set_logging
from config_path import ENH_LOG_PATH

params_file = 'train/params/P1000/pnet/onsplit_average_reg_10_tanh_large_testing.py'

log_dir = join(ENH_LOG_PATH, 'log')
log_dir = log_dir
set_logging(log_dir)

In [None]:
loader = importlib.machinery.SourceFileLoader('params', params_file)
params = loader.load_module()   


In [None]:
params  

In [None]:
data = Data(**params.data[0])

In [None]:
x_train, x_validate_, x_test_, y_train, y_validate_, y_test_, info_train, info_validate_, info_test_, cols = data.get_train_validate_test()

In [None]:
import csv

info_t = info_test_


In [None]:
print("x_train:", x_train.shape)
print("x_train:", y_train.shape)

print("x_validate_:", x_validate_.shape)
print("y_validate_:", y_validate_.shape)

print("x_test:", x_test_.shape)
print("y_test:", y_test_.shape)

print("columns:", len(cols))

# Model

In [None]:
from copy import deepcopy
from model import nn

In [None]:
model_params_ = deepcopy(params.models[0])

In [None]:
model = nn.Model(**model_params_['params'])

In [None]:
history = model.fit(x_train, y_train, x_validate_, y_validate_)

In [None]:
y_pred_test = model.predict(x_test_)

from sklearn import metrics
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
 
def evaluate_classification_binary(y_test, y_pred, y_pred_score=None):
    accuracy = accuracy_score(y_test, y_pred)
    if y_pred_score is None:
        fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred, pos_label=1)
    else:
        fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred_score, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    f1 = metrics.f1_score(y_test, y_pred)
    precision = metrics.precision_score(y_test, y_pred)
    recall = metrics.recall_score(y_test, y_pred)
    logging.info(metrics.classification_report(y_test, y_pred))
    from sklearn.metrics import average_precision_score
    aupr = average_precision_score(y_test, y_pred_score)
    score = {}
    score['accuracy'] = accuracy
    score['precision'] = precision
    score['auc'] = auc
    score['f1'] = f1
    score['aupr'] = aupr
    score['recall'] = recall

    # plot auc curve
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')

    # Save the ROC plot
    plt.savefig('roc_curve.png')

    # Show the ROC plot
    plt.show()

    return score

if hasattr(model, 'predict_proba'):
    y_pred_test_scores = model.predict_proba(x_test_)[:, 1]
else:
    y_pred_test_scores = y_pred_test

test_score = evaluate_classification_binary(y_test_, y_pred_test, y_pred_test_scores)

logging.info('Test score {}'.format(test_score))



In [None]:
# !python analysis/run_it_all.py

model_name = 'P-net'

# save prediction data for plot generation

def save_prediction(info, y_pred, y_pred_scores, y_test, model_name, training=False):

        if training:
            file_name = join(model_name + '_training.csv')
        else:
            file_name = join(model_name + '_testing.csv')
        info = pd.DataFrame(index=info)
        print(('info', info))
        print(y_test)
        info['pred'] = y_pred
        info['pred_scores'] = y_pred_scores

        # survival case
        # https://docs.scipy.org/doc/numpy/user/basics.rec.html
        if y_test.dtype.fields is not None:
            fields = y_test.dtype.fields
            for f in fields:
                info['y_{}'.format(f)] = y_test[f]
        else:
            info['y'] = y_test
        info.to_csv(file_name)

save_prediction(info_t, y_pred_test, y_pred_test_scores, y_test_, model_name)

# save model

filename = '/PROJECTS/Sally/PNET_py3_enh_gene/_logs/enh_vs_genes/log/fs/P-net.h5'
model.save_model(filename)

# filename = join(log_dir, 'fs')
# filename = join(filename, model_name + '.h5')
# # if not exists(filename.strip()):
# #     makedirs(filename)

# save model weights 

# w_filename = '/PROJECTS/Sally/PNET_py3_enh_gene/_logs/enh_vs_genes/log/fs/P-net_weights.h5'

# model.save_model(w_filename)

# load model 

model.load_model(filename)

#load model weights

# model = model.load_model(w_filename)



In [None]:
# shap implementation

import shap   
# from tensorflow.keras.models import load_model, model_from_json

sample = x_test_[1:11]

# sample_output = model.predict(sample)

# # Find the indices where the value is 1
# indices_of_ones = np.where(sample_output == 1)[0]

# # Print the indices
# print("Indices of '1':", indices_of_ones)

# print('sample_output:', sample_output)

from config_path import INTERACTIONS_PATH

gene_names = join(INTERACTIONS_PATH, 'enh_vs_genes_selected_genes.csv')

# Read the CSV file
df = pd.read_csv(gene_names)

# Extract gene names into a list
genes_list = df['genes'].tolist()

explainer = shap.Explainer(model, x_test_)
shap_values = explainer.shap_values(sample)
shap.summary_plot(shap_values, sample, feature_names=genes_list)





In [None]:
# shap.plots.force(explainer.expected_value, shap_values)