# Subsampling of Tahoe Dataset 
## Predicting perturbation responses for unseen cell-types (context transfer)

# Imports

In [8]:
import io
import pandas as pd
import scanpy as sc
import pyarrow.dataset as ds
import gcsfs
import os
import numpy as np
import ast

# Functions

In [9]:
# List Files in Google Cloud Storage Bucket

def print_gc_files(fs):
    # List the top‐level contents of the bucket
    print(fs.ls('arc-ctc-tahoe100'))

    # Drill down into a folder:
    print(fs.ls('arc-ctc-tahoe100/2025-02-25'))

    # List everything recursively:
    all_files = fs.find('arc-ctc-tahoe100')
    print(len(all_files), "items found")
    for file in all_files:
        print(file)


# Calculate Total Remote File Size in GCS

def calc_remote_file_size(fs, h5ad_file):
    all_paths = fs.find(h5ad_file)

    # Sum up file sizes
    total_bytes = 0
    for path in all_paths:
        print(path)
        info = fs.info(path)
        if info.get('type') == 'file':
            total_bytes += info.get('size', 0)

    # Print result in bytes and GB
    print(f"Total size: {total_bytes:,} bytes")
    print(f"Which is roughly {total_bytes/1e9:.2f} GB")
    
# Parse Drug Condition Identifier   

def parse_condition(s):
    parsed = ast.literal_eval(s)
    if len(parsed) == 0:
        raise ValueError(f"Condition {s} is empty")
    compound, dose, unit = parsed[0]
    x = f"{compound}_{dose}{unit}"
    x = x.replace('.', '')
    return x

# Extract Numerical Dose from Condition String

def get_dose_from_condition(s):
    parsed = ast.literal_eval(s)
    if len(parsed) == 0:
        raise ValueError(f"Condition {s} is empty")
    compound, dose, unit = parsed[0]
    return float(dose)

# Read the data and subsample

In [10]:
# Download the h5ad file from google storage # mounts Google Cloud Storage via gcsfs and downloads the file
local_file  = "plate9_full.h5ad"
if not os.path.exists(local_file):
    # Initialize GCS file system for reading data from GCS
    fs = gcsfs.GCSFileSystem()
    
    # File to Download and use
    h5ad_file = 'arc-ctc-tahoe100/2025-02-25/h5ad/plate9_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad'
    
    # print_gc_files(fs)
    # calc_remote_file_size(fs, h5ad_file)

    print(f"Downloading {h5ad_file} → {local_file} …")
    fs.get(h5ad_file, local_file)
    print("Download complete.")
else:
    print(f"Local file '{local_file}' already exists; skipping download.")    # Otherwise, if the backed file already exists, it skips the download.


# Open in backed mode (only obs. and var. loaded into RAM, the rest of the expression matrix is on the disc)
print("Reading all data...")
adata_full = sc.read_h5ad(local_file, backed="r")  
print("-> Total cells on Plate 9:", adata_full.n_obs)  


# Explore observation metadata 
print(adata_full.obs.head())
print(adata_full.obs.columns)

#  Identify control vs. treated cells
print("------ All drug labels -------")
all_drug_labels = adata_full.obs['drug'].unique()
# for index, label in enumerate(all_drug_labels):
#     print(f"{index+1}) {label}")

# Build a mask that picks up either 'DMSO_TF' or any plain 'DMSO':
ctrl_mask = adata_full.obs['drug'].str.contains(
    r"\bDMSO_TF\b|\bDMSO\b",  # matches "DMSO_TF" or "DMSO"
    case=False,
    na=False
)

# Prints total / control / treated counts and the percentage of controls.
print(f"Total cells: {len(ctrl_mask)}")
print(f"Control cells: {ctrl_mask.sum()}")
print(f"Treated cells: {(~ctrl_mask).sum()}")
print(f"Percentage controls: {ctrl_mask.mean()*100:.2f}%")
# control_percentage = ctrl_mask.mean()

# The treated cells are simply the inverse of the control mask:
treated_mask = ~ctrl_mask


# Collect all unique drug‐dose conditions (excluding controls) 
drug_conds = adata_full.obs.loc[treated_mask, "drugname_drugconc"].unique().tolist()  #Extracts all non-control conditions


# Stratified Subsampling of Treated Cells per Condition
N_per_drug = 500                                    #<--------choose the number N of cells per drug you want to keep!
expected_total_cells = N_per_drug * len(drug_conds)    
actual_total_cells = 0
N_ctrl = N_per_drug * len(drug_conds) 
N_ctrl = int(np.ceil(N_ctrl))
print(f"Sampling {N_ctrl} controls...")

# Stratified sampling for treated cells 
print("-> Sampling conditions...")
sampled_idxs = []
for index, cond in enumerate(drug_conds):
    idxs_tuple = np.where(adata_full.obs["drugname_drugconc"] == cond)
    idxs = idxs_tuple[0]
    n_to_pick  = min(len(idxs), N_per_drug)
    pick = np.random.choice(idxs, size=n_to_pick, replace=False)   #(replace=False) to keep unique cells
    sampled_idxs.append(pick)
    actual_total_cells += n_to_pick
    print(f'{index+1:2d}) {cond:>{45}} n_of_cells: {len(idxs):7d} n_to_pick: {n_to_pick}...')
