In [8]:
import anndata as ad
import numpy as np
import pandas as pd
import os

In [39]:
CELL_TYPE_MAPPING = {
    "granulocytes": "Granulocytes",
    "granulocytes (cd45-cd66+)": "Granulocytes",
    "neutrophils": "Granulocytes",
    "bcells": "Bcells",
    "b-cells (cd19+cd3-)": "Bcells",
    "b-cells": "Bcells",
    "ncmc": "ncMCs",
    "non-classical monocytes (cd14-cd16+)": "ncMCs",
    "ncmcs": "ncMCs",
    "cmc": "cMCs",
    "classical monocytes (cd14+cd16-)": "cMCs",
    "cmcs": "cMCs",
    "mdsc": "MDSCs",
    "mdscs (lin-cd11b-cd14+hladrlo)": "MDSCs",
    "mdscs": "MDSCs",
    "mdc": "mDCs",
    "mdcs (cd11c+hladr+)": "mDCs",
    "mdcs": "mDCs",
    "pdc": "pDCs",
    "pdcs(cd123+hladr+)": "pDCs",
    "pdcs": "pDCs",
    "intmc": "intMCs",
    "intermediate monocytes (cd14+cd16+)": "intMCs",
    "intmcs": "intMCs",
    "cd56brightcd16negnk": "CD56hiCD16negNK",
    "cd56+cd16- nk cells": "CD56hiCD16negNK",
    "cd16- cd56+ nk": "CD56hiCD16negNK",
    "cd56dimcd16posnk": "CD56loCD16posNK",
    "cd56locd16+nk cells": "CD56loCD16posNK",
    "cd16+ cd56lo nk": "CD56loCD16posNK",
    "nk": "NKcells",
    "nk cells (cd7+)": "NKcells",
    "cd4tcells": "CD4Tcells",
    "cd4 t-cells": "CD4Tcells",
    "cd4 tcells": "CD4Tcells",
    "tregs": "Tregs",
    "treg": "Tregs",
    "tregs (cd25+foxp3+)": "Tregs",
    "cd8tcells": "CD8Tcells",
    "cd8 t-cells": "CD8Tcells",
    "cd8 tcells": "CD8Tcells",
    "cd4negcd8neg": "CD4negCD8negTcells",
    "cd8-cd4- t-cells": "CD4negCD8negTcells",
    "cd4- cd8- t-cells": "CD4negCD8negTcells"
}

def normalize_cell_type(cell):
    key = cell.strip().lower()
    if key not in CELL_TYPE_MAPPING:
        raise AssertionError(f"Unrecognized cell type: {cell}")
    return CELL_TYPE_MAPPING[key]

STIM_MAPPING = {
    "unstim": "Unstim",
    "tnfa": "TNFa",
    "lps": "LPS",
    "p. gingivalis": "LPS",
    "ifna": "IFNa"
}

def normalize_stim(stim):
    key = stim.strip().lower()
    if key not in STIM_MAPPING:
        raise AssertionError(f"Unrecognized stimulation: {stim}")
    return STIM_MAPPING[key]

def normalize_marker(raw_marker):
    marker = raw_marker.strip().lower()
    # Remove leading "p" characters (except for "pp38")
    while marker.startswith("p"):
        marker = marker[1:]
    if marker == "creb":
        return "pCREB"
    elif marker in ["erk", "erk12"]:
        return "pERK"
    elif marker == "ikb":
        return "IkB"
    elif marker in ["mk2", "mapkapk2"]:
        return "pMK2"
    elif marker == "nfkb":
        return "pNFkB"
    elif marker in ["38", "p38"]:
        return "pp38"
    elif marker == "s6":
        return "pS6"
    elif marker == "stat1":
        return "pSTAT1"
    elif marker == "stat3":
        return "pSTAT3"
    elif marker == "stat5":
        return "pSTAT5"
    elif marker == "stat6":
        return "pSTAT6"
    elif marker == "hladr":
        return "HLADR"
    elif marker == "cd25":
        return "CD25"
    else:
        raise AssertionError(f"Unrecognized marker: {raw_marker}")

############################
# Global constants
############################

# The 13 canonical markers
MARKERS = ["pCREB", "pERK", "IkB", "pMK2", "pNFkB", "pp38", "pS6",
           "pSTAT1", "pSTAT3", "pSTAT5", "pSTAT6", "HLADR", "CD25"]

