# Imports

In [1]:
import io
import pandas as pd
import scanpy as sc
import pyarrow.dataset as ds
import gcsfs
import os
import numpy as np
import ast

# Functions

In [2]:
def print_gc_files(fs):
    # List the top‐level contents of the bucket
    print(fs.ls('arc-ctc-tahoe100'))

    # If you want to drill down into a folder, e.g. the date directory:
    print(fs.ls('arc-ctc-tahoe100/2025-02-25'))

    # To list everything recursively, you can use:
    all_files = fs.find('arc-ctc-tahoe100')
    print(len(all_files), "items found")
    for file in all_files:
        print(file)


def calc_remote_file_size(fs, h5ad_file):
    all_paths = fs.find(h5ad_file)

    # Sum up file sizes (ignoring directories)
    total_bytes = 0
    for path in all_paths:
        print(path)
        info = fs.info(path)
        if info.get('type') == 'file':
            total_bytes += info.get('size', 0)

    # Print result in bytes and GB
    print(f"Total size: {total_bytes:,} bytes")
    print(f"Which is roughly {total_bytes/1e9:.2f} GB")
    
def parse_condition(s):
    parsed = ast.literal_eval(s)
    if len(parsed) == 0:
        raise ValueError(f"Condition {s} is empty")
    compound, dose, unit = parsed[0]
    x = f"{compound}_{dose}{unit}"
    x = x.replace('.', '')
    return x

def get_dose_from_condition(s):
    parsed = ast.literal_eval(s)
    if len(parsed) == 0:
        raise ValueError(f"Condition {s} is empty")
    compound, dose, unit = parsed[0]
    return float(dose)

# Read the data and subsample

In [None]:
# Download the h5ad file from google storage # mounts Google Cloud Storage via gcsfs and downloads the file
local_file  = "plate9_full.h5ad"
if not os.path.exists(local_file):
    # Initialize GCS file system for reading data from GCS
    fs = gcsfs.GCSFileSystem()
    
    # File to Download and USE
    h5ad_file = 'arc-ctc-tahoe100/2025-02-25/h5ad/plate9_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad'
    
    # print_gc_files(fs)
    # calc_remote_file_size(fs, h5ad_file)

    print(f"Downloading {h5ad_file} → {local_file} …")
    fs.get(h5ad_file, local_file)
    print("Download complete.")
else:
    print(f"Local file '{local_file}' already exists; skipping download.")    # Otherwise, if the backed file already exists, it skips the download.


# ─── 1) Open in backed mode (only obs. and var. loaded into RAM, the rest of the expression matrix is on the disc, important if there is not enoufh RAM) ────────────────────────
print("Reading all data...")
adata_full = sc.read_h5ad(local_file, backed="r")  
print("-> Total cells on Plate 9:", adata_full.n_obs)  


# ─── 2) Identify DMSO_TF (vehicle) controls vs. drug‐treated cells ────────────
print(adata_full.obs.head())                #observing the obs table and column
print(adata_full.obs.columns)

# a) Look at all the unique drug labels and confirm how controls appear (we are looking for DMSO):
print("------ All drug labels -------")
all_drug_labels = adata_full.obs['drug'].unique()


# b) Build a mask that picks up either 'DMSO_TF' or any plain 'DMSO' variant:
ctrl_mask = adata_full.obs['drug'].str.contains(
    r"\bDMSO_TF\b|\bDMSO\b",  # matches "DMSO_TF" or "DMSO"
    case=False,
    na=False
)
# The ctrl_mask is a boolean array that is True for cells that match the condition, and False for cells that don't match. Prints total / control / treated counts and the percentage of controls.
print(f"Total cells: {len(ctrl_mask)}")
print(f"Control cells: {ctrl_mask.sum()}")
print(f"Treated cells: {(~ctrl_mask).sum()}")
print(f"Percentage controls: {ctrl_mask.mean()*100:.2f}%")
control_percentage = ctrl_mask.mean()
# c) The treated cells are simply the inverse of your control mask:
treated_mask = ~ctrl_mask



