In [1]:
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedGroupKFold, cross_val_predict
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn import metrics

from ssrl_rnaseq.data import load_tcga

In [2]:
data = load_tcga("../../data/label.parquet", "../../data/mRNA.omics.parquet")

data = data.loc[
    data["clinical", "patient"].notnull() &
    data["clinical", "cancer_type"].notnull() &
    data["gene_expression"].notnull().all(axis=1) &
    (data["clinical", "cancer_type"] != "Normal") &
    (data["clinical", "sample_type"] == "Tumour")
]

In [3]:
X = data["gene_expression"]
y = data["clinical", "cancer_type"]
g = data["clinical", "patient"]

In [4]:
cv = StratifiedGroupKFold(5, shuffle=True, random_state=0)

select_k_best = SelectKBest(f_classif, k=1000)

scaler = StandardScaler()

classifier = LogisticRegression(
    solver="lbfgs",
    multi_class="multinomial",
    max_iter=2000,
    tol=1e-2,
    class_weight="balanced",
    n_jobs=8,
    random_state=0,
)

model = make_pipeline(select_k_best, scaler, classifier)

y_pred = cross_val_predict(model, X, y, groups=g, cv=cv)

report = metrics.classification_report(y, y_pred)
print(report)
print(f"\nAccuracy: {metrics.accuracy_score(y, y_pred)}")

              precision    recall  f1-score   support

   TCGA-BLCA       0.91      0.88      0.89       406
   TCGA-BRCA       0.99      0.99      0.99      1101
   TCGA-CESC       0.82      0.84      0.83       306
   TCGA-COAD       0.98      0.99      0.99       460
   TCGA-HNSC       0.86      0.87      0.86       522
   TCGA-KIRC       0.98      0.95      0.97       534
   TCGA-KIRP       0.94      0.96      0.95       291
    TCGA-LGG       1.00      0.99      1.00       534
   TCGA-LIHC       0.99      0.99      0.99       374
   TCGA-LUAD       0.93      0.94      0.94       518
   TCGA-LUSC       0.86      0.85      0.85       501
     TCGA-OV       1.00      0.99      1.00       429
   TCGA-PRAD       1.00      1.00      1.00       498
   TCGA-SARC       0.89      0.97      0.93       263
   TCGA-SKCM       0.97      0.97      0.97       472
   TCGA-STAD       0.99      1.00      0.99       412
   TCGA-THCA       1.00      1.00      1.00       513
   TCGA-UCEC       0.97    