In [1]:
import os
os.environ["RPY2_CFFI_MODE"] = "ABI"
import rpy2.robjects as ro
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import localconverter
from rpy2.rinterface_lib.callbacks import logger as rpy2_logger
rpy2_logger.setLevel("ERROR")   # 只显示错误，屏蔽 message 和 warning
rpy2_logger.propagate = False   # 阻止继续传给 root logger

import sys
sys.path.append("../../")

# from scSurvival.scsurvival import scSurvival, scSurvivalRun, PredictIndSample
from scSurvival_beta import scSurvivalRun, PredictIndSample
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import os 
os.environ['KMP_DUPLICATE_LIB_OK']='True'

from tqdm import tqdm, trange
import scanpy as sc
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

from sklearn.metrics import classification_report
from sklearn.model_selection import KFold
import io
import contextlib
f = io.StringIO()
from lifelines.utils import concordance_index
from scipy.stats import percentileofscore
from utils import *

In [2]:
def load_r_packages():
    ro.r('''  
    rm(list=ls())
    # library("scater")
    library("splatter")
    library(scran)
    library(Seurat)
    # library(preprocessCore)
    library(pROC)
    # library(APML2)
    # library(APML1)
    # library(APML0)

    library(ggplot2)
    library(dplyr)
    library(caret)
    set.seed(1)
    ''')


def simulated_base_sc_dataset(seed=42, plot=False, cell_surv_ratio=0.5):
    ro.r(f'''
    seed <- {seed}
    alpha = {cell_surv_ratio}
    sim.groups <- splatSimulateGroups(
    batchCells = 10000, nGenes=5000,
    #group.prob = c(0.9, 0.05, 0.05),
    group.prob = c(1 - 2*alpha, alpha, alpha),
    de.prob = c(0.2, 0.06, 0.06), 
    de.facLoc = c(0.1, 0.1, 0.1),
    de.facScale = 0.4,
    seed = seed)

    data <- CreateSeuratObject(counts = counts(sim.groups), project = 'Scissor_Single_Cell')
    data <- AddMetaData(object = data, metadata = sim.groups$Group, col.name = "sim.group")
    data$Actual.cond <- recode(data$sim.group,'Group1'='other', 'Group2'='good.survival', 'Group3'='bad.survival')

    select_gene_ids <- 1:2000
    data <- NormalizeData(object = data, normalization.method = "LogNormalize", 
                          scale.factor = 10000)
    data <- FindVariableFeatures(object = data, selection.method = 'vst', nfeatures=2000)
    var_features_genes = VariableFeatures(data)
    ''')

    if plot:
        ro.r('''
        data <- ScaleData(object = data)
        data <- RunPCA(object = data, features = VariableFeatures(data)[select_gene_ids])

        data <- RunUMAP(object = data, dims = 1:10, n.neighbors = 5, min.dist=0.5, spread=1.5)
        # data <- RunUMAP(object = data, dims = 1:10)
        data <- FindNeighbors(object = data, dims = 1:10, k.param=20)
        # data <- FindNeighbors(object = data, dims = 1:10, k.param=20,prune.SNN = 0.2)
        data <- FindClusters(object = data, resolution = 0.5)

        DimPlot(object = data, reduction = 'umap', group.by = 'seurat_clusters', label = F, label.size = 10,pt.size=0.5)
        ggsave(paste0(save_path, 'simu_seurat_cluster_umap.pdf'), height = 5, width = 7)

        DimPlot(object = data, reduction = 'umap', group.by = 'sim.group', pt.size = 0.5, label = T)
        ggsave(paste0(save_path, 'simu_group_umap.pdf'), height = 5, width = 7)

        # DimPlot(object = data, reduction = 'umap',  cols = c('grey','blue', 'red'), group.by = 'sim.group', pt.size = 0.5, label = T)
        # 
        DimPlot(object = data, reduction = 'umap',  cols = c('grey','blue', 'red'), group.by = 'Actual.cond', pt.size = 0.5, label = T)
        ggsave(paste0(save_path, 'simu_surv_group_umap.pdf'), height = 5, width = 7)
        ''')

