# Compute and Analyze Classifier Metrics After Controlling for GFP

The snakemake pipeline output the probability of 0 / 1 for each cell for each classifier. Here, we compute and save many common metrics from these probabilities. Here we analyze the classification results of each allele.

In [35]:
# imports
import os
import glob
import operator
import polars as pl
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from functools import reduce
import sys
sys.path.append("../..")
from img_utils import *

letter_dict_rev = {v: k for k, v in letter_dict.items()}
channel_dict_rev = {v: k for k, v in channel_dict.items()}
channel_list = list(channel_dict_rev.values())[:-3]

%matplotlib inline
## Disable truncation globally
# pl.Config.set_tbl_rows(20)  # Show all rows
# pl.Config.set_tbl_cols(40)  # Show all columns

## define control types
TC = ["EGFP"]
NC = ["RHEB", "MAPK9", "PRKACB", "SLIRP"]
PC = ["ALK", "ALK_Arg1275Gln", "PTK2B"]
cNC = ["Renilla"]
cPC = ["KRAS", "PTK2B", "GHSR", "ABL1", "BRD4", "OPRM1", "RB1", "ADA", "WT PMP22", "LYN", "TNF", 
       "CYP2A6", "CSK", "PAK1", "ALDH2", "CHRM3", "KCNQ2", "ALK T1151M", "PRKCE", "LPAR1", "PLP1"]

TRN_IMBAL_THRES = 3
MIN_CLASS_NUM = 2
AUROC_THRESHOLDS = [0.95] ## , 0.99

In [36]:
allele_meta_df = pl.DataFrame()
allele_meta_df_dict, img_well_qc_sum_dict = {}, {}

for bio_rep, bio_rep_batches in BIO_REP_BATCHES_DICT.items():
    for batch_id in bio_rep_batches:
        allele_meta_df_batch = pl.DataFrame()
        platemaps = [file for file in os.listdir(PLATEMAP_DIR.format(batch_id=batch_id)) if file.endswith(".txt")]
        for platemap in platemaps:
            platemap_df = pl.read_csv(os.path.join(PLATEMAP_DIR.format(batch_id=batch_id), platemap), separator="\t", infer_schema_length=100000)
            allele_meta_df_batch = pl.concat([allele_meta_df_batch, 
                                        platemap_df.filter((pl.col("node_type").is_not_null()))], # (~pl.col("node_type").is_in(["TC","NC","PC"]))&
                                        how="diagonal_relaxed").sort("plate_map_name")
            allele_meta_df_batch = allele_meta_df_batch.with_columns(pl.col("plate_map_name").alias("plate_map")) ## str.split('_').list.get(0).
            # display(allele_meta_df.head())
        allele_meta_df = pl.concat([
            allele_meta_df,
            allele_meta_df_batch
        ], how="diagonal_relaxed")#.sort("plate_map_name") ## (~pl.col("node_type").is_in(["TC","NC","PC"]))&
    allele_meta_df_dict[bio_rep] = allele_meta_df_batch

    img_well_qc_sum = pl.read_csv(f"{IMGS_QC_METRICS_DIR}/{bio_rep}/plate-well-level_img_qc_sum.csv")
    img_well_qc_sum = img_well_qc_sum.with_columns(
        pl.col("channel").replace("DAPI", "DNA").alias("channel")
    )
    img_well_qc_sum_morph = img_well_qc_sum.filter(pl.col("channel")!="GFP")
    img_well_qc_sum_morph = img_well_qc_sum_morph.group_by(["plate","well"]).agg(
        pl.col("is_bg").max().alias("is_bg"),
        pl.col("s2n_ratio").mean().alias("s2n_ratio")
    ).with_columns(pl.lit("Morph").alias("channel"))
    img_well_qc_sum = pl.concat([
        img_well_qc_sum.select(pl.col(["plate","well","channel","is_bg","s2n_ratio"])),
        img_well_qc_sum_morph.select(pl.col(["plate","well","channel","is_bg","s2n_ratio"])),
    ], how="vertical_relaxed")
    img_well_qc_sum_dict[bio_rep] = img_well_qc_sum

In [37]:
class_info_dir = "{}/{}/profiles_tcdropped_filtered_var_mad_outlier_featselect_filtcells/metrics_gfp_adj.csv"