print("Done.")
print(f"Total cells sampled: {actual_total_cells}, expected: {expected_total_cells}")

# Random sampling of DMSO controls 
print("-> Sampling controls...")
ctrl_idxs = np.where(ctrl_mask)[0]
n_to_pick = min(len(ctrl_idxs), N_ctrl)
ctrl_pick = np.random.choice(ctrl_idxs, size=n_to_pick, replace=False)
sampled_idxs.append(ctrl_pick)
print("Done.")
print(f"Total control sampled: {len(ctrl_pick)}, expected: {N_ctrl}")

# Combine all sampled indices and load subset into memory 
print("-> Loading subset into memory...")          
all_idxs   = np.concatenate(sampled_idxs)
adata_sub  = adata_full[all_idxs, :].to_memory()
print("Subset shape:", adata_sub.shape)    


Local file 'plate9_full.h5ad' already exists; skipping download.
Reading all data...
-> Total cells on Plate 9: 5866669
                       sample  gene_count  tscp_count  mread_count  \
BARCODE_SUB_LIB_ID                                                   
01_001_019-lib_1585  smp_2263        1079        1464         1725   
01_001_031-lib_1585  smp_2263        1122        1476         1771   
01_001_132-lib_1585  smp_2263         977        1335         1577   
01_001_133-lib_1585  smp_2263        2022        3251         3869   
01_001_159-lib_1585  smp_2263        1178        1560         1823   

                                                     drugname_drugconc  \
BARCODE_SUB_LIB_ID                                                       
01_001_019-lib_1585  [('Sivelestat (sodium tetrahydrate)', 5.0, 'uM')]   
01_001_031-lib_1585  [('Sivelestat (sodium tetrahydrate)', 5.0, 'uM')]   
01_001_132-lib_1585  [('Sivelestat (sodium tetrahydrate)', 5.0, 'uM')]   
01_001_133-lib_1585

# Data Preprocessing for CPA

### Preprocessing AnnData for CPA: Raw Counts, Dose Extraction, Filtering, Normalization, and HVG Selection



In [11]:
# Create a Copy & Save Raw Counts for normalizing and log-transforming
print(adata_sub)
adata_sub_log = adata_sub.copy()
adata_sub_log.layers['counts'] = adata_sub_log.X.copy()

print("\nCells per drug condition:")
print(adata_sub_log.obs["drugname_drugconc"].value_counts())

# Define CPA-Specific Observation Keys
adata_sub_log.obs["condition_ID"] = adata_sub_log.obs["drugname_drugconc"].apply(parse_condition)
print("\nCondition IDs:\n", adata_sub_log.obs["condition_ID"])

# Extract numeric dose (µM) from strings
adata_sub_log.obs["dose"] = adata_sub_log.obs["drugname_drugconc"].apply(get_dose_from_condition)

print("\nDoses (µM):\n", adata_sub_log.obs["dose"])

adata_sub_log.obs['log_dose'] = np.log10(adata_sub_log.obs['dose'] + 1e-6)
print("\nLog - Doses (µM):\n", adata_sub_log.obs["log_dose"])

# Filter out cells with low number of counts
before = adata_sub_log.n_obs
sc.pp.filter_cells(adata_sub_log, min_counts=100)
after = adata_sub_log.n_obs
print(f"➡️ Filtered cells based on min_counts ≥ 100:")
print(f"   • Before filtering: {before} cells")
print(f"   • After filtering:  {after} cells")
print(f"   • Cells removed:    {before - after}")

print("\n ➜ Sample of remaining cells:")
print(adata_sub_log.obs.head())

# Normalize counts to 10k UMIs per cell
print("-> Normalizing...")
sc.pp.normalize_total(adata_sub_log, target_sum=1e4, exclude_highly_expressed=True)

# Log-transform
print("-> Log transforming...")
sc.pp.log1p(adata_sub_log)

# Select HVGs (2 000 genes)
print("-> Highly Variable Genes...")
sc.pp.highly_variable_genes(adata_sub_log, n_top_genes=2000, subset=True)


AnnData object with n_obs × n_vars = 93000 × 62710
    obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate'

Cells per drug condition:
drugname_drugconc
[('DMSO_TF', 0.0, 'uM')]                         33600
[('Trametinib (DMSO_TF solvate)', 5.0, 'uM')]    12900
[('Adagrasib', 0.05, 'uM')]                        500
[('Afatinib', 5.0, 'uM')]                          500
[('Almonertinib (mesylate)', 5.0, 'uM')]           500
                                                 ...  
[('Topotecan (hydrochloride)', 5.0, 'uM')]         500
[('Tranilast', 5.0, 'uM')]                         500
[('Verapamil', 5.0, 'uM')]                         500
[('Verteporfin', 5.0, 'uM')]                       500
[('crizotinib', 5.0, 'uM')]                        500
Name: count, Length: 95, dtype: int64

