In [1]:
import numpy as np
from sklearn.feature_selection import VarianceThreshold
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn import metrics

from ssrl_rnaseq.split import GroupKShotsFold
from ssrl_rnaseq.data import load_tcga

In [2]:
data = load_tcga("../../data/label.parquet", "../../data/mRNA.omics.parquet")

data = data.loc[
    data["clinical", "patient"].notnull()
    & data["clinical", "cancer_type"].notnull()
    & data["gene_expression"].notnull().all(axis=1)
    & (data["clinical", "cancer_type"] != "Normal")
    & (data["clinical", "sample_type"] == "Tumour")
]

In [3]:
X = data["gene_expression"]
y = data["clinical", "cancer_type"]
g = data["clinical", "patient"]

In [4]:
cv = GroupKShotsFold(10, k=1, random_state=0)

non_zero_variance = VarianceThreshold()

scaler = StandardScaler()

classifier = LogisticRegression(
    solver="lbfgs",
    multi_class="multinomial",
    max_iter=2000,
    tol=1e-2,
    class_weight="balanced",
    n_jobs=8,
    random_state=0,
)

model = make_pipeline(non_zero_variance, scaler, classifier)

y_pred = cross_val_predict(model, X, y, groups=g, cv=cv)

report = metrics.classification_report(y, y_pred, zero_division=np.nan)
print(report)
print(f"\nAccuracy: {100 * metrics.accuracy_score(y, y_pred):.2f}%")

              precision    recall  f1-score   support

   TCGA-BLCA       0.65      0.39      0.49       406
   TCGA-BRCA       0.87      0.57      0.69      1101
   TCGA-CESC       0.37      0.39      0.38       306
   TCGA-COAD       0.78      0.82      0.80       460
   TCGA-HNSC       0.61      0.56      0.58       522
   TCGA-KIRC       0.69      0.66      0.67       534
   TCGA-KIRP       0.58      0.59      0.58       291
    TCGA-LGG       0.82      0.99      0.90       534
   TCGA-LIHC       0.77      0.91      0.83       374
   TCGA-LUAD       0.63      0.49      0.55       518
   TCGA-LUSC       0.47      0.47      0.47       501
     TCGA-OV       0.51      0.84      0.63       429
   TCGA-PRAD       0.80      0.94      0.86       498
   TCGA-SARC       0.41      0.33      0.36       263
   TCGA-SKCM       0.76      0.68      0.71       472
   TCGA-STAD       0.37      0.86      0.52       412
   TCGA-THCA       0.93      0.94      0.93       513
   TCGA-UCEC       0.67    