# Analyze continuous assay prediction

In [1]:
import polars as pl


In [11]:
res_dir = "../../1_snakemake/outputs"

dino_res = f"{res_dir}/dino/mad_featselect/classifier_results/axiom_assay_predictions.parquet"
cpcnn_res = f"{res_dir}/cpcnn/mad_featselect/classifier_results/axiom_continuous_predictions.parquet"
cellprofiler_res = f"{res_dir}/cellprofiler/mad_featselect/classifier_results/axiom_assay_predictions.parquet"

dino_res_null = f"{res_dir}/dino/mad_featselect/classifier_results/axiom_assay_null.parquet"
cpcnn_res_null = f"{res_dir}/cpcnn/mad_featselect/classifier_results/axiom_continuous_null.parquet"
cellprofiler_res_null = f"{res_dir}/cellprofiler/mad_featselect/classifier_results/axiom_assay_null.parquet"

In [10]:
dino

Metadata_Plate,Metadata_Well,Metadata_Compound,Metadata_Log10Conc,Observed,Predicted,Variable_Name
str,str,str,f64,f32,f32,str
"""plate_41002687""","""J24""","""DMSO""",0.0,0.054173,0.087573,"""Metadata_ldh_normalized"""
"""plate_41002687""","""B03""","""DMSO""",0.0,0.012341,-0.014785,"""Metadata_ldh_normalized"""
"""plate_41002687""","""N17""","""DMSO""",0.0,-0.018336,-0.030356,"""Metadata_ldh_normalized"""
"""plate_41002687""","""L03""","""DMSO""",0.0,0.020707,-0.021875,"""Metadata_ldh_normalized"""
"""plate_41002687""","""D13""","""DMSO""",0.0,-0.035069,-0.013877,"""Metadata_ldh_normalized"""
…,…,…,…,…,…,…
"""plate_41002960""","""N23""","""DMSO""",0.0,548.0,590.220398,"""Metadata_Count_Cells"""
"""plate_41002960""","""A11""","""DMSO""",0.0,708.0,789.383667,"""Metadata_Count_Cells"""
"""plate_41002960""","""B05""","""DMSO""",0.0,770.0,819.089294,"""Metadata_Count_Cells"""
"""plate_41002960""","""L04""","""DMSO""",0.0,585.0,721.137634,"""Metadata_Count_Cells"""


In [21]:
dino = pl.read_parquet(dino_res).with_columns(
    pl.lit("Actual").alias("Model_type"),
    pl.lit("DINO").alias("Representation"),
)
cpcnn = pl.read_parquet(cpcnn_res).with_columns(
    pl.lit("Actual").alias("Model_type"),
    pl.lit("CP-CNN").alias("Representation"),
)
cellprofiler = pl.read_parquet(cellprofiler_res).with_columns(
    pl.lit("Actual").alias("Model_type"),
    pl.lit("CellProfiler").alias("Representation"),
)

dino_null = pl.read_parquet(dino_res_null).with_columns(
    pl.lit("Random").alias("Model_type"),
    pl.lit("DINO").alias("Representation"),
)
cpcnn_null = pl.read_parquet(cpcnn_res_null).with_columns(
    pl.lit("Random").alias("Model_type"),
    pl.lit("CP-CNN").alias("Representation"),
)
cellprofiler_null = pl.read_parquet(cellprofiler_res_null).with_columns(
    pl.lit("Random").alias("Model_type"),
    pl.lit("CellProfiler").alias("Representation"),
)

res = pl.concat([
    dino, cpcnn, cellprofiler,
    dino_null, cpcnn_null, cellprofiler_null,
], how="vertical")

In [22]:
# Append cell count
cc = pl.read_parquet(f"{res_dir}/dino/mad_featselect/profiles/mad_featselect.parquet").select(
    ["Metadata_Plate", "Metadata_Well", "Metadata_Count_Cells"]
)

res = res.join(cc, on=["Metadata_Plate", "Metadata_Well"])

In [15]:
from scipy.stats import linregress

def compute_r2_pvalue(x, y):
    slope, intercept, r_value, p_value, std_err = linregress(x, y)
    r_squared = r_value ** 2
    return r_squared, p_value

In [25]:
results = []
for name, group in res.group_by(["Variable_Name", "Model_type", "Representation"]):
    var_name, model_type, representation = name
    
    # Step 1: R² and p-value between Observed and Predicted
    r2_obs_pred, pval_obs_pred = compute_r2_pvalue(group["Observed"].to_list(), group["Predicted"].to_list())
    
    # Step 2: R² and p-value between Metadata_Count_Cells and Predicted
    r2_count_pred, pval_count_pred = compute_r2_pvalue(group["Metadata_Count_Cells"].to_list(), group["Observed"].to_list())
    
    # Store results
    results.append({
        "Variable_Name": var_name,
        "Model_type": model_type,
        "Representation": representation,
        "R2_Observed_vs_Predicted": r2_obs_pred,
        "pval_Observed_vs_Predicted": pval_obs_pred,
        "R2_CountCells_vs_Observed": r2_count_pred,
        "pval_CountCells_vs_Observed": pval_count_pred
    })

# Convert results to a Polars DataFrame for easy viewing
results_df = pl.DataFrame(results).sort(["Variable_Name", "Model_type", "Representation"])
results_df

Variable_Name,Model_type,Representation,R2_Observed_vs_Predicted,pval_Observed_vs_Predicted,R2_CountCells_vs_Observed,pval_CountCells_vs_Observed
str,str,str,f64,f64,f64,f64
"""Metadata_Count_Cells""","""Actual""","""CP-CNN""",0.886887,0.0,1.0,0.0
"""Metadata_Count_Cells""","""Actual""","""CellProfiler""",0.927043,0.0,1.0,0.0
"""Metadata_Count_Cells""","""Actual""","""DINO""",0.925722,0.0,1.0,0.0
"""Metadata_Count_Cells""","""Random""","""CP-CNN""",0.000161,0.06321,0.000002,0.823299
"""Metadata_Count_Cells""","""Random""","""CellProfiler""",0.000145,0.078262,0.000023,0.481131
…,…,…,…,…,…,…
"""Metadata_mtt_normalized""","""Actual""","""CellProfiler""",0.736985,0.0,0.374796,0.0
"""Metadata_mtt_normalized""","""Actual""","""DINO""",0.729268,0.0,0.375802,0.0
"""Metadata_mtt_normalized""","""Random""","""CP-CNN""",0.000012,0.615625,4.9091e-9,0.991812
"""Metadata_mtt_normalized""","""Random""","""CellProfiler""",7.8692e-7,0.896623,2.0139e-9,0.994756
