In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from sklearn.feature_selection import VarianceThreshold
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from sklearn.model_selection import cross_val_predict
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn import metrics

from ssrl_rnaseq.split import pretrain_downstream_split, GroupKShotsFold
from ssrl_rnaseq.data import load_tcga

In [2]:
data = load_tcga("../../data/label.parquet", "../../data/mRNA.omics.parquet")

data = data.loc[
    data["clinical", "patient"].notnull()
    & data["clinical", "cancer_type"].notnull()
    & data["gene_expression"].notnull().all(axis=1)
    & (data["clinical", "cancer_type"] != "Normal")
    & (data["clinical", "sample_type"] == "Tumour")
]

In [3]:
X = data["gene_expression"]
y = data["clinical", "cancer_type"]
g = data["clinical", "patient"]

In [4]:
X_pretrain, X_downstream, _, y_downstream, g_pretrain, g_downstream = pretrain_downstream_split(
    X, y, g, pretrain_size=7000, downstream_size=1200, groups=g, stratify=y, random_state=0,
)

## Without pretrained PCA

In [6]:
cv = GroupKShotsFold(10, k=1, random_state=0)

non_zero_variance = VarianceThreshold()

scaler = StandardScaler()

classifier = LogisticRegression(
    solver="lbfgs",
    multi_class="multinomial",
    max_iter=2000,
    tol=1e-2,
    class_weight="balanced",
    n_jobs=8,
    random_state=0,
)

model = make_pipeline(non_zero_variance, scaler, classifier)

y_downstream_pred = cross_val_predict(model, X_downstream, y_downstream, groups=g_downstream, cv=cv)

report = metrics.classification_report(y_downstream, y_downstream_pred, zero_division=np.nan)
print(report)
print(f"\nAccuracy: {100 * metrics.accuracy_score(y_downstream, y_downstream_pred):.2f}%")

              precision    recall  f1-score   support

   TCGA-BLCA       0.47      0.31      0.37        52
   TCGA-BRCA       0.77      0.40      0.53       151
   TCGA-CESC       0.36      0.39      0.38        41
   TCGA-COAD       0.79      0.71      0.75        62
   TCGA-HNSC       0.47      0.54      0.50        74
   TCGA-KIRC       0.49      0.66      0.56        76
   TCGA-KIRP       0.64      0.60      0.62        42
    TCGA-LGG       0.79      0.99      0.88        77
   TCGA-LIHC       0.80      0.87      0.83        55
   TCGA-LUAD       0.71      0.45      0.56        66
   TCGA-LUSC       0.45      0.54      0.49        70
     TCGA-OV       0.52      0.88      0.65        58
   TCGA-PRAD       0.68      0.94      0.79        67
   TCGA-SARC       0.60      0.30      0.40        40
   TCGA-SKCM       0.95      0.70      0.81        60
   TCGA-STAD       0.44      0.81      0.57        58
   TCGA-THCA       0.96      0.95      0.95        74
   TCGA-UCEC       0.74    

## With pretrained PCA

In [7]:
non_zero_variance = VarianceThreshold()
scaler = StandardScaler()
pca = PCA(n_components=1000, random_state=0)

encoder = make_pipeline(non_zero_variance, scaler, pca)

encoder.fit(X_pretrain)

In [8]:
cv = GroupKShotsFold(10, k=1, random_state=0)

model = LogisticRegression(
    solver="lbfgs",
    multi_class="multinomial",
    max_iter=2000,
    tol=1e-2,
    class_weight="balanced",
    n_jobs=8,
    random_state=0,
)

e_downstream = encoder.transform(X_downstream)
y_downstream_pred = cross_val_predict(model, e_downstream, y_downstream, groups=g_downstream, cv=cv)

report = metrics.classification_report(y_downstream, y_downstream_pred, zero_division=np.nan)
print(report)
print(f"\nAccuracy: {100 * metrics.accuracy_score(y_downstream, y_downstream_pred):.2f}%")

              precision    recall  f1-score   support

   TCGA-BLCA       0.53      0.38      0.44        52
   TCGA-BRCA       0.90      0.37      0.53       151
   TCGA-CESC       0.27      0.32      0.29        41
   TCGA-COAD       0.62      0.77      0.69        62
   TCGA-HNSC       0.40      0.61      0.48        74
   TCGA-KIRC       0.75      0.63      0.69        76
   TCGA-KIRP       0.55      0.67      0.60        42
    TCGA-LGG       0.84      1.00      0.91        77
   TCGA-LIHC       0.69      0.91      0.79        55
   TCGA-LUAD       0.58      0.42      0.49        66
   TCGA-LUSC       0.37      0.44      0.40        70
     TCGA-OV       0.73      0.78      0.75        58
   TCGA-PRAD       0.59      0.85      0.70        67
   TCGA-SARC       0.53      0.65      0.58        40
   TCGA-SKCM       0.83      0.75      0.79        60
   TCGA-STAD       0.64      0.62      0.63        58
   TCGA-THCA       0.89      0.86      0.88        74
   TCGA-UCEC       0.71    