In [1]:
import os
import json
import pickle
import hashlib
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, balanced_accuracy_score, matthews_corrcoef, confusion_matrix
from sklearn.ensemble import RandomForestClassifier

In [5]:
## ********** data preprocessing functions **********

def clean_data(X):
  features_names = list(X)
  return X, features_names

def clean_metadata(meta_df):
  new_index = np.array(meta_df.index)
  k = 0
  for i, row in meta_df.iterrows():
    if row['Geography'] == 'Fiji':
      new_index[k] = i + '_profiled_Metaphlan3'
    elif row['Geography'] == 'Nunavik':
      new_index[k] = 'X' + i
    k += 1
  meta_df = meta_df.set_index(new_index)
  meta_df.columns = ['Age', 'Community', 'Sex', 'SampleIDC', 'Region', 'Lifestyle', 'Geography', 'Coast', 'Community size']
  print('available metadata:', list(meta_df))
  return meta_df

def preprocess_data(X, metadata):
  # assert that all samples are described in the metadata :
  y = [] #prepare labels list 
  for sample_name in X.index:
      if sample_name in metadata.index:
          # get sample metadata and add the corresponding label in y
          diet = metadata.loc[sample_name, 'Lifestyle']
          y.append(diet)
      else:
          # if sample not described in metadata : remove it from data matrix
          X = X.drop(sample_name)
  print('classes in y:', list(dict.fromkeys(y)))
  features_names = list(X)
  y = np.array(y)
  return X, y

def get_hash(df):
  # returns a hash value for a dataframe
  # used to ensure that the data is the same than in the original experiment run
  # idea from https://death.andgravity.com/stable-hashing
  assert isinstance(df, pd.DataFrame)
  json_dump = df.to_json(orient='split', date_format='epoch', double_precision=10, force_ascii=True, date_unit='ms', lines=False, index=True, indent=None)
  return hashlib.md5(json_dump.encode('utf-8')).digest().hex()

In [18]:
## ********** machine learning utils functions **********

dir = ''
splits_subdirectory = os.path.join(dir, 'splits')
results_subdirectory = os.path.join(dir, 'grid_search_results')

def generate_splits(X, y, n_splits, test_size, seed):
  if not os.path.exists(splits_subdirectory):
    os.makedirs(splits_subdirectory)
  for id in range(n_splits):
    random_state = seed + id
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
    np.savez(os.path.join(splits_subdirectory,'split_'+str(id)), X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test)
    print('split {} saved in {}'.format(id, os.path.join(splits_subdirectory,'split_'+str(id))))

def grid_search_on_split(split_id, classifier, hyperparameters_grid):
  X_train, X_test, y_train, y_test = load_split(split_id)
  grid = GridSearchCV(estimator=classifier, param_grid=hyperparameters_grid, verbose=20, n_jobs=4)
  grid_result = grid.fit(X_train, y_train)
  print('grid-Search done!\n')
  best_params = grid_result.best_params_
  save_best_params(best_params, split_id)
  print('best params:', best_params)
  results = pd.DataFrame(grid_result.cv_results_) # convert GS results to a pandas dataframe
  save_grid_search_results(results, split_id)

def load_split(split_id):
  split_filename = os.path.join(splits_subdirectory, 'split_{}.npz'.format(split_id))
  split = np.load(split_filename)
  X_train, X_test = split['X_train'], split['X_test']
  y_train, y_test = split['y_train'], split['y_test']
  return X_train, X_test, y_train, y_test

def clean_splits_files(n_splits):
  for split_id in range(n_splits):
    split_filename = os.path.join(splits_subdirectory, 'split_{}.npz'.format(split_id))
    os.remove(split_filename)
  os.rmdir(splits_subdirectory)

def save_best_params(best_params, split_id):
  save_params_path = os.path.join(results_subdirectory,'best_params_{}.pkl'.format(split_id))
  if not os.path.exists(results_subdirectory):
    os.makedirs(results_subdirectory)
  with open(save_params_path, 'wb') as f:
    pickle.dump(best_params, f)
  print('grid search best parameters for split {} saved in {}'.format(split_id, save_params_path))  

def load_best_params(split_id):
  best_params_filename = os.path.join(results_subdirectory,'best_params_{}.pkl'.format(split_id))
  with open(best_params_filename, "rb") as f:
    best_params = pickle.load(f)
  return best_params

def save_grid_search_results(results, split_id):
  save_result_path = os.path.join(results_subdirectory,'results_{}.csv'.format(split_id))
  if not os.path.exists(results_subdirectory):
    os.makedirs(results_subdirectory)
  results.to_csv(save_result_path)
  print('grid search results for split {} saved in {}.csv'.format(split_id, save_result_path))  

def load_grid_search_results(split_id):
  results_filename = os.path.join(results_subdirectory,'results_{}.csv'.format(split_id))
  best_params = pd.read_csv(results_filename)
  return best_params

In [7]:
## ********** results vizualisation functions **********