def simulated_sc_datasets(plot=False):
    ro.r('''
    Expression_pbmc <- as.matrix(data@assays[["RNA"]]@layers[["data"]])
    rownames(Expression_pbmc) <- rownames(data)
    colnames(Expression_pbmc) <- colnames(data)
    Expression_pbmc <- as.data.frame(Expression_pbmc)
    all_genes <- rownames(Expression_pbmc)
         
    set.seed(seed)
    sampled_cells = 1000
    bulk_num=100

    other_cells <- colnames(Expression_pbmc)[data$Actual.cond=='other']
    good_cells <- colnames(Expression_pbmc)[data$Actual.cond=='good.survival']
    bad_cells <- colnames(Expression_pbmc)[data$Actual.cond=='bad.survival']
    num_good <- length(good_cells)
    num_bad <- length(bad_cells)

    bulk_condition = NULL
    censor_prob = 0.1

    status = NULL
    surv_time = NULL

    num_good_cond_cells = NULL
    num_bad_cond_cells = NULL

    sc_data_list = list()
    pb <- txtProgressBar(min = 1, max = bulk_num, style = 3)
    for (i in 1:bulk_num){
      setTxtProgressBar(pb, i)
      ratio <- (i-1) / (bulk_num-1)
      # ratio <- plogis((ratio - 0.5) * 2 * 6)
      num_good_cond_cells_i = round(num_good * ratio)
      num_bad_cond_cells_i = round(num_bad * (1-ratio))
      condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
      condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
      condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
      # condition_cells <- c(condition_bad_cells, other_cells)
  
      num_good_cond_cells = c(num_good_cond_cells, num_good_cond_cells_i)
      num_bad_cond_cells = c(num_bad_cond_cells, num_bad_cond_cells_i)
  
      Expression_condition = Expression_pbmc[, condition_cells]
      Expression_selected <- Expression_condition[, sample(ncol(Expression_condition),size=sampled_cells,replace=TRUE)]
  
      # filter_cells = intersect(c(condition_bad_cells, other_cells), colnames(Expression_selected))
      # Expression_selected <- Expression_selected[, filter_cells]
  
      # write.csv(Expression_selected, file = sprintf('./source_data/single_cell_revision/%d.csv', i))
      sc_data_list[[sprintf('bulk%d', i)]] <- Expression_selected

      if (runif(1, min = 0, max = 1) < censor_prob){
        status = c(status, 0)
        surv_time = c(surv_time, sample(i, 1))
      }
      else{
        surv_time = c(surv_time, i)
        status = c(status, 1)
      }
    }

    bulk_names <- paste0('bulk', 1:bulk_num)
    surv_info <- data.frame(
      time=surv_time,
      status=status,
      num.good.cells = num_good_cond_cells,
      num.bad.cells = num_bad_cond_cells,
      row.names = bulk_names
    )

    dim(surv_info)
    dim(Expression_pbmc)
         
    labels <- data$Actual.cond
    labels <- as.data.frame(labels)
    row.names(labels) <- colnames(data)
    
    ''')

    if plot:
        ro.r('''
        library(gridExtra)
        library(ggpubr)

        plot_list <- list()

        for (i in c(2, 10, 40, 60, 90, 99)){
          ratio <- (i-1) / (bulk_num-1)
          # ratio <- plogis((ratio - 0.5) * 2 * 6)
          num_good_cond_cells_i = round(num_good * ratio)
          num_bad_cond_cells_i = round(num_bad * (1-ratio))
          condition_good_cells <- good_cells[sample(num_good, num_good_cond_cells_i , replace=TRUE)]
          condition_bad_cells <- bad_cells[sample(num_bad, num_bad_cond_cells_i, replace=TRUE)]
          condition_cells <- c(condition_good_cells, condition_bad_cells, other_cells)
          # condition_cells <- c(condition_bad_cells, other_cells)
  
  
          p <- DimPlot(data[, condition_cells], group.by = 'Actual.cond', cols = c('grey','blue', 'red'), pt.size = 0.5) +
          ggtitle(sprintf("survival.time: %d months", i))
          plot_list[[length(plot_list) + 1]] <- p
        }

        # combined_plot <- do.call(grid.arrange, c(plot_list, ncol = 3))
        # combined_plot
        ggarrange(plotlist = plot_list, ncol = 3, nrow=2, common.legend = TRUE, legend = "bottom")
        ggsave(paste0(save_path, 'survival.time.simulated.pdf'), height = 7, width = 10.5)
        ''')

    # collected sc_data_list, surv_info, Expression_pbmc and transfer to python
    surv_info_df     = r_to_pandas("surv_info")
    Expression_pbmc_df = r_to_pandas("Expression_pbmc")
    sc_data_list     = r_list_to_pydict_df("sc_data_list")  # dict: { 'bulk_1': DataFrame, ... }
    labels_df       = r_to_pandas("labels")
    features = {
        'all_genes': list(ro.r("all_genes")),
        'hvg': list(ro.r("var_features_genes"))
    }

    return_data = {
        'sc_data_list': sc_data_list,
        'surv_info_df': surv_info_df,
        'Expression_pbmc_df': Expression_pbmc_df,
        'labels_df': labels_df,
        'features': features
    }

    return return_data