for bio_rep, bio_rep_batches in BIO_REP_BATCHES_DICT.items():
    for batch_id in bio_rep_batches:
        if bio_rep != "2024_12_Batch_11-12":
            batch_id_tmp = batch_id + "_morph"
        class_res_df = pl.read_csv(class_info_dir.format(CLASS_ANALYSES_DIR, batch_id_tmp))
        class_res_df = class_res_df.with_columns(
            pl.when((pl.col("well_0").str.contains(r"(?:01|P|A|24)"))|(pl.col("well_1").str.contains(r"(?:01|P|A|24)")))
            .then(pl.lit(True))
            .otherwise(pl.lit(False))
            .alias("Well_On_Edge")
        )
        # display(class_res_df)
        for channel in ["GFP"]:
            class_res_df_channel = class_res_df.filter(pl.col("Metadata_Feature_Type")==channel)
            class_res_df_channel = class_res_df_channel.join(
                img_well_qc_sum.filter(pl.col("channel")==channel).select(pl.col(["plate","well","is_bg"])),
                left_on=["Plate","well_0"],
                right_on=["plate","well"]
            ).rename({"is_bg": "well_0_is_bg"})
            class_res_df_channel = class_res_df_channel.join(
                img_well_qc_sum.filter(pl.col("channel")==channel).select(pl.col(["plate","well","is_bg"])),
                left_on=["Plate","well_1"],
                right_on=["plate","well"]
            ).rename({"is_bg": "well_1_is_bg"})
            
            plate_maps = sorted(set(["_".join(pm.split("_")[:-1]) for pm in class_res_df_channel["Plate"].unique()]))
            for pm in plate_maps:
                class_res_ch_pm = class_res_df_channel.filter(pl.col("Plate").str.contains(pm))
                plates = sorted(class_res_ch_pm["Plate"].unique().to_list())
                # fig, axes = plt.subplots(2,2,figsize=(48,23)) ## sharey=True,sharex=True
                for plate in plates:
                    plate_info = class_res_ch_pm.filter(pl.col("Plate")==plate)
    
                    ## plot the ctrls alleles first
                    plate_info_ctrl = plate_info.filter(pl.col("Metadata_Control"))
                    ctrls_wells = pl.concat([plate_info_ctrl["well_0"], plate_info_ctrl["well_1"]]).unique()
                    agg_group_by_well_0 = plate_info_ctrl.group_by("well_0","allele_0").agg([
                        pl.col("AUROC").mean().alias("AUROC_Mean"),
                        pl.col("well_0_is_bg").max().alias("is_bg")
                    ]).rename(
                        {"well_0": "well", "allele_0": "allele"}
                    )#.with_columns(pl.col("Well_On_Edge").cast(pl.Boolean).alias("Well_On_Edge"))
                    agg_group_by_well_1 = plate_info_ctrl.filter(
                        (pl.col("well_1").is_in(ctrls_wells))&\
                        (~pl.col("well_1").is_in(agg_group_by_well_0["well"]))
                    ).group_by("well_1","allele_1").agg([
                        pl.col("AUROC").mean().alias("AUROC_Mean"),
                        pl.col("well_1_is_bg").max().alias("is_bg")
                    ]).rename(
                        {"well_1": "well", "allele_1": "allele"}
                    )
                    agg_group_by_well = pl.concat(
                        [
                            agg_group_by_well_0.select(pl.col("well","AUROC_Mean","allele","is_bg")),
                            agg_group_by_well_1.select(pl.col("well","AUROC_Mean","allele","is_bg"))
                        ]
                    )
                    ## plot the auroc per each plate for a platemap
                #     plot_platemap(
                #         agg_group_by_well,
                #         plate+f"_{channel}",
                #         well_pos_col="well",
                #         # this is the column to color by (categorical or continuous)
                #         value_col="AUROC_Mean",
                #         # these columns will be concatenated into the annotation text
                #         label_cols=("allele","AUROC_Mean"),
                #         ax=axes[plates.index(plate)//2, plates.index(plate)%2],
                #         value_type="continuous",   # or "continuous"
                #         continuous_cmap="vlag",  # matplotlib colormap for continuous mode
                #         categorical_colors={True: "tomato", False: "skyblue"},     # dict for categorical → color
                #         grid_square="is_bg"
                #     )
                # fig.subplots_adjust(wspace=-.55, hspace=.05)
                # plt.tight_layout()

In [38]:
metric_df_dict = {}