# List of 15 cell types (exactly as used in your prediction script)
CELL_TYPES = [
    'Granulocytes_(CD45-CD66+)',
    'B-Cells_(CD19+CD3-)',
    'Classical_Monocytes_(CD14+CD16-)',
    'MDSCs_(lin-CD11b-CD14+HLADRlo)',
    'mDCs_(CD11c+HLADR+)',
    'pDCs(CD123+HLADR+)',
    'Intermediate_Monocytes_(CD14+CD16+)',
    'Non-classical_Monocytes_(CD14-CD16+)',
    'CD56+CD16-_NK_Cells',
    'CD56loCD16+NK_Cells',
    'NK_Cells_(CD7+)',
    'CD4_T-Cells',
    'Tregs_(CD25+FoxP3+)',
    'CD8_T-Cells',
    'CD8-CD4-_T-Cells'
]

############################
# Helper functions for loading & processing AnnData files
############################

def load_anndata_to_df(anndata_path):
    """
    Load an AnnData file and return a DataFrame containing expression values (for MARKERS)
    plus all obs columns. Marker, cell type and stim names are normalized.
    """
    try:
        adata = ad.read_h5ad(anndata_path)
    except Exception as e:
        raise AssertionError(f"Error reading file {anndata_path}: {e}")
    df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs.index)
    # Rename marker columns if they can be normalized (only keep canonical markers)
    new_cols = {}
    for col in df.columns:
        try:
            new_name = normalize_marker(col.split("_")[1])
            if new_name in MARKERS:
                new_cols[col] = new_name
        except Exception as e:
            continue
    df = df.rename(columns=new_cols)
    # Append obs columns
    for col in adata.obs.columns:
        df[col] = adata.obs[col]
    # Normalize cell type if present
    if "cell_type" in df.columns:
        df["cell_type"] = df["cell_type"].apply(normalize_cell_type)
    # Normalize stim if present
    if "stim" in df.columns:
        df["stim"] = df["stim"].apply(normalize_stim)
    return df

def compute_group_medians(df, markers=MARKERS):
    """
    Melt the DataFrame so each row is one marker value and then compute the median
    per combination of patient, cell_type, and marker.
    """
    present_markers = [m for m in markers if m in df.columns]
    if len(present_markers) != len(markers):
        missing = set(markers) - set(present_markers)
        raise AssertionError(f"Missing marker columns: {missing} in file with columns {list(df.columns)}")
    # Identify non-marker columns to use as id_vars
    id_vars = [col for col in df.columns if col not in markers]
    df_melt = df.melt(id_vars=id_vars, value_vars=markers, var_name="marker", value_name="value")
    medians = df_melt.groupby(["patient", "cell_type", "marker"])["value"].median().reset_index()
    medians = medians.rename(columns={"value": "median", "patient": "sampleID"})
    return medians
    
def load_and_compute_medians(anndata_path):
    df = load_anndata_to_df(anndata_path)
    medians = compute_group_medians(df)
    return medians

In [15]:
ad.read('pred_surge_corrected.h5ad').obs

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


Unnamed: 0,drug,cell_type,patient,stim,state
0,TNFa,Bcells,HCCC_BL,TNFa,predicted
1,TNFa,Bcells,HCCC_BL,TNFa,predicted
2,TNFa,Bcells,HCCC_BL,TNFa,predicted
3,TNFa,Bcells,HCCC_BL,TNFa,predicted
4,TNFa,Bcells,HCCC_BL,TNFa,predicted
...,...,...,...,...,...
8464,TNFa,Bcells,HCBB_IDX,TNFa,true_corrected
8465,TNFa,Bcells,HCBB_IDX,TNFa,true_corrected
8466,TNFa,Bcells,HCBB_IDX,TNFa,true_corrected
8467,TNFa,Bcells,HCBB_IDX,TNFa,true_corrected


In [46]:
merged = pd.merge(pred_medians, baseline_medians, on=["sampleID", "cell_type", "marker"], suffixes=("_stim", "_baseline"), how="outer", indicator=True)
print(merged[merged["_merge"] != "both"])


Empty DataFrame
Columns: [sampleID, cell_type, marker, median_stim, stim_stim, median_baseline, stim_baseline, _merge]
Index: []


In [59]:
stim='TNFa'
baseline_store = {}
pred_rows = []
if stim == "LPS":
    stim_folder = "P._gingivalis"
else:
    stim_folder = stim
cohort="surge"
cell = "Bcells"
pred_path = f"./pred_{cohort}_corrected.h5ad"
if not os.path.exists(pred_path):
    raise AssertionError(f"Missing prediction file: {pred_path}")

