In [1]:
from sklearn.feature_selection import VarianceThreshold, SelectKBest, f_classif
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn import metrics

from ssrl_rnaseq.split import pretrain_downstream_split
from ssrl_rnaseq.data import load_tcga

In [2]:
data = load_tcga("../../data/label.parquet", "../../data/mRNA.omics.parquet")

data = data.loc[
    data["clinical", "patient"].notnull()
    & data["clinical", "cancer_type"].notnull()
    & data["gene_expression"].notnull().all(axis=1)
    & (data["clinical", "cancer_type"] != "Normal")
    & (data["clinical", "sample_type"] == "Tumour")
]

In [3]:
X = data["gene_expression"]
y = data["clinical", "cancer_type"]
g = data["clinical", "patient"]

In [5]:
X_train, X_test, y_train, y_test = pretrain_downstream_split(
    X, y, pretrain_size=930, downstream_size=900, groups=g, stratify=y, random_state=0,
)

In [6]:
non_zero_variance = VarianceThreshold()

select_k_best = SelectKBest(f_classif, k=50000)

scaler = StandardScaler()

classifier = LogisticRegression(
    solver="lbfgs",
    multi_class="multinomial",
    max_iter=2000,
    tol=1e-2,
    class_weight="balanced",
    n_jobs=8,
    random_state=0,
)

model = make_pipeline(non_zero_variance, select_k_best, scaler, classifier)

model.fit(X_train, y_train)

y_test_pred = model.predict(X_test)

report = metrics.classification_report(y_test, y_test_pred)
print(report)
print(f"\nAccuracy: {metrics.accuracy_score(y_test, y_test_pred)}")

              precision    recall  f1-score   support

   TCGA-BLCA       0.87      0.85      0.86        46
   TCGA-BRCA       0.99      0.96      0.97       113
   TCGA-CESC       0.86      0.86      0.86        29
   TCGA-COAD       0.98      0.96      0.97        53
   TCGA-HNSC       0.89      0.91      0.90        55
   TCGA-KIRC       0.94      0.91      0.93        55
   TCGA-KIRP       0.91      0.91      0.91        32
    TCGA-LGG       0.96      1.00      0.98        51
   TCGA-LIHC       0.91      0.98      0.94        42
   TCGA-LUAD       0.96      0.92      0.94        50
   TCGA-LUSC       0.88      0.87      0.88        53
     TCGA-OV       0.94      1.00      0.97        48
   TCGA-PRAD       0.98      1.00      0.99        47
   TCGA-SARC       0.96      0.92      0.94        26
   TCGA-SKCM       0.96      0.90      0.93        50
   TCGA-STAD       0.95      1.00      0.98        41
   TCGA-THCA       1.00      1.00      1.00        51
   TCGA-UCEC       0.90    