for bio_rep, bio_rep_batches in BIO_REP_BATCHES_DICT.items():
    for batch_id in bio_rep_batches:
        if bio_rep != "2024_12_Batch_11-12":
            batch_id_tmp = batch_id + "_morph"
            
        metrics_df = pl.read_csv(class_info_dir.format(CLASS_ANALYSES_DIR, batch_id_tmp))
        batch_id = f"B{batch_id.split('Batch_')[-1]}"
        # display(metrics_df)
        print(f"====================================={batch_id} metrics=====================================")
        metrics_df = metrics_df.with_columns(
            pl.col("Metadata_Feature_Type").alias("Classifier_type"),
            # Extract the substring that:
            #  1. Has a digit (\d) immediately before it (anchors the match at a number)
            #  2. Starts with 'A' and then as few characters as needed (A.*?), captured as group 1
            #  3. Stops right before the literal 'T'
            pl.col("Plate").str.extract(r"\d(A.*?)T", 1).alias("Allele_set"),
            pl.col("Full_Classifier_ID").str.split("_").list.last().alias("Batch")
        )
    
        metrics_df = metrics_df.join(
            img_well_qc_sum_dict[bio_rep].select(pl.col("plate", "well", "channel", "is_bg")),
            left_on=["Plate", "well_0", "Metadata_Feature_Type"],
            right_on=["plate", "well", "channel"]
        ).rename({"is_bg": "well_0_is_bg"})
        metrics_df = metrics_df.join(
            img_well_qc_sum_dict[bio_rep].select(pl.col("plate", "well", "channel", "is_bg")),
            left_on=["Plate", "well_1", "Metadata_Feature_Type"],
            right_on=["plate", "well", "channel"]
        ).rename({"is_bg": "well_1_is_bg"})
            
        print("==========================================================================")
        metric_df_dict[f"{batch_id}_met"] = metrics_df
        # print(metrics_df_thres.head())
        metrics_df_thres = metrics_df
        
        ## Must be at least min_class_num classifiers per batch
        ## Number of classifiers is the same for localization and morph, so just use morph
        classifier_count = (
            metrics_df_thres.filter(
                (~pl.col("Metadata_Control")) & \
                (~pl.col("well_0_is_bg")) & (~pl.col("well_1_is_bg")) & \
                (pl.col("Training_imbalance") < TRN_IMBAL_THRES)
            )
            .group_by(["allele_0", "Allele_set", "Batch", "allele_1"])
            .agg([pl.len().alias("Number_classifiers")])
        )
        classifier_count = classifier_count.pivot(
            index=["allele_0", "allele_1", "Allele_set"],
            on="Batch",
            values="Number_classifiers",
        )
        print(f"Total number of unique classifiers for GFP:", classifier_count.shape[0])
        print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
        print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))
        print("==========================================================================")
    
        # Must be at least min_class_num classifiers per batch
        # Number of classifiers is the same for localization and morph, so just use morph
        classifier_count = (
            metrics_df_thres.filter(
                (pl.col("Training_imbalance") < TRN_IMBAL_THRES)
                & (~pl.col("Metadata_Control"))
                & (pl.col("Classifier_type") == "GFP")
            )
            .group_by(["allele_0", "Allele_set", "Batch", "allele_1"])
            .agg([pl.len().alias("Number_classifiers")])
        )
        classifier_count = classifier_count.pivot(
            index=["allele_0", "allele_1", "Allele_set"],
            on="Batch",
            values="Number_classifiers",
        )
        # display(classifier_count)
        # break
        print(f"After filtering out classifiers with training imbalance > {TRN_IMBAL_THRES}:")
        print("Total number of unique classifiers:", classifier_count.shape)
        print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
        print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))
        print("==========================================================================")
    
        # classifier_count = classifier_count.filter(
        #     (pl.col(batch_id) >= MIN_CLASS_NUM)
        # )
        # print("After filtering out alleles with available number of classifiers < 2:")
        # print("Total number of unique classifiers:", classifier_count.shape)
        # print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
        # print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))
    
        # # filter based on this
        # keep_alleles = classifier_count.select("allele_0").to_series().unique().to_list()
        # metrics_df_thres = metrics_df_thres.filter(
        #     ~((~pl.col("Metadata_Control")) & ~pl.col("allele_0").is_in(keep_alleles))
        # )
        # Filter by imbalance and calculate mean AUROC for each batch
        metrics_wtvar = (
            (
                metrics_df_thres.filter(
                    (~pl.col("well_0_is_bg")) & (~pl.col("well_1_is_bg")) & \
                    (pl.col("Training_imbalance") < TRN_IMBAL_THRES) & (~pl.col("Metadata_Control"))
                )
            )
            .select([
                "AUROC",
                "Classifier_type",
                "Batch",
                "allele_0",
                "trainsize_0",
                "testsize_0",
                "trainsize_1",
                "testsize_1",
                "Allele_set",
                "Training_imbalance",
            ])#+[f"AUROC_thresh_{auroc_thres*100:.0f}" for auroc_thres in AUROC_THRESHOLDS])
            .group_by(["Classifier_type", "allele_0", "Allele_set", "Batch"]) ## +[f"AUROC_thresh_{auroc_thres*100:.0f}" for auroc_thres in AUROC_THRESHOLDS]
            .agg([
                pl.all()
                .exclude(["Classifier_type", "allele_0", "Allele_set", "Batch"]) ## +[f"AUROC_thresh_{auroc_thres*100:.0f}" for auroc_thres in AUROC_THRESHOLDS]
                .mean()
                .name.suffix("_mean")
            ])
        )
        metric_df_dict[f"{batch_id}_met_thres"] = metrics_df_thres
        metric_df_dict[f"{batch_id}_allele_summary"] = metrics_wtvar

        # os.makedirs(f"{CLASS_SUMMARY_DIR}/{bio_rep}/{batch_id}", exist_ok=True)
        # metrics_wtvar.write_csv(f"{CLASS_SUMMARY_DIR}/{bio_rep}/{batch_id}/metrics_summary.csv")

