In [1]:
%matplotlib inline
import os, pickle, logging, pickle, joblib, sys, warnings
warnings.simplefilter('ignore')
from scipy import stats
import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import catboost as cb

from sklearn import ensemble, metrics, pipeline, preprocessing, impute, model_selection
from scipy.stats import pearsonr, spearmanr

import shap

font_path = "/home/zhoujb/local/font/Times New Roman.ttf"
mpl.font_manager.fontManager.addfont(font_path)
prop = mpl.font_manager.FontProperties(fname=font_path)
mpl.rcParams['font.family'] = prop.get_name()
mpl.rcParams['font.sans-serif'] = ["Times New Roman"]
mpl.rcParams['axes.unicode_minus'] = False
plt.rcParams['pdf.fonttype'] = 42

RAW_PATH = "/data2/zhoujb/project/cowpea_project/rawData/"
ML_RAW_PATH = "/data2/zhoujb/project/cowpea_project/basedXPXLR/ML/rawData/"
FS_PATH = "/data2/zhoujb/project/cowpea_project/basedXPXLR/ML/fs_PL/"
TEST_RES_PATH = "/data2/zhoujb/project/cowpea_project/basedXPXLR/ML/tesRes/"
FIG_PATH = "/data2/zhoujb/project/cowpea_project/basedXPXLR/ML/figRes/"

In [2]:
feat_col_raw = []
with open(os.path.join(FS_PATH, "cb_rfa_cv_pl_rmse")) as f:
    for line in f:
        feat_col_raw.append(line.strip())

feat_col = []
feat_col_map = {}
for item in feat_col_raw:
    feat_col.append(item[4:])
    feat_col_map[item] = item[4:]


raw_data = pd.read_table(os.path.join(ML_RAW_PATH, "raw_data_PL.txt"), sep="\t", index_col=0)
raw_data = raw_data.rename(columns=feat_col_map)

target_col = ['HZ-PL']
raw_data = raw_data.dropna(subset=target_col)

kf = model_selection.KFold(n_splits=5, shuffle=True,  random_state=0)
y_test_final, y_pred_final = [], []
for i, (train_index, test_index) in enumerate(kf.split(raw_data)):
    data_train = raw_data.iloc[train_index].copy()
    data_test = raw_data.iloc[test_index].copy()

    scale_tool = preprocessing.StandardScaler()
    scale_tool.fit(data_train.loc[:, feat_col])
    data_train.loc[:, feat_col] = scale_tool.transform(data_train.loc[:, feat_col])
    data_test.loc[:, feat_col] = scale_tool.transform(data_test.loc[:, feat_col])

    train_sel = data_train.sample(frac=0.8, random_state=0)
    val_sel = data_train.drop(train_sel.index).copy()

    X_train = train_sel[feat_col].copy()
    y_train = train_sel[target_col].values.ravel()

    X_val = val_sel[feat_col].copy()
    y_val = val_sel[target_col].values.ravel()

    X_test = data_test[feat_col].copy()
    y_test = data_test[target_col].values.ravel()

    if i == 1:
        break

# Initialize CatBoostClassifier
clf_model = cb.CatBoostRegressor(random_state=0, thread_count=4, loss_function='RMSE')
# Fit model
clf_model.fit(X_train, y_train, eval_set=[(X_train, y_train), (X_val, y_val)], verbose=0, plot=False)

y_pred = clf_model.predict(X_test)

#score_pear = pearsonr(y_test, y_pred)[0]
score_spear = spearmanr(y_test, y_pred)[0]
score_rmse = metrics.root_mean_squared_error(y_test, y_pred)
score_nrmse = score_rmse / np.std(y_test)

score_spear, score_rmse, score_nrmse

(0.9126073318342379, 7.743411706248816, 0.43435578958003573)

In [3]:
data_scale = raw_data[feat_col].copy()
data_scale.loc[:, feat_col] = scale_tool.transform(data_scale.loc[:, feat_col])

explainer = shap.TreeExplainer(clf_model)
shap_values = explainer(data_scale[feat_col])

In [4]:
diff_df = pd.DataFrame([y_test, y_pred])
diff_df = diff_df.T
diff_df.columns = ["y_test", "y_pred"]
diff_df["diff"] = np.abs(diff_df["y_test"] - diff_df["y_pred"])
diff_df = diff_df.sort_values(by="diff", ascending=True)
diff_df

Unnamed: 0,y_test,y_pred,diff
29,49.216667,49.324271,0.107604
30,59.55,59.66126,0.11126
3,51.875,51.75496,0.12004
35,41.375,41.64708,0.27208
48,50.275,49.982545,0.292455
54,43.975,44.369394,0.394394
57,17.75,18.206584,0.456584
13,42.405556,41.929512,0.476044
11,64.55,65.027843,0.477843
23,19.55,20.057012,0.507012


In [5]:
detail_info = pd.read_excel(os.path.join(RAW_PATH, "Detail_information_344.xlsx"), index_col="Genotype No.")
detail_info = detail_info.loc[data_scale.index]
detail_info["Type"] = detail_info["Types"].map({'Grain':"G", 'V-landrace':"VL", '-':"NA", 'V-cultivar':"VC"})

Source = detail_info.loc[data_scale.index]["Type"].to_list()

target_df = raw_data[target_col]
target_df = pd.concat([target_df, detail_info[["Type"]]], axis=1)
target_df = target_df.reset_index()

In [6]:
target_df[target_df["HZ-PL"]==59.55]

Unnamed: 0,index,HZ-PL,Type
161,D659,59.55,VC


In [8]:
need_plot_num, need_plot_name = 161, "D659"

with plt.rc_context():
    mpl.rcParams['font.family'] = prop.get_name()
    mpl.rcParams['font.sans-serif'] = ["Times New Roman"]
    mpl.rcParams['axes.unicode_minus'] = False
    
    shap.force_plot(explainer.expected_value,
                    shap_values[:, feat_col].values[need_plot_num],
                    raw_data[feat_col].round(3).loc[need_plot_name, feat_col].values,
                    feat_col,text_rotation=15,
                    matplotlib=True, show = False)
    plt.savefig(os.path.join(FIG_PATH, "fig_3_f_{}_force_plot.pdf".format(need_plot_name)),  format="pdf", dpi=1000, bbox_inches="tight", transparent=True)

In [9]:
target_df[target_df["HZ-PL"]==17.75]

Unnamed: 0,index,HZ-PL,Type
294,D646,17.75,G


In [12]:
need_plot_num, need_plot_name = 294, "D646"

with plt.rc_context():
    mpl.rcParams['font.family'] = prop.get_name()
    mpl.rcParams['font.sans-serif'] = ["Times New Roman"]
    mpl.rcParams['axes.unicode_minus'] = False
    
    shap.force_plot(explainer.expected_value,
                    shap_values[:, feat_col].values[need_plot_num],
                    raw_data[feat_col].round(3).loc[need_plot_name, feat_col].values,
                    feat_col,text_rotation=15,
                    matplotlib=True, show = False)
    
    plt.savefig(os.path.join(FIG_PATH, "fig_3_g_{}_force_plot.pdf".format(need_plot_name)),  format="pdf", dpi=1000, bbox_inches="tight", transparent=True)