In [107]:
import pandas as pd
import numpy as np
from utils import get_species, get_labels, get_labels_all
from utils import get_taxonomy

In [108]:
X, y, y_all = get_taxonomy(), get_labels(), get_labels_all()

In [109]:
raw = pd.read_csv("../data/raw.csv", index_col=[1, 4])

  exec(code_obj, self.user_global_ns, self.user_ns)


In [110]:
raw.shape

(12532, 3217)

In [111]:
# to reduce batch effects
non_illumina = [
    "454 GS FLX Titanium",
    "Ion Torrent PGM",
    "Ion Torrent Proton",
    "BGISEQ-500",
]

In [112]:
remove = (
    (y_all == "Underweight").values.flatten() | 
    (y_all == "Overweight").values.flatten() |
    (y_all == "Obesity").values.flatten() |
    (y_all == "Obese").values.flatten() |
    raw["Sequencing Platform"].isin(non_illumina).values.flatten() |
    (y_all.index.get_level_values(0) == "P4") | # P4 treats the poop for extracting viral DNA
    (y_all.index.get_level_values(0) == "P86") | # Healthy at baseline but half develop T2D 
#     y2.values.flatten() | 
    ((y_all.index.get_level_values(0) == "P48") & (y_all == "Healthy").values.flatten()) | # Alcohol or smoking
    (y_all.index.get_level_values(0) == "P59") | # Are all technically healthy, but half are in heavily urbanized areas
    # and "Microbes with higher relative abundance in Chinese urban samples have been associated with disease in other studies"
    (y_all.index.get_level_values(0) == "P63") | # Deals with semisupercentenarians, i.e., 105 to 109 years old
#     (y_all.index.get_level_values(0) == "GMHI-4") |
    (X['UNKNOWN'] >= 100).values.flatten()
    )

In [113]:
X, y, y_all = X.iloc[~remove, :], y.iloc[~remove, :], y_all.iloc[~remove, :]

In [114]:
raw = raw.iloc[~remove, :]

In [115]:
X = X.divide((100 - X["UNKNOWN"]), axis="rows")

In [116]:
X.shape

(9067, 3200)

In [117]:
studies = np.unique(X.index.get_level_values(0))

In [118]:
np.random.seed(42)
perm = np.random.permutation(len(studies))
prop = 0.9
train_idx, test_idx = perm[:int(len(studies) * prop)], perm[int(len(studies) * prop):]
train_studies = studies[train_idx]
test_studies = studies[test_idx]

In [119]:
raw["Geographical Region or Population"]

Study_ID  Sample Accession
GMHI-23   SAMEA3879547        USA
          SAMEA3879551        USA
          SAMEA3879543        USA
          SAMEA3879565        USA
          SAMEA3879546        USA
                             ... 
P140      SAMN07509555        USA
          SAMN07509557        USA
          SAMN07509546        USA
          SAMN07509552        USA
          SAMN07509921        USA
Name: Geographical Region or Population, Length: 9067, dtype: object

In [88]:
from sklearn.preprocessing import OneHotEncoder

oh = OneHotEncoder(sparse=False)
encoded = oh.fit_transform(raw[["Geographical Region or Population"]])
encoded

array([[0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       ...,
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 1., 0.]])

In [90]:
y_new = pd.DataFrame(encoded, index=y.index, columns=np.unique(raw["Geographical Region or Population"]))
y_new["target"] = y.iloc[:, 0]

In [106]:
y_new.columns

Index(['Austria', 'Burkina Faso', 'Cameroon', 'China', 'Denmark', 'Ethiopia',
       'Finland', 'France', 'German', 'Germany', 'Ghana', 'India', 'Ireland',
       'Israel', 'Italy', 'Italy: Milan', 'Italy: Vercelli', 'Japan',
       'Kazakhstan', 'Madagascar', 'Mongolia', 'Netherlands', 'New Zealand',
       'North America', 'Peru', 'South Korea', 'Spain', 'Sweden', 'Tanzania',
       'USA', 'United Kingdom', 'target'],
      dtype='object')

In [92]:
X_train, y_train = X.loc[train_studies], y_new.loc[train_studies]
X_test, y_test = X.loc[test_studies], y_new.loc[test_studies]

In [98]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score, accuracy_score

c = 0.00001

clf = LogisticRegression(random_state=42, penalty="l1", solver="liblinear", C=0.02, class_weight="balanced")
clf.fit(X_train > c, y_train.iloc[:, -1])
y_hat = clf.predict(X_test > c)
balanced_accuracy_score(y_test.iloc[:, -1], y_hat), accuracy_score(y_test.iloc[:, -1], y_hat)

(0.7113164889774928, 0.7187789084181314)

In [99]:
from sklearn.neural_network import MLPClassifier

In [101]:
clf = MLPClassifier(random_state=42)
clf.fit(X_train.values > c, y_train * 1.0)
y_hat = clf.predict(X_test.values > c)



ValueError: Classification metrics can't handle a mix of binary and multilabel-indicator targets

In [105]:
balanced_accuracy_score(y_test.iloc[:, -1], y_hat[:, -1]), accuracy_score(y_test.iloc[:, -1], y_hat[:, -1])

(0.6830794400624078, 0.6484736355226642)

In [95]:
pd.DataFrame(y_hat, index=y_test.index, columns=y_test.columns)

Unnamed: 0_level_0,Unnamed: 1_level_0,Austria,Burkina Faso,Cameroon,China,Denmark,Ethiopia,Finland,France,German,Germany,...,New Zealand,North America,Peru,South Korea,Spain,Sweden,Tanzania,USA,United Kingdom,target
Study_ID,Sample Accession,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
GMHI-9,SAMEA2042385,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
GMHI-9,SAMEA2042199,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
GMHI-9,SAMEA2041938,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
GMHI-9,SAMEA2042581,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
GMHI-9,SAMEA2042491,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
P21,SAMEA3708723,0,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
P21,SAMEA3708721,0,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
P21,SAMEA3708719,0,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
P21,SAMEA3708717,0,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1


In [96]:
balanced_accuracy_score(y_test.iloc[:, -1], y_hat[:, -1])

0.6877654502903701

In [71]:
y_hat[:, -1]

array([1, 1, 0, ..., 0, 0, 1])

In [97]:
y_test

Unnamed: 0_level_0,Unnamed: 1_level_0,Austria,Burkina Faso,Cameroon,China,Denmark,Ethiopia,Finland,France,German,Germany,...,New Zealand,North America,Peru,South Korea,Spain,Sweden,Tanzania,USA,United Kingdom,target
Study_ID,Sample Accession,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
GMHI-9,SAMEA2042385,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,True
GMHI-9,SAMEA2042199,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,True
GMHI-9,SAMEA2041938,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,True
GMHI-9,SAMEA2042581,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,True
GMHI-9,SAMEA2042491,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
P21,SAMEA3708723,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,False
P21,SAMEA3708721,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,False
P21,SAMEA3708719,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,False
P21,SAMEA3708717,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,False