def evaluate_predictions(y_test, pred):
  print('random_forest accuracy: ', round(accuracy_score(y_test, pred), 3))
  print('random_forest balanced_accuracy_score: ', round(balanced_accuracy_score(y_test, pred), 3))
  print('random_forest matthews_correlation coefficient: ', round(matthews_corrcoef(y_test, pred), 3))
  conf_labels = ['Westernized', 'Nunavik', 'Non Westernized']
  conf = confusion_matrix(y_test, pred, labels=conf_labels)
  confusion = conf/conf.sum()
  print('\n\n      # random forest confusion matrix :')
  print('                                              predicted values')
  print('                                  {}    |      {}     |  {}  | '.format(conf_labels[0], conf_labels[1], conf_labels[2]))
  print('                                ---------------------------------------------------------')
  print('   true     {}       |      {:.4f}      |      {:.4f}      |      {:.4f}       |'.format('Westernized', confusion[0,0], confusion[0,1], confusion[0,2]))
  print('  values                        ---------------------------------------------------------')
  print('            {}           |      {:.4f}      |      {:.4f}      |      {:.4f}       |'.format('Nunavik', confusion[1,0], confusion[1,1], confusion[1,2]))
  print('                                ---------------------------------------------------------')
  print('            {}   |      {:.4f}      |      {:.4f}      |      {:.4f}       |'.format('Non Westernized', confusion[2,0], confusion[2,1], confusion[2,2]))
  print('                                ---------------------------------------------------------')

def display_top_features(features_importance, features_names, n_top_feat):
  result = {features_names[i] : round(features_importance[i],4) for i in range(len(features_names))}
  res = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True)}
  rank = 0
  print('  rank    feature {:130s}'.format(''), 'importance')
  for k in res:
    rank += 1
    if rank <= n_top_feat:
      print('  {:3d}     {:140s}'.format(rank, k), res[k])

In [19]:
## ********** running experiments **********

## ***** load data and metadata *****
data_filename = 'especes_metaphlan3.csv'
X = pd.read_csv(data_filename, index_col=0)
X, features_names = clean_data(X)
# ensure that the data file is the same than in original experiment run :
assert get_hash(X) == '76db802436148af02fdc9d5ed48d7e5d'

meta_filename = 'Metadata_samples_all_controls_2020-09-02_V6.csv'
meta_df = pd.read_csv(meta_filename,sep=';',index_col=0)
# ensure that the metadata file is the same than in original experiment run :
assert get_hash(meta_df) == '53c67bf4e333301df1e90cddea472ef8'
metadata = clean_metadata(meta_df)
print(metadata.head())

print('classes:', list(dict.fromkeys(metadata['Lifestyle'])))

X, y = preprocess_data(X, metadata)

print('# of samples:', X.shape[0])
print('# of features:', X.shape[1])
assert np.array_equal(X.shape[0], len(y))

## ***** generate splits of the data matrix *****
n_splits = 10
test_size = 0.3
seed = 1
generate_splits(X, y, n_splits, test_size, seed)

## ***** grid search tuning on splits *****
#random_forest_n_estims = [1, 10, 100, 1000]
#random_forest_param_grid = {'n_estimators' : random_forest_n_estims}
#for split_id in range(n_splits):
#  classifier = RandomForestClassifier(random_state=1)
#  grid_search_on_split(split_id, classifier, random_forest_param_grid)

## ***** compute classifiers average performances and features importances *****
all_pred, all_y_test = [], []
sum_imp = []

for split_id in range(n_splits):
  print('split', split_id)
  X_train, X_test, y_train, y_test = load_split(split_id)
  classifier = RandomForestClassifier(random_state=1) #define classifier
  best_params = load_best_params(split_id)
  classifier.set_params(**best_params)
  print('best params:', best_params)
  classifier.fit(X_train, y_train)
  y_pred = classifier.predict(X_test)
  all_pred.extend(y_pred)
  print('len(all_pred)', len(all_pred))
  all_y_test.extend(y_test)
  features_importance = classifier.feature_importances_
  assert len(features_importance) == X.shape[1] == len(features_names)
  sum_imp.append(features_importance.sum())
  if split_id == 0:
    all_imp = features_importance
  else:
    all_imp = np.add(all_imp, features_importance)

## ***** display results *****
print('sum of importances for each split:', sum_imp) #check that it is always 1
evaluate_predictions(y_test=all_y_test, pred=all_pred)
display_top_features(features_importance, features_names, 30)

## ***** cleaning *****
clean_splits_files(n_splits)

available metadata: ['Age', 'Community', 'Sex', 'SampleIDC', 'Region', 'Lifestyle', 'Geography', 'Coast', 'Community size']
                                           Age  ... Community size
SchirmerM_2016_G89275_profiled_Metaphlan3   20  ...    Netherlands
SchirmerM_2016_G89250_profiled_Metaphlan3   22  ...    Netherlands
SchirmerM_2016_G89187_profiled_Metaphlan3   27  ...    Netherlands
SchirmerM_2016_G89182_profiled_Metaphlan3   21  ...    Netherlands
SchirmerM_2016_G89134_profiled_Metaphlan3   21  ...    Netherlands

[5 rows x 9 columns]
classes: ['Westernized', 'Non Westernized', 'Nunavik']
classes in y: ['Westernized', 'Non Westernized', 'Nunavik']
# of samples: 456
# of features: 714
split 0 saved in splits/split_0
split 1 saved in splits/split_1
split 2 saved in splits/split_2
split 3 saved in splits/split_3
split 4 saved in splits/split_4
split 5 saved in splits/split_5
split 6 saved in splits/split_6
split 7 saved in splits/split_7
split 8 saved in splits/split_8
split 9 save