# ─── 3) Collect all unique drug‐dose conditions (excluding controls) ───────────
drug_conds = adata_full.obs.loc[treated_mask, "drugname_drugconc"].unique().tolist()  #Extracts all non-control conditions



# ─── 4) Decide sampling scheme ───────────────────────────────────────────────
N_per_drug = 500                      # cells per drug              
expected_total_cells = N_per_drug * len(drug_conds)    
actual_total_cells = 0
# keep the same percentage of controls- Compute expected total treated cells, and then compute how many controls (N_ctrl) to sample so the overall control fraction in your subset matches the full plate.
N_ctrl = N_per_drug * len(drug_conds) ######* control_percentage
# ceil = round up to nearest integer
N_ctrl = int(np.ceil(N_ctrl))

print(f"Sampling {N_ctrl} controls...")

# ─── 5) Stratified sampling of treatment conditions ──────────────────────────
print("-> Sampling conditions...")
sampled_idxs = []
for index, cond in enumerate(drug_conds):
    idxs_tuple = np.where(adata_full.obs["drugname_drugconc"] == cond)
    idxs = idxs_tuple[0]
    # Always sample without replacement (replace=False) so that we keep unique cells
    # That means that we may not have exactly N_per_drug cells for each condition
    n_to_pick  = min(len(idxs), N_per_drug)
    pick = np.random.choice(idxs, size=n_to_pick, replace=False)
    sampled_idxs.append(pick)
    actual_total_cells += n_to_pick
    print(f'{index+1:2d}) {cond:>{45}} n_of_cells: {len(idxs):7d} n_to_pick: {n_to_pick}...')
print("Done.")
print(f"Total cells sampled: {actual_total_cells}, expected: {expected_total_cells}")

# ─── 6) Random sampling of DMSO controls ──────────────────────────────────────
print("-> Sampling controls...")
ctrl_idxs = np.where(ctrl_mask)[0]
n_to_pick = min(len(ctrl_idxs), N_ctrl)
ctrl_pick = np.random.choice(ctrl_idxs, size=n_to_pick, replace=False)
sampled_idxs.append(ctrl_pick)
print("Done.")
print(f"Total control sampled: {len(ctrl_pick)}, expected: {N_ctrl}")

# ─── 7) Combine all sampled indices and load subset into memory ───────────────
print("-> Loading subset into memory...")          
all_idxs   = np.concatenate(sampled_idxs)
adata_sub  = adata_full[all_idxs, :].to_memory()
print("Subset shape:", adata_sub.shape)   # e.g., (num_drugs*N_per_drug + N_ctrl, n_genes)


Local file 'plate9_full.h5ad' already exists; skipping download.
Reading all data...
-> Total cells on Plate 9: 5866669
                       sample  gene_count  tscp_count  mread_count  \
BARCODE_SUB_LIB_ID                                                   
01_001_019-lib_1585  smp_2263        1079        1464         1725   
01_001_031-lib_1585  smp_2263        1122        1476         1771   
01_001_132-lib_1585  smp_2263         977        1335         1577   
01_001_133-lib_1585  smp_2263        2022        3251         3869   
01_001_159-lib_1585  smp_2263        1178        1560         1823   

                                                     drugname_drugconc  \
BARCODE_SUB_LIB_ID                                                       
01_001_019-lib_1585  [('Sivelestat (sodium tetrahydrate)', 5.0, 'uM')]   
01_001_031-lib_1585  [('Sivelestat (sodium tetrahydrate)', 5.0, 'uM')]   
01_001_132-lib_1585  [('Sivelestat (sodium tetrahydrate)', 5.0, 'uM')]   
01_001_133-lib_1585

# Data Preprocessing for CPA

In [None]:
print(adata_sub)
adata_sub_log = adata_sub.copy()
adata_sub_log.layers['counts'] = adata_sub_log.X.copy()

print("\nCells per drug condition:")
print(adata_sub_log.obs["drugname_drugconc"].value_counts())

# ─── 1) Define CPA-required obs keys ─────────────────────────────────────────
adata_sub_log.obs["condition_ID"] = adata_sub_log.obs["drugname_drugconc"].apply(parse_condition)
print("\nCondition IDs:\n", adata_sub_log.obs["condition_ID"])