In [3]:
def organize_data_for_model(datasets):
    sc_data_list = datasets['sc_data_list']
    clinic = datasets['surv_info_df']

    xs = []
    samples = []
    for key, val in tqdm(sc_data_list.items()):
        df = val
        xs.append(df.values.T)
        samples.extend([key] * df.shape[1])

    X = np.concatenate(xs, axis=0)
    adata = sc.AnnData(X, obs=pd.DataFrame(samples, index=np.arange(X.shape[0]), columns=['sample']),
    var=pd.DataFrame(index=datasets['features']['all_genes']))

    adata.raw = adata.copy()
    adata = adata[:, datasets['features']['hvg']]

    surv = clinic[['time', 'status']].copy()
    surv['time'] = surv['time'].astype(float)
    surv['status'] = surv['status'].astype(int)

    df = datasets['Expression_pbmc_df']
    x = df.values.T
    sim_group = datasets['labels_df']
    sim_group = sim_group['labels'].values

    adata_new = sc.AnnData(x, obs=pd.DataFrame(sim_group, index=np.arange(x.shape[0]), columns=['sim_group']), var=pd.DataFrame(index=datasets['features']['all_genes']))

    return adata, surv, adata_new

def detect_subpopulations(adata, surv, adata_new, entropy_threshold=0.7):
    adata, surv, model = scSurvivalRun(adata, 
        sample_column='sample',
        surv=surv,
        # batch_key='batch',
        feature_flavor='AE',
        entropy_threshold=entropy_threshold,
        lambdas=(0.01, 1.0),
        pretrain_epochs=200,
        epochs=500,
        weight_decay=0.01,
        lr=0.001,
        patience=100,
        rec_likelihood='ZIG',
        do_scale_ae=False,
        beta=0.1, tau=0.2, 
        sample_size_ae=None,
        finetue_lr_factor=0.1,
        gene_weight_alpha=0.2,
        gamma_beta_weight=(0.1, 0.0),
        once_load_to_gpu=True,
        use_amp=False,
        fitnetune_strategy='alternating', # jointly, alternating, alternating_lightly,
        )

    data = adata.obs['attention'].values.reshape(-1, 1)
    kmeans = KMeans(n_clusters=2, random_state=42)
    kmeans.fit(data)
    cluster_centers = kmeans.cluster_centers_
    atten_thr = cluster_centers.flatten().mean()
    
    adata_new, _ = PredictIndSample(adata_new, adata, model)

    attention = adata_new.obs['attention'].values
    hazard_adj = adata_new.obs['hazard_adj'].values
    hazard = adata_new.obs['hazard'].values

    risk_group = np.array(['inattentive'] * attention.shape[0], dtype=object)
    risk_group[np.logical_and(attention >= atten_thr, hazard_adj > 0)] = 'higher'
    risk_group[np.logical_and(attention >= atten_thr, hazard_adj <= 0)] = 'lower'

    # higher -> bad.survival, lower -> good.survival, inattentive -> other 

    risk_group_recoded = np.array(['other'] * attention.shape[0], dtype=object)
    risk_group_recoded[risk_group == 'higher'] = 'bad.survival'
    risk_group_recoded[risk_group == 'lower'] = 'good.survival'

    clf_report = classification_report(adata_new.obs['sim_group'].values, risk_group_recoded, output_dict=True, zero_division=0)

    clf_report_df = pd.DataFrame(clf_report).T
    return clf_report_df, adata_new