df_all = load_anndata_to_df(pred_path)
df_baseline = df_all[(df_all["state"] == "true_corrected") & (df_all["drug"] == "Unstim")]
baseline_medians = compute_group_medians(df_baseline)
df_pred = df_all[(df_all["state"] == "predicted")& (df_all["drug"] == stim)]
pred_medians = compute_group_medians(df_pred)
patients_presents=baseline_medians['median'].notna()
pred_medians = pred_medians[patients_presents]
baseline_medians = baseline_medians[patients_presents]
#patients_with_nan = baseline_medians[baseline_medians['median'].isna()]["sampleID"].unique()

#print("Patients avec des NaN dans les marqueurs (df_baseline) :")
#print(patients_with_nan)
baseline_medians, pred_medians

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  medians = df_melt.groupby(["patient", "cell_type", "marker"])["value"].median().reset_index()
  medians = df_melt.groupby(["patient", "cell_type", "marker"])["value"].median().reset_index()


(     sampleID cell_type  marker    median
 13    HCBB_BL    Bcells    CD25 -0.019437
 14    HCBB_BL    Bcells   HLADR  5.657043
 15    HCBB_BL    Bcells     IkB  0.474065
 16    HCBB_BL    Bcells   pCREB  0.296267
 17    HCBB_BL    Bcells    pERK  0.038416
 ..        ...       ...     ...       ...
 164  HCKK_IDX    Bcells  pSTAT1  0.097737
 165  HCKK_IDX    Bcells  pSTAT3  0.050058
 166  HCKK_IDX    Bcells  pSTAT5  0.257048
 167  HCKK_IDX    Bcells  pSTAT6  0.000193
 168  HCKK_IDX    Bcells    pp38  0.032914
 
 [117 rows x 4 columns],
      sampleID cell_type  marker    median
 13    HCBB_BL    Bcells    CD25  0.044681
 14    HCBB_BL    Bcells   HLADR  5.622467
 15    HCBB_BL    Bcells     IkB  0.511490
 16    HCBB_BL    Bcells   pCREB  0.340450
 17    HCBB_BL    Bcells    pERK  0.044005
 ..        ...       ...     ...       ...
 164  HCKK_IDX    Bcells  pSTAT1  0.106708
 165  HCKK_IDX    Bcells  pSTAT3  0.078087
 166  HCKK_IDX    Bcells  pSTAT5  0.253626
 167  HCKK_IDX    Bcells  p

In [88]:
stim='P. gingivalis'
baseline_store = {}
pred_rows = []

if stim == "P. gingivalis":
    stim_folder = "P._gingivalis"
else:
    stim_folder = stim
cohort="perio"
cell = "Granulocytes_(CD45-CD66+)"
pred_path = f"./pred_{cohort}_corrected.h5ad"

out_dir = "/home/groups/gbrice/ptb-drugscreen/ot/cellot/cellwise/results_dbl_corr"
#os.makedirs(out_dir, exist_ok=True)

cohorts_and_stims = {
"perio": ["P. gingivalis"]
}


if stim == "P. gingivalis":
    stim_folder = "P._gingivalis"
else:
    stim_folder = stim

pred_path = f"perio_dbl_corrected2/perio_data_sherlock{stim_folder}_{cell}.h5ad"
if not os.path.exists(pred_path):
    raise AssertionError(f"Missing prediction file: {pred_path}")

df_all = load_anndata_to_df(pred_path)
# Filtrer les lignes baseline (true unstim corrigé)
df_baseline = df_all[df_all["drug"] == "Unstim"]
baseline_medians = compute_group_medians(df_baseline)
baseline_medians["stim"] = "Unstim"

if cell in baseline_store:
    merged_base = pd.merge(baseline_store[cell], baseline_medians, on=["sampleID", "cell_type", "marker"], suffixes=("_old", "_new"))
    if not np.allclose(merged_base["median_old"], merged_base["median_new"], atol=1e-12):
        raise AssertionError(f"Inconsistent baseline medians for cell type '{cell}'")
else:
    baseline_store[cell] = baseline_medians

# Filtrer les lignes prédictions
df_pred = df_all[df_all["drug"] == stim]
pred_medians = compute_group_medians(df_pred)
pred_medians["stim"] = stim

patients_presents=baseline_medians['median'].notna()
pred_medians = pred_medians[patients_presents]
baseline_medians = baseline_medians[patients_presents]
print(pred_medians)
print(baseline_medians)
# Fusionner pour faire la soustraction
merged = pd.merge(pred_medians, baseline_medians, on=["sampleID", "cell_type", "marker"], suffixes=("_stim", "_baseline"))
if merged.empty:
    raise AssertionError(f"Merge failed for {cohort}, {stim}, {cell}")
if not np.allclose(merged["median_baseline"], baseline_medians["median"], atol=1e-12):
    raise AssertionError(f"Mismatch in baseline medians for {cohort}, {stim}, {cell}")

merged["median_diff"] = merged["median_stim"] - merged["median_baseline"]
pred_final = merged[["sampleID", "cell_type", "marker"]].copy()
pred_final["stim"] = stim
pred_final["median"] = merged["median_diff"]
pred_rows.append(pred_final)

baseline_all = pd.concat(list(baseline_store.values()), ignore_index=True)
baseline_all = baseline_all.drop_duplicates(subset=["sampleID", "cell_type", "marker", "stim"])

if pred_rows:
    pred_all = pd.concat(pred_rows, ignore_index=True)
else:
    pred_all = pd.DataFrame(columns=["sampleID", "cell_type", "marker", "stim", "median"])

final_df = pd.concat([baseline_all, pred_all], ignore_index=True)
final_df = final_df.rename(columns={"cell_type": "population"})
final_df = final_df[["sampleID", "population", "marker", "stim", "median"]]

out_path = os.path.join(out_dir, f"{cohort}_{stim}_transformed.csv")
final_df.to_csv(out_path, index=False)
print(f"Saved results for cohort '{cohort}', stim '{stim}' to {out_path}")

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  medians = df_melt.groupby(["patient", "cell_type", "marker"])["value"].median().reset_index()
  medians = df_melt.groupby(["patient", "cell_type", "marker"])["value"].median().reset_index()


    sampleID     cell_type  marker    median           stim
0       PO1   Granulocytes    CD25  0.009572  P. gingivalis
1       PO1   Granulocytes   HLADR  0.283246  P. gingivalis
2       PO1   Granulocytes     IkB  0.351163  P. gingivalis
3       PO1   Granulocytes   pCREB  2.038272  P. gingivalis
4       PO1   Granulocytes    pERK  1.013392  P. gingivalis
..       ...           ...     ...       ...            ...
138    PO30   Granulocytes  pSTAT1  0.459453  P. gingivalis
139    PO30   Granulocytes  pSTAT3  0.390468  P. gingivalis
140    PO30   Granulocytes  pSTAT5  0.848648  P. gingivalis
141    PO30   Granulocytes  pSTAT6  0.105065  P. gingivalis
142    PO30   Granulocytes    pp38  2.105501  P. gingivalis

[143 rows x 5 columns]
    sampleID     cell_type  marker    median    stim
0       PO1   Granulocytes    CD25  0.005938  Unstim
1       PO1   Granulocytes   HLADR  0.297587  Unstim
2       PO1   Granulocytes     IkB  0.953270  Unstim
3       PO1   Granulocytes   pCREB  0.762365

OSError: Cannot save file into a non-existent directory: '/home/groups/gbrice/ptb-drugscreen/ot/cellot/cellwise/results_dbl_corr'

In [65]:
stim='TNFa'
baseline_store = {}
pred_rows = []
if stim == "LPS":
    stim_folder = "P._gingivalis"
else:
    stim_folder = stim
cohort="surge"
cell = "Bcells"
pred_path = f"./pred_{cohort}_corrected.h5ad"
if not os.path.exists(pred_path):
    raise AssertionError(f"Missing prediction file: {pred_path}")

df_all = load_anndata_to_df(pred_path)
print(df_all["state"].value_counts(dropna=False))
print(df_all[df_all["state"] == "predicted"].shape)
print(set(df_all["state"].unique()))
print(df_all[df_all["state"] == "predicted"].head())
print(df_all[df_all["state"] == "true_corrected"].head())

# Filtrer les lignes baseline (true unstim corrigé)
df_baseline = df_all[(df_all["state"] == "true_corrected") & (df_all["drug"] == "Unstim")]
baseline_medians = compute_group_medians(df_baseline)
baseline_medians["stim"] = "Unstim"

if cell in baseline_store:
    merged_base = pd.merge(baseline_store[cell], baseline_medians, on=["sampleID", "cell_type", "marker"], suffixes=("_old", "_new"))
    if not np.allclose(merged_base["median_old"], merged_base["median_new"], atol=1e-12):
        raise AssertionError(f"Inconsistent baseline medians for cell type '{cell}'")
else:
    baseline_store[cell] = baseline_medians

# Filtrer les lignes prédictions
df_pred = df_all[(df_all["state"] == "predicted")& (df_all["drug"] == stim)]
pred_medians = compute_group_medians(df_pred)
pred_medians["stim"] = stim
patients_presents=baseline_medians['median'].notna()
pred_medians = pred_medians[patients_presents]
baseline_medians = baseline_medians[patients_presents]

# Fusionner pour faire la soustraction
merged = pd.merge(pred_medians, baseline_medians, on=["sampleID", "cell_type", "marker"], suffixes=("_stim", "_baseline"))
print('pred_medians',pred_medians.head())
print('baseline_medians',baseline_medians.head())

#print('pred_med',pred_medians.head())
#print('baseline_med',baseline_medians.head())
if merged.empty:
    raise AssertionError(f"Merge failed for {cohort}, {stim}, {cell}")
if not np.allclose(merged["median_baseline"], baseline_medians["median"], atol=1e-12):
    raise AssertionError(f"Mismatch in baseline medians for {cohort}, {stim}, {cell}")

merged["median_diff"] = merged["median_stim"] - merged["median_baseline"]
pred_final = merged[["sampleID", "cell_type", "marker"]].copy()
pred_final["stim"] = stim
pred_final["median"] = merged["median_diff"]
pred_rows.append(pred_final)

baseline_all = pd.concat(list(baseline_store.values()), ignore_index=True)
baseline_all = baseline_all.drop_duplicates(subset=["sampleID", "cell_type", "marker", "stim"])

if pred_rows:
    pred_all = pd.concat(pred_rows, ignore_index=True)
else:
    pred_all = pd.DataFrame(columns=["sampleID", "cell_type", "marker", "stim", "median"])

final_df = pd.concat([baseline_all, pred_all], ignore_index=True)
final_df = final_df.rename(columns={"cell_type": "population"})
final_df = final_df[["sampleID", "population", "marker", "stim", "median"]]

patients = final_df[final_df["stim"] == "Unstim"]["sampleID"].unique()
expected_rows = len(patients) * len(CELL_TYPES) * len(MARKERS)
actual_rows = final_df[final_df["stim"] == "Unstim"].shape[0]
#if actual_rows != expected_rows:
    #raise AssertionError(f"Expected {expected_rows} baseline rows, got {actual_rows} for cohort {cohort}")

#out_path = os.path.join(out_dir, f"{cohort}_predicted_transformed.csv")
#final_df.to_csv(out_path, index=False)
#print(f"Saved results for cohort '{cohort}' to {out_path}")

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


state
true_corrected    104130
predicted          60458
Name: count, dtype: int64
(60458, 18)
{'true_corrected', 'predicted'}
      pCREB      pERK       IkB      pMK2     pNFkB      pp38       pS6  \
0  0.316666  0.114594  1.390323  2.167221  1.596827  0.971373  1.413708   
1  0.098386  0.013828  1.076254  1.006284  0.705473  0.203828  2.162868   
2  0.085149 -0.001163  0.700696  1.058345  1.309270  0.848566  0.517530   
3  0.629129  0.127166  0.978627  2.900368  0.826653  0.039490  1.016654   
4  0.064694  0.376558  1.126306  1.090124  0.114210  0.035125  0.385365   

     pSTAT1    pSTAT3    pSTAT5    pSTAT6     HLADR      CD25  drug cell_type  \
0  0.114148  0.082816  0.249317  0.007869  5.088123  0.281713  TNFa    Bcells   
1  0.096308  0.068358  0.259687 -0.007051  6.143575  0.563467  TNFa    Bcells   
2  0.099697  0.068444  0.247778  0.083037  5.433863  0.013237  TNFa    Bcells   
3  0.104914  0.810044  0.243995  0.084184  5.186887  0.090779  TNFa    Bcells   
4  0.100643  0.509

  medians = df_melt.groupby(["patient", "cell_type", "marker"])["value"].median().reset_index()
  medians = df_melt.groupby(["patient", "cell_type", "marker"])["value"].median().reset_index()


In [12]:
print(len(patients),len(CELL_TYPES),len(MARKERS))

18 15 13


In [11]:
final_df

Unnamed: 0,sampleID,population,marker,stim,median
0,HCAA_BL,CD8Tcells,CD25,Unstim,0.117656
1,HCAA_BL,CD8Tcells,HLADR,Unstim,0.200606
2,HCAA_BL,CD8Tcells,IkB,Unstim,0.669407
3,HCAA_BL,CD8Tcells,pCREB,Unstim,0.243447
4,HCAA_BL,CD8Tcells,pERK,Unstim,0.398922
...,...,...,...,...,...
463,HCKK_IDX,CD8Tcells,pSTAT1,TNFa,0.003394
464,HCKK_IDX,CD8Tcells,pSTAT3,TNFa,-0.046704
465,HCKK_IDX,CD8Tcells,pSTAT5,TNFa,-0.025158
466,HCKK_IDX,CD8Tcells,pSTAT6,TNFa,-0.111774


In [None]:
def main():
    out_dir = "./"
    os.makedirs(out_dir, exist_ok=True)

    cohorts_and_stims = {
        "surge": ["TNFa", "LPS"]
    }

    for cohort, stims in cohorts_and_stims.items():
        pred_rows = []
        baseline_store = {}

        for stim in stims:
            for cell in CELL_TYPES:
                # Nouveau chemin unique vers le fichier contenant tout
                if stim == "LPS":
                    stim_folder = "P._gingivalis"
                else:
                    stim_folder = stim

                pred_path = f"/home/groups/gbrice/ptb-drugscreen/ot/cellot/results/perio_{cohort}_dbl_corr_training/{stim_folder}_{cell}/model-cellot/pred_{cohort}_corrected.h5ad"
                if not os.path.exists(pred_path):
                    raise AssertionError(f"Missing prediction file: {pred_path}")

                df_all = load_anndata_to_df(pred_path)

                # Filtrer les lignes baseline (true unstim corrigé)
                df_baseline = df_all[(df_all["state"] == "true_corrected") & (df_all["drug"] == "Unstim")]
                baseline_medians = compute_group_medians(df_baseline)
                baseline_medians["stim"] = "Unstim"

                if cell in baseline_store:
                    merged_base = pd.merge(baseline_store[cell], baseline_medians, on=["sampleID", "cell_type", "marker"], suffixes=("_old", "_new"))
                    if not np.allclose(merged_base["median_old"], merged_base["median_new"], atol=1e-12):
                        raise AssertionError(f"Inconsistent baseline medians for cell type '{cell}'")
                else:
                    baseline_store[cell] = baseline_medians

                # Filtrer les lignes prédictions
                df_pred = df_all[df_all["state"] == "predicted"]
                pred_medians = compute_group_medians(df_pred)
                pred_medians["stim"] = stim

                # Fusionner pour faire la soustraction
                merged = pd.merge(pred_medians, baseline_medians, on=["sampleID", "cell_type", "marker"], suffixes=("_stim", "_baseline"))
                if merged.empty:
                    raise AssertionError(f"Merge failed for {cohort}, {stim}, {cell}")
                if not np.allclose(merged["median_baseline"], baseline_medians["median"], atol=1e-12):
                    raise AssertionError(f"Mismatch in baseline medians for {cohort}, {stim}, {cell}")

                merged["median_diff"] = merged["median_stim"] - merged["median_baseline"]
                pred_final = merged[["sampleID", "cell_type", "marker"]].copy()
                pred_final["stim"] = stim
                pred_final["median"] = merged["median_diff"]
                pred_rows.append(pred_final)

        baseline_all = pd.concat(list(baseline_store.values()), ignore_index=True)
        baseline_all = baseline_all.drop_duplicates(subset=["sampleID", "cell_type", "marker", "stim"])

        if pred_rows:
            pred_all = pd.concat(pred_rows, ignore_index=True)
        else:
            pred_all = pd.DataFrame(columns=["sampleID", "cell_type", "marker", "stim", "median"])

        final_df = pd.concat([baseline_all, pred_all], ignore_index=True)
        final_df = final_df.rename(columns={"cell_type": "population"})
        final_df = final_df[["sampleID", "population", "marker", "stim", "median"]]

        patients = final_df[final_df["stim"] == "Unstim"]["sampleID"].unique()
        expected_rows = len(patients) * len(CELL_TYPES) * len(MARKERS)
        actual_rows = final_df[final_df["stim"] == "Unstim"].shape[0]
        if actual_rows != expected_rows:
            raise AssertionError(f"Expected {expected_rows} baseline rows, got {actual_rows} for cohort {cohort}")

        out_path = os.path.join(out_dir, f"{cohort}_predicted_transformed.csv")
        final_df.to_csv(out_path, index=False)
        print(f"Saved results for cohort '{cohort}' to {out_path}")