In [1]:
import numpy as np
import pandas as pd
from sklearn.feature_selection import VarianceThreshold
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedGroupKFold, cross_val_predict, train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn import metrics
from sklearn.utils import check_X_y

In [2]:
def load(clinical_path, gene_expression_path):
    clinical = pd.read_parquet(clinical_path)
    gene_expression = pd.read_parquet(gene_expression_path)
    
    clinical.set_index("sampleID", inplace=True)
    gene_expression.set_index("caseID", inplace=True)
    
    gene_expression.index = gene_expression.index.str.split("-").str[:4].str.join("-")

    if not clinical.index.is_unique:
        raise ValueError

    if not gene_expression.index.is_unique:
        raise ValueError

    common_case = clinical.index.intersection(gene_expression.index)

    clinical = clinical.loc[common_case]
    gene_expression = gene_expression.loc[common_case]

    data = pd.concat({"clinical": clinical, "gene_expression": gene_expression}, axis=1)
    data.index.name = "caseID"

    return data

In [3]:
class GroupKShotsFold:
    def __init__(self, n_splits, *, k, random_state=None):
        self.n_splits = n_splits
        self.k = k
        self.random_state = random_state

    def get_n_splits(self, X=None, y=None, groups=None):
        return self.n_splits

    def split(self, X, y, groups):
        _, y = check_X_y(X, y)

        n_classes = len(np.unique(y))
        train_size = self.k * n_classes

        cv = StratifiedGroupKFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state)

        for train, test in cv.split(X, y, groups):
            train, _ = train_test_split(
                train,
                train_size=train_size,
                stratify=y[train],
                shuffle=True,
                random_state=self.random_state,
            )

            yield train, test

In [4]:
data = load("../data/label.parquet", "../data/mRNA.omics.parquet")

data = data.loc[
    data["clinical", "patient"].notnull()
    & data["clinical", "cancer_type"].notnull()
    & data["gene_expression"].notnull().all(axis=1)
    & (data["clinical", "cancer_type"] != "Normal")
    & (data["clinical", "sample_type"] == "Tumour")
]

In [5]:
X = data["gene_expression"]
y = data["clinical", "cancer_type"]
g = data["clinical", "patient"]

In [6]:
cv = GroupKShotsFold(10, k=1, random_state=0)

non_zero_variance = VarianceThreshold()

scaler = StandardScaler()

classifier = LogisticRegression(
    solver="lbfgs",
    multi_class="multinomial",
    max_iter=2000,
    tol=1e-2,
    class_weight="balanced",
    n_jobs=8,
    random_state=0,
)

model = make_pipeline(non_zero_variance, scaler, classifier)

y_pred = cross_val_predict(model, X, y, groups=g, cv=cv)

report = metrics.classification_report(y, y_pred, zero_division=np.nan)
print(report)
print(f"\nAccuracy: {100 * metrics.accuracy_score(y, y_pred):.2f}%")

              precision    recall  f1-score   support

   TCGA-BLCA       0.33      0.18      0.23       406
   TCGA-BRCA       0.97      0.57      0.72      1101
   TCGA-CESC       0.39      0.41      0.40       306
   TCGA-COAD       0.83      0.66      0.73       460
   TCGA-HNSC       0.48      0.66      0.56       522
   TCGA-KIRC       0.53      0.83      0.65       534
   TCGA-KIRP       0.69      0.57      0.63       291
    TCGA-LGG       0.67      1.00      0.80       534
   TCGA-LIHC       0.85      0.93      0.89       374
   TCGA-LUAD       0.70      0.32      0.44       518
   TCGA-LUSC       0.39      0.45      0.42       501
     TCGA-OV       0.51      0.91      0.65       429
   TCGA-PRAD       0.86      0.95      0.91       498
   TCGA-SARC        nan      0.00      0.00       263
   TCGA-SKCM       0.79      0.75      0.77       472
   TCGA-STAD       0.51      0.79      0.62       412
   TCGA-THCA       0.91      0.95      0.93       513
   TCGA-UCEC       0.64    