In [1]:
import numpy as np
import pandas as pd

In [2]:
abundance = pd.read_csv("../../results/data/prepared/processed_abundance.csv")
groups = pd.read_csv("../../results/data/prepared/groups.csv")

abundance = abundance.pivot(index="sample", columns="glycan", values="value")
abundance = pd.DataFrame(np.log2(abundance.values), columns=abundance.columns, index=abundance.index)
groups = groups.set_index("sample")
data = pd.merge(abundance, groups, left_index=True, right_index=True, how="left")
data = data[data["group"] != "QC"]
data["group"] = data["group"] == "C"
data["group"] = data["group"].astype(int)

In [3]:
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(data, test_size=128, random_state=42, stratify=data["group"], shuffle=True)

In [4]:
X_train = train_data.drop("group", axis=1)
y_train = train_data["group"]
X_test = test_data.drop("group", axis=1)
y_test = test_data["group"]

In [12]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

rf = RandomForestClassifier(n_estimators=1000, random_state=42)
lr = Pipeline(
    [
        ("scaler", StandardScaler()),
        ("lr", LogisticRegression(random_state=42))
    ]
)
svm = Pipeline(
    [
        ("scaler", StandardScaler()),
        ("svm", SVC(random_state=42))
    ]
)

rf_scores = cross_val_score(rf, X_train, y_train, cv=10, scoring="accuracy")
svm_scores = cross_val_score(svm, X_train, y_train, cv=10, scoring="accuracy")
lr_scores = cross_val_score(lr, X_train, y_train, cv=10, scoring="accuracy")

print(f"Random Forest: {rf_scores.mean():.2f} +/- {rf_scores.std():.2f}")
print(f"Logistic Regression: {lr_scores.mean():.2f} +/- {lr_scores.std():.2f}")
print(f"SVM: {svm_scores.mean():.2f} +/- {svm_scores.std():.2f}")

Random Forest: 0.78 +/- 0.05
Logistic Regression: 0.77 +/- 0.04
SVM: 0.78 +/- 0.03