def cross_validation_samples(adata, surv, entropy_threshold=0.7):
    # 交叉验证样本
    adata = adata.raw.to_adata()
    adata.obs['patient_no'] = adata.obs['sample']
    patients = adata.obs['patient_no'].unique()

    # K fold cross validation
    cv_hazards_adj_cells = np.zeros((adata.shape[0], ))
    surv['cv_hazards_adj_patient'] = 0.0
    surv['cv_hazard_percentile_patient'] = 0.0
    cindexs = []
    surv_test_all_folds = []

    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    for i, (train_index, test_index) in enumerate(kf.split(patients)):

        print(f'fold {i}, train_size: {train_index.shape[0]}, test_size: {test_index.shape[0]}')
        train_patients = patients[train_index]
        test_patients = patients[test_index]

        # train
        adata_train = adata[adata.obs['patient_no'].isin(train_patients), :].copy()
    
        ## select HVGs on training set only
        sc.pp.highly_variable_genes(adata_train, n_top_genes=2000, subset=False, flavor='seurat')
        hvgs = adata_train.var[adata_train.var['highly_variable']].index.tolist() 
        adata_train = adata_train[:, hvgs]

        surv_train = surv.loc[surv.index.isin(train_patients), :].copy()

        adata_train, surv_train, model = scSurvivalRun(
            adata_train,
            sample_column='sample',
            surv=surv_train,
            # batch_key='batch',
            feature_flavor='AE',
            entropy_threshold=entropy_threshold,
            validate=True,
            validate_ratio=0.2,
            validate_metric='ccindex',
            lambdas=(0.01, 1.0),
            pretrain_epochs=200,
            epochs=500,
            weight_decay=0.01,
            lr=0.001,
            patience=100,
            rec_likelihood='ZIG',
            do_scale_ae=False,
            beta=0.1, tau=0.2, 
            sample_size_ae=None,
            finetue_lr_factor=0.1,
            gene_weight_alpha=0.2,
            gamma_beta_weight=(0.1, 0.0),
            once_load_to_gpu=True,
            use_amp=False,
            fitnetune_strategy='alternating', # jointly, alternating, alternating_lightly,
            )  
        
        
        train_cindex = concordance_index(surv_train['time'], -surv_train['patient_hazards'], surv_train['status'])
        print(f'train c-index: {train_cindex:.4f}')

        # test
        print('testing...')
        adata_test = adata[adata.obs['patient_no'].isin(test_patients), :].copy()
        adata_test = adata_test[:, hvgs]

        with contextlib.redirect_stdout(f):
            for test_patient in test_patients:
                adata_test_patient = adata_test[adata_test.obs['patient_no'] == test_patient, :].copy()
                adata_test_patient, patient_hazard = PredictIndSample(adata_test_patient, adata_train, model)
                cv_hazards_adj_cells[adata.obs['patient_no'] == test_patient] = adata_test_patient.obs['hazard_adj'].values
                surv.loc[surv.index == test_patient, 'cv_hazards_adj_patient'] = patient_hazard
                surv.loc[surv.index == test_patient, 'cv_hazard_percentile_patient'] = percentileofscore(surv_train['patient_hazards'], patient_hazard, kind='rank')

        surv_test = surv.loc[surv.index.isin(test_patients), :]
        c_index = concordance_index(surv_test['time'], -surv_test['cv_hazards_adj_patient'], surv_test['status'])

        cindexs.append(c_index)
        surv_test_all_folds.append(surv_test)

        print(f'c-index: {c_index:.4f}')
        print('='*50)

        # if i == 0:
        #     break

    mean_cindex = np.mean(cindexs)
    std_cindex = np.std(cindexs)

    print(f'mean c-index: {mean_cindex:.4f} ± {std_cindex:.4f}')
    cindexs_df = pd.DataFrame(cindexs, columns=['c-index'], index=['fold%d' % i for i in range(5)])

    cindex_results = {
        'mean_cindex': mean_cindex,
        'std_cindex': std_cindex,
        'cindexs_df': cindexs_df
    }

    return cindex_results


In [4]:
ro.r('.libPaths()')

0,1
'/home/groups/XiaLab/re...,'/arc/software/25Q1/spa...


In [5]:
load_r_packages()


    an issue that caused a segfault when used with rpy2:
    https://github.com/rstudio/reticulate/pull/1188
    Make sure that you use a version of that package that includes
    the fix.
    