Total number of unique classifiers for GFP: 372
Total number of unique variant alleles: 372
Total number of unique WT genes: 77
After filtering out classifiers with training imbalance > 3:
Total number of unique classifiers: (375, 4)
Total number of unique variant alleles: 375
Total number of unique WT genes: 77
Total number of unique classifiers for GFP: 259
Total number of unique variant alleles: 259
Total number of unique WT genes: 62
After filtering out classifiers with training imbalance > 3:
Total number of unique classifiers: (263, 4)
Total number of unique variant alleles: 263
Total number of unique WT genes: 65
Total number of unique classifiers for GFP: 0
Total number of unique variant alleles: 0
Total number of unique WT genes: 0
After filtering out classifiers with training imbalance > 3:
Total number of unique classifiers: (0, 3)
Total number of unique variant alleles: 0
Total number of unique WT genes: 0
Total number of unique classifiers for GFP: 0
Total number of unique

In [39]:
allele_df = pl.DataFrame()
for key, df in metric_df_dict.items():
    if "allele_summary" in key:
        allele_df = pl.concat([allele_df, df], how="diagonal_relaxed")
        
allele_df = allele_df.filter(
    ~((pl.col("allele_0")=="AGXT_Asp201Asn")&((pl.col("Batch")=="B7")|(pl.col("Batch")=="B8"))) &
    ~((pl.col("allele_0")=="MVK_Pro288Leu")&((pl.col("Batch")=="B7")|(pl.col("Batch")=="B8"))) &
    ~((pl.col("allele_0")=="GSS_Arg125Cys")&((pl.col("Batch")=="B7")|(pl.col("Batch")=="B8"))) &
    ~((pl.col("allele_0")=="GSS_Arg125Cys")&((pl.col("Batch")=="B11")|(pl.col("Batch")=="B12"))) &
    ~((pl.col("allele_0")=="HPRT1_His204Asp")&((pl.col("Batch")=="B7")|(pl.col("Batch")=="B8"))) &
    ~((pl.col("allele_0")=="MVK_Ser329Asn")&((pl.col("Batch")=="B7")|(pl.col("Batch")=="B8"))) &
    ~((pl.col("allele_0")=="RAD51D_Arg165Gln")&((pl.col("Batch")=="B7")|(pl.col("Batch")=="B8"))) &
    ~((pl.col("allele_0")=="BCL10_Leu8Leu")&((pl.col("Batch")=="B7")|(pl.col("Batch")=="B8")))
)
allele_df.group_by("allele_0").agg(
    pl.col("AUROC_mean").mean().alias("AUROC_Mean_GFP_Adj"),
).filter(
    # pl.col("AUROC_Mean_GFP_Adj") > 0.9577,
    ~pl.col("allele_0").str.contains("KCNJ2|STXBP1")
)

allele_0,AUROC_Mean_GFP_Adj
str,f64
"""GBA_Asp176Asn""",0.857824
"""KLHL3_Cys164Phe""",0.860255
"""CCM2_Ala295Thr""",0.849189
"""CCM2_Glu290Lys""",0.807653
"""DES_Leu274Pro""",0.929777
…,…
"""PSEN1_Arg269His""",0.997795
"""COMP_Asp319Val""",0.959706
"""CDKN1A_Asp149Gly""",0.967215
"""CCM2_Ala205Thr""",0.802699


In [40]:
allele_df.group_by("allele_0").agg(
    pl.col("AUROC_mean").mean().alias("AUROC_Mean_GFP_Adj"),
).filter(
    # pl.col("AUROC_Mean_GFP_Adj") > 0.9577,
    ~pl.col("allele_0").str.contains("KCNJ2|STXBP1")
).write_csv("../outputs/2.classification_results/imaging_analyses_classification_summary_all_gfp_adj_auroc.csv")