Condition IDs:
 BARCODE_SUB_LIB_ID
01_166_159-lib

### Marking Control vs Stimulated Cells and Computing Differential Expression (DE) per Cell Line

In [12]:
# Mark DMSO_TF as control vs stimulated cells
adata_sub_log.obs['condition'] = adata_sub_log.obs['condition_ID'].apply(lambda x: 'ctrl' if 'DMSO_TF_' in str(x) else 'stimulated')

# Observe condition counts  and columns
print("\n=== condition value counts ===")
print(adata_sub_log.obs['condition'].value_counts())
print(adata_sub_log.obs.columns.tolist())
print(adata_sub_log.obs.head())

# Create Combined Covariate Column
adata_sub_log.obs['cov_cond'] = adata_sub_log.obs['cell_line'].astype(str) + '_' + adata_sub_log.obs['condition'].astype(str)

# Store Raw Counts for DE Analysis               
adata_sub_log.raw = adata_sub_log   

# Prepare for DE Analysis
adata_sub_log.uns['rank_genes_groups'] = {}
cell_lines = adata_sub_log.obs['cell_line'].unique()

# Loop Over Cell Lines and Compute DEGs
for cell in cell_lines:
    adata_cell = adata_sub_log[adata_sub_log.obs['cell_line'] == cell].copy()

    # Skips cell lines without both control and stimulated groups.
    cond_counts = adata_cell.obs['condition'].value_counts()
    if not {'ctrl', 'stimulated'}.issubset(cond_counts.index):
        print(f"⏭ Skipping {cell}: missing ctrl or stimulated")
        continue
    if cond_counts['ctrl'] < 2 or cond_counts['stimulated'] < 2:
        print(f" Skipping {cell}: not enough cells per group (ctrl={cond_counts['ctrl']}, stim={cond_counts['stimulated']})")
        continue

    #adata_cell.raw = adata_cell

    # Run DE analysis:
    sc.tl.rank_genes_groups(
        adata_cell,
        groupby='condition',
        reference='ctrl',
        method='t-test_overestim_var',
        use_raw=True
    )

    key = f"{cell}_stimulated"   # Save DE results for this cell line
    adata_sub_log.uns['rank_genes_groups'][key] = {k: v['stimulated'] for k, v in adata_cell.uns['rank_genes_groups'].items() if isinstance(v, np.ndarray)}

print(" Finished computing per-cell-line DEGs.")


# Save downsampled, preprocessed Data  
save_path = "plate9_preprocessed_for_CPA_pred_cell_1.h5ad"               
adata_sub_log.write_h5ad(save_path)
print(f"Wrote fully preprocessed AnnData to: {save_path}")


=== condition value counts ===
condition
stimulated    59400
ctrl          33600
Name: count, dtype: int64
['sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate', 'condition_ID', 'dose', 'log_dose', 'n_counts', 'condition']
                       sample  gene_count  tscp_count  mread_count  \
BARCODE_SUB_LIB_ID                                                   
01_166_159-lib_1645  smp_2263         647         827          966   
01_001_005-lib_2504  smp_2263        2167        3243         3820   
01_015_109-lib_1647  smp_2263        1027        1438         1687   
01_093_083-lib_1592  smp_2263         791        1010         1172   
01_149_149-lib_1632  smp_2263         873        1294         1498   

                                                     drugname_drugconc  \
BARCODE_SUB_LIB_ID                                            

In [17]:
print(adata_sub_log.obs[['cell_line', 'cell_name']].head(50))



                     cell_line   cell_name
BARCODE_SUB_LIB_ID                        
01_166_159-lib_1645  CVCL_0332     HS-578T
01_001_005-lib_2504  CVCL_1495   NCI-H1792
01_015_109-lib_1647  CVCL_0480      PANC-1
01_093_083-lib_1592  CVCL_1239          H4
01_149_149-lib_1632  CVCL_0152      AsPC-1
01_052_137-lib_1617  CVCL_1119     CFPAC-1
01_015_002-lib_1593  CVCL_1055       A-427
01_060_106-lib_1670  CVCL_0292       HCT15
01_164_152-lib_1669  CVCL_1693      SHP-77
01_062_162-lib_1647  CVCL_0504         RKO
01_085_079-lib_2507  CVCL_0293     HEC-1-A
01_031_148-lib_1597  CVCL_1731      SW 900
01_132_144-lib_1643  CVCL_0480      PANC-1
01_002_191-lib_2608  CVCL_0428  MIA PaCa-2
01_177_033-lib_1639  CVCL_1097         C32
01_029_148-lib_1650  CVCL_0152      AsPC-1
01_119_077-lib_1626  CVCL_C466  hTERT-HPNE
01_074_127-lib_1667  CVCL_1056        A498
01_025_087-lib_1613  CVCL_0504         RKO
01_096_120-lib_2504  CVCL_1381    LOX-IMVI
01_124_188-lib_1665  CVCL_0546       SW480
01_119_104-