In [None]:
from utils import Logger
from itertools import product
load_r_packages()
param_grid = {
    'seed': range(1, 11),
    'cell_surv_ratio': [0.01, 0.03, 0.05, 0.10, 0.15]
}
keys, values = zip(*param_grid.items())
combos = [dict(zip(keys, v)) for v in product(*values)]

save_root_path = './results/revision-sim1-python/'
logger = Logger(save_path=f'{save_root_path}cell_subpopulation_logs.csv')

for i, params in enumerate(combos):
    logger.log_dict(params)
    seed = params['seed']
    cell_surv_ratio = params['cell_surv_ratio']

    print(f'Running {i+1}/{len(combos)}: seed={seed}, cell_surv_ratio={cell_surv_ratio}')

    save_path = f'./results/revision-sim1-python/ratio-{cell_surv_ratio}_seed-{seed}/'
    ro.globalenv['save_path'] = save_path
    

    if seed == 1:
        ro.r('dir.create(save_path, recursive=T)')
        simulated_base_sc_dataset(seed=seed, plot=True, cell_surv_ratio=cell_surv_ratio)
        datasets = simulated_sc_datasets(plot=True)
    else:
        simulated_base_sc_dataset(seed=seed, plot=False, cell_surv_ratio=cell_surv_ratio)
        datasets = simulated_sc_datasets(plot=False)

    adata, surv, adata_new = organize_data_for_model(datasets)
    if cell_surv_ratio < 0.05:
        entropy_threshold = 0.3
    elif cell_surv_ratio <= 0.1:
        entropy_threshold = 0.5
    else:
        entropy_threshold = 0.7

    clf_report_df, adata_new = detect_subpopulations(adata, surv, adata_new, entropy_threshold=entropy_threshold)

    clf_rst = {
        'precision': clf_report_df.loc['macro avg', 'precision'],
        'recall': clf_report_df.loc['macro avg', 'recall'],
        'f1-score': clf_report_df.loc['macro avg', 'f1-score'],
    }

    for cls in ['good.survival', 'bad.survival', 'other']:
        for metric in ['precision', 'recall', 'f1-score']:
            key = f'{cls}_{metric}'
            if cls in clf_report_df.index:
                clf_rst[key] = clf_report_df.loc[cls, metric]
            else:
                clf_rst[key] = 0.0

    logger.log_dict(clf_rst)
    logger.get_logs_df()

In [None]:
from utils import Logger
from itertools import product
load_r_packages()
param_grid = {
    'seed': [1],
    'cell_surv_ratio': [0.01, 0.03, 0.05, 0.10, 0.15]
}
keys, values = zip(*param_grid.items())
combos = [dict(zip(keys, v)) for v in product(*values)]

save_root_path = './results/revision-sim1-python/'
logger = Logger(save_path=f'{save_root_path}cv_logs.csv')

for i, params in enumerate(combos):
    logger.log_dict(params)
    seed = params['seed']
    cell_surv_ratio = params['cell_surv_ratio']

    print(f'Running {i+1}/{len(combos)}: seed={seed}, cell_surv_ratio={cell_surv_ratio}')
    
    simulated_base_sc_dataset(seed=seed, plot=False, cell_surv_ratio=cell_surv_ratio)
    datasets = simulated_sc_datasets(plot=False)

    adata, surv, adata_new = organize_data_for_model(datasets)
    if cell_surv_ratio < 0.05:
        entropy_threshold = 0.3
    elif cell_surv_ratio <= 0.1:
        entropy_threshold = 0.5
    else:
        entropy_threshold = 0.7
    cindex_results = cross_validation_samples(adata, surv, entropy_threshold=entropy_threshold)

    cindex_results = {
        'mean_cindex': cindex_results['mean_cindex'],
        'std_cindex': cindex_results['std_cindex']
    }

    logger.log_dict(cindex_results)
    logger.get_logs_df()


Running 1/4: seed=1, cell_surv_ratio=0.03

100%|██████████| 100/100 [00:00<00:00, 91980.35it/s]


fold 0, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [01:51<00:00,  1.80it/s, ae_loss=-95.2] 
Finetuning:  58%|█████▊    | 291/500 [04:06<02:56,  1.18it/s, ae_loss=-111, atten_entropy=0.271, ccindex_val=0.646, cox_loss=1.53, loss=0.418]


