In [None]:
import pandas as pd
import seaborn as sns
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from mne.stats import fdr_correction
from tqdm import tqdm
from sklearn.ensemble import RandomForestClassifier
import random 
import subprocess
import matplotlib as mpl
from matplotlib import animation
from sqlalchemy import create_engine, types, dialects

from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)

import warnings
warnings.filterwarnings("ignore")

sns.set(font_scale=1.2)
sns.set_style('ticks')

%config InlineBackend.figure_format = 'svg'
plt.rcParams['figure.dpi'] = 200
plt.rcParams['svg.fonttype'] = 'none'

In [None]:
# samples metadata
filtered_reads_samples_filterdis  = pd.read_pickle('.../metadata.pkl')

In [None]:
# use human/bacterial protein abundance profile to run RF
all_uniref_human_profile = glob('hp_abun/*human_cluster_filter.pkl')
all_uniref_human_profile_df = [pd.read_pickle(x) for x in all_uniref_human_profile]

# all study names
study_name = pd.DataFrame(all_uniref_human_profile)[0].str.replace('hp_abun/','').str.replace('_human_cluster_filter.pkl','').tolist()

In [None]:
# attach "labels" to the metagenomic samples
def return_all_factorized():
    all_df_study = []
    for i in range(len(study_name)):
        temp = all_uniref_human_profile_df[i].reset_index()\
        .merge(filtered_reads_samples_filterdis[filtered_reads_samples_filterdis.study_name\
        == study_name[i]][['sample_id','study_condition']].dropna(), right_on = 'sample_id', \
        left_on = 'index').drop('index', axis = 1).rename(columns = {'study_condition':'y'})
        temp.index = temp.sample_id
        temp.drop('sample_id', axis = 1, inplace = True)
        temp.loc[temp['y']!='control', 'y'] = 1
        temp.loc[temp['y']=='control', 'y'] = 0
        temp['y'] = temp['y'].astype('int')
        temp = temp.reset_index(drop = True)
        all_df_study.append(temp)
    return all_df_study

all_uniref_human_profile_df_fac = return_all_factorized()

In [None]:
# read tuned hyperparams from gridsearch
best_para = pd.read_pickle('RF_final/best_hyperparam.pkl')
best_para = best_para.replace(np.nan, 'None')

In [None]:
plt.rcParams.update({'figure.max_open_warning': 0})
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate
from sklearn.metrics import roc_curve, precision_recall_curve, auc, average_precision_score, precision_score, recall_score, f1_score
from scipy import interp

def run_rf_eval(data, study_name_one):

    X = data.iloc[:,:-1]
    y = data['y']
    results_all = []
    
    min_samples_leaf = best_para[best_para.study_name == study_name_one].min_samples_leaf.values[0]
    min_samples_split = best_para[best_para.study_name == study_name_one].min_samples_split.values[0]
    n_estimators = best_para[best_para.study_name == study_name_one].n_estimators.values[0]
    max_depth = best_para[best_para.study_name == study_name_one].max_depth.values[0]
    
    for i in range(20):
        if max_depth == 'None':
            clf = RandomForestClassifier(n_jobs = 10, bootstrap = True, max_features = 'auto',
                                     min_samples_leaf = min_samples_leaf,
                                    min_samples_split = min_samples_split, 
                                    n_estimators = n_estimators, class_weight = 'balanced')
        else:
            clf = RandomForestClassifier(n_jobs = 10, bootstrap = True, max_features = 'auto',
                                     min_samples_leaf = min_samples_leaf,
                                    min_samples_split = min_samples_split, 
                                    n_estimators = n_estimators, max_depth = max_depth, class_weight = 'balanced')

        results_all.append(pd.DataFrame(cross_validate(clf, X, y, cv=5, n_jobs = 30, scoring=('accuracy', 'average_precision', 
                                                                                              'roc_auc','precision','recall','f1'))))
        results_all_df = pd.concat(results_all)
        results_all_df['study'] = study_name_one
    
    print(study_name_one, ":done")
    
    return results_all_df

In [None]:
# run 5fold CV
all_results_eval = []
for i in range(len(study_name)):
    all_results_eval.append(run_rf_eval(all_uniref_human_profile_df_fac[i], study_name[i]))
    
pd.concat(all_results_eval).to_pickle('RF_final/5fold_CV.after_gridsearch.pkl')

In [None]:
# mean and std
mean_table = pd.concat(all_results_eval).groupby(['study']).mean().drop(['fit_time','score_time', 'test_accuracy','test_average_precision'], axis =1)
std_table = pd.concat(all_results_eval).groupby(['study']).std().drop(['fit_time','score_time', 'test_accuracy','test_average_precision'], axis =1)

In [None]:
# visualize the eval
df = mean_table.sort_values('test_roc_auc', ascending = False)
df.columns = ['AUROC','Precision','Recall','F1']
sns.set(font_scale=1.1)
sns.set_style('white')
fig, ax = plt.subplots(figsize = (3,5))
sns.heatmap(df, cmap='Oranges', annot=True, fmt='.2f', vmin = 0.4)
plt.xticks(rotation=30, fontsize = 12)
plt.ylabel('')