# Extract numeric dose (µM) from strings like "DrugX_5.0uM"
adata_sub_log.obs["dose"] = adata_sub_log.obs["drugname_drugconc"].apply(get_dose_from_condition)

print("\nDoses (µM):\n", adata_sub_log.obs["dose"])

adata_sub_log.obs['log_dose'] = np.log10(adata_sub_log.obs['dose'] + 1e-6)
print("\nLog - Doses (µM):\n", adata_sub_log.obs["log_dose"])



#Add cell_type + cov_drug_dose
print("\n-> Adding cell_type and cov_drug_dose column...")
adata_sub_log.obs["cell_type"] = adata_sub_log.obs["cell_line"]
adata_sub_log.obs["cov_drug_dose"] = (
    adata_sub_log.obs["condition_ID"].astype(str)
    + "|" +
    adata_sub_log.obs["log_dose"].astype(str)
)
print("-> Done.")

# 1) Filter out cells with low number of counts
# before = adata_sub_log.n_obs
# sc.pp.filter_cells(adata_sub_log, min_counts=100)
# after = adata_sub_log.n_obs
# print(f" Filtered cells based on min_counts ≥ 100:")
# print(f"   • Before filtering: {before} cells")
# print(f"   • After filtering:  {after} cells")
# print(f"   • Cells removed:    {before - after}")

# print("\n ➜ Sample of remaining cells:")
# print(adata_sub_log.obs.head())

# 3) Normalize counts to 10k UMIs per cell
print("-> Normalizing...")
sc.pp.normalize_total(adata_sub_log, target_sum=1e4, exclude_highly_expressed=True)
    
# 4) Log-transform
print("-> Log transforming...")
sc.pp.log1p(adata_sub_log)

# 5) Select HVGs (≈2 000 genes)
print("-> Highly Variable Genes...")
sc.pp.highly_variable_genes(adata_sub_log, n_top_genes=2000, subset=True)

# 6) Scale & PCA
#sc.pp.scale(adata_sub_log, max_value=10)
#sc.tl.pca(adata_sub_log, n_comps=50)



AnnData object with n_obs × n_vars = 93000 × 62710
    obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate'

Cells per drug condition:
drugname_drugconc
[('DMSO_TF', 0.0, 'uM')]                         33451
[('Trametinib (DMSO_TF solvate)', 5.0, 'uM')]    13049
[('(S)-Crizotinib', 5.0, 'uM')]                    500
[('Nevirapine', 5.0, 'uM')]                        500
[('Posaconazole', 5.0, 'uM')]                      500
                                                 ...  
[('Docetaxel (Trihydrate)', 5.0, 'uM')]            500
[('Diphenhydramine', 5.0, 'uM')]                   500
[('Dinaciclib', 5.0, 'uM')]                        500
[('Dimethyl fumarate', 5.0, 'uM')]                 500
[('crizotinib', 5.0, 'uM')]                        500
Name: count, Length: 95, dtype: int64

Condition IDs:
 BARCODE_SUB_LIB_ID
01_039_015-lib

In [9]:
# For RDkit use


#7) Compute DE per condition vs. DMSO_TF               #run this only for RDkit use, DE computation for cell line prediction happens before training
adata_sub_log.raw = adata_sub_log                               
print("-> Compute DE per condition...")
sc.tl.rank_genes_groups(
    adata_sub_log,
    groupby='cov_drug_dose',       
    reference='DMSO_TF_00uM|-6.0',     # control id
    method='t-test_overestim_var',
    key_added='rank_genes_groups',
    use_raw=True
)


# ─── 11) Save the downsampled, balanced subset ───────────────────────────────
# 8) Now save your CPA-ready subset
save_path = "plate9_preprocessed_for_CPA_2.h5ad"               #"plate*_preprocessed_for_CPA.h5ad"  plate9_preprocessed_for_CPA_2 has groupby='cov_drug_dose
adata_sub_log.write_h5ad(save_path)
print(f"Wrote fully preprocessed AnnData to: {save_path}")

-> Compute DE per condition...


  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group

Wrote fully preprocessed AnnData to: plate9_preprocessed_for_CPA_2.h5ad