Early stopping with best validation ccindex: 0.6819.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.8718
testing...
c-index: 0.6162
fold 1, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [01:50<00:00,  1.81it/s, ae_loss=-77.5] 
Finetuning:  45%|████▍     | 224/500 [03:08<03:52,  1.19it/s, ae_loss=-92.4, atten_entropy=0.41, ccindex_val=0.913, cox_loss=1.47, loss=0.658] 


Early stopping with best validation ccindex: 0.9245.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9676
testing...
c-index: 0.8873
fold 2, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [01:50<00:00,  1.82it/s, ae_loss=-75.4] 
Finetuning: 100%|██████████| 500/500 [07:00<00:00,  1.19it/s, ae_loss=-93.4, atten_entropy=0.292, ccindex_val=0.756, cox_loss=1.42, loss=0.481]
  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9128
testing...
c-index: 0.7211
fold 3, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [01:50<00:00,  1.81it/s, ae_loss=-67]   
Finetuning:  98%|█████████▊| 490/500 [06:53<00:08,  1.19it/s, ae_loss=-85.8, atten_entropy=0.267, ccindex_val=0.629, cox_loss=1.49, loss=0.628]


Early stopping with best validation ccindex: 0.6415.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.8572
testing...
c-index: 0.4842
fold 4, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [01:50<00:00,  1.81it/s, ae_loss=-80.1] 
Finetuning:  93%|█████████▎| 463/500 [05:59<00:28,  1.29it/s, ae_loss=-96.8, atten_entropy=0.279, ccindex_val=0.845, cox_loss=1.42, loss=0.455]


Early stopping with best validation ccindex: 0.8559.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9406
testing...
c-index: 0.7007
mean c-index: 0.6819 ± 0.1323
Running 2/4: seed=1, cell_surv_ratio=0.05

100%|██████████| 100/100 [00:00<00:00, 84665.00it/s]


fold 0, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [01:46<00:00,  1.87it/s, ae_loss=-90.3] 
Finetuning:  42%|████▏     | 210/500 [05:16<07:17,  1.51s/it, ae_loss=-104, atten_entropy=0.498, ccindex_val=0.934, cox_loss=1.97, loss=0.93] 


Early stopping with best validation ccindex: 0.9479.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9786
testing...
c-index: 0.9399
fold 1, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:52<00:00,  1.16s/it, ae_loss=-87.8] 
Finetuning:  89%|████████▉ | 446/500 [11:55<01:26,  1.60s/it, ae_loss=-107, atten_entropy=0.479, ccindex_val=0.92, cox_loss=1.57, loss=0.503]  


Early stopping with best validation ccindex: 0.9333.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9694
testing...
c-index: 0.9317
fold 2, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [02:54<00:00,  1.15it/s, ae_loss=-103]  
Finetuning:  87%|████████▋ | 435/500 [11:23<01:42,  1.57s/it, ae_loss=-120, atten_entropy=0.441, ccindex_val=0.741, cox_loss=1.66, loss=0.459]


Early stopping with best validation ccindex: 0.7687.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9018
testing...
c-index: 0.7950
fold 3, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:52<00:00,  1.16s/it, ae_loss=-99.1] 
Finetuning:  79%|███████▉  | 395/500 [09:36<02:33,  1.46s/it, ae_loss=-117, atten_entropy=0.469, ccindex_val=0.873, cox_loss=1.83, loss=0.659]


Early stopping with best validation ccindex: 0.8851.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9439
testing...
c-index: 0.8316
fold 4, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:45<00:00,  1.13s/it, ae_loss=-101]  
Finetuning:  63%|██████▎   | 313/500 [08:23<05:00,  1.61s/it, ae_loss=-117, atten_entropy=0.499, ccindex_val=0.935, cox_loss=1.56, loss=0.384]


Early stopping with best validation ccindex: 0.9379.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9678
testing...
c-index: 0.9008
mean c-index: 0.8798 ± 0.0570
Running 3/4: seed=1, cell_surv_ratio=0.1

100%|██████████| 100/100 [00:00<00:00, 88319.73it/s]


fold 0, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:52<00:00,  1.16s/it, ae_loss=-117]  
Finetuning:  48%|████▊     | 240/500 [05:36<06:04,  1.40s/it, ae_loss=-132, atten_entropy=0.593, ccindex_val=0.93, cox_loss=1.5, loss=0.278]  


Early stopping with best validation ccindex: 0.9386.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9639
testing...
c-index: 0.9601
fold 1, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:43<00:00,  1.12s/it, ae_loss=-100]  
Finetuning:  54%|█████▍    | 269/500 [07:10<06:09,  1.60s/it, ae_loss=-116, atten_entropy=0.551, ccindex_val=0.944, cox_loss=1.59, loss=0.489]


Early stopping with best validation ccindex: 0.9541.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9811
testing...
c-index: 0.9484
fold 2, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:52<00:00,  1.16s/it, ae_loss=-115]  
Finetuning:  45%|████▍     | 224/500 [05:23<06:38,  1.44s/it, ae_loss=-129, atten_entropy=0.6, ccindex_val=0.931, cox_loss=1.67, loss=0.479]  


Early stopping with best validation ccindex: 0.9457.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9767
testing...
c-index: 0.9618
fold 3, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:37<00:00,  1.09s/it, ae_loss=-108]  
Finetuning:  63%|██████▎   | 314/500 [08:19<04:55,  1.59s/it, ae_loss=-125, atten_entropy=0.505, ccindex_val=0.945, cox_loss=1.51, loss=0.265]


Early stopping with best validation ccindex: 0.9607.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9803
testing...
c-index: 0.9303
fold 4, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:52<00:00,  1.16s/it, ae_loss=-121]  
Finetuning:  40%|███▉      | 199/500 [05:20<08:04,  1.61s/it, ae_loss=-135, atten_entropy=0.499, ccindex_val=0.947, cox_loss=1.56, loss=0.213]


Early stopping with best validation ccindex: 0.9502.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9751
testing...
c-index: 0.9470
mean c-index: 0.9495 ± 0.0113
Running 4/4: seed=1, cell_surv_ratio=0.15

100%|██████████| 100/100 [00:00<00:00, 84222.97it/s]


fold 0, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [02:59<00:00,  1.11it/s, ae_loss=-130]  
Finetuning:  82%|████████▏ | 409/500 [10:44<02:23,  1.57s/it, ae_loss=-147, atten_entropy=0.685, ccindex_val=0.949, cox_loss=1.61, loss=0.138]  


Early stopping with best validation ccindex: 0.9608.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9782
testing...
c-index: 0.9785
fold 1, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:52<00:00,  1.16s/it, ae_loss=-140]  
Finetuning:  43%|████▎     | 213/500 [05:42<07:41,  1.61s/it, ae_loss=-155, atten_entropy=0.701, ccindex_val=0.947, cox_loss=1.61, loss=0.057]   


Early stopping with best validation ccindex: 0.9617.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9854
testing...
c-index: 0.9789
fold 2, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:02<00:00,  1.10it/s, ae_loss=-109]  
Finetuning:  45%|████▌     | 226/500 [05:50<07:05,  1.55s/it, ae_loss=-125, atten_entropy=0.724, ccindex_val=0.94, cox_loss=1.61, loss=0.382] 


Early stopping with best validation ccindex: 0.9471.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9596
testing...
c-index: 0.9107
fold 3, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:52<00:00,  1.16s/it, ae_loss=-124]  
Finetuning:  57%|█████▋    | 285/500 [07:39<05:46,  1.61s/it, ae_loss=-140, atten_entropy=0.698, ccindex_val=0.948, cox_loss=1.83, loss=0.426] 


Early stopping with best validation ccindex: 0.9534.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9784
testing...
c-index: 0.8596
fold 4, train_size: 80, test_size: 20
Validation mode is enabled, will split 20% of the data for validation.


Pretraining: 100%|██████████| 200/200 [03:01<00:00,  1.10it/s, ae_loss=-118]  
Finetuning:  39%|███▉      | 196/500 [04:49<07:29,  1.48s/it, ae_loss=-131, atten_entropy=0.708, ccindex_val=0.944, cox_loss=1.66, loss=0.352]


Early stopping with best validation ccindex: 0.9547.


  adata.obsm['X_ae'] = h.cpu().detach().numpy()


Added hazard and attention to adata.obs.
Added patient_hazards to surv.
train c-index: 0.9792
testing...
c-index: 0.9788
mean c-index: 0.9413 ± 0.0486
