# Visualization Variant Painting Images and Cells

In [22]:
import os
import glob
import polars as pl
import matplotlib as mpl
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.io import imread
from tqdm import tqdm
import re
import sys
import subprocess

sys.path.append("../..")
from img_utils import *
from display_img import *

## 1. Loading meta data and variant classification

In [23]:
OUT_IMG_DIR = f"../../2.snakemake_pipeline/outputs/visualize_imgs"
OUT_CELL_DIR = f"../../2.snakemake_pipeline/outputs/visualize_cells"

### 1.1 Read in the meta data

In [24]:
allele_meta_df, img_well_qc_sum_df = pl.DataFrame(), pl.DataFrame()
allele_meta_df_dict, img_well_qc_sum_dict = {}, {}

for bio_rep, bio_rep_batches in BIO_REP_BATCHES_DICT.items():
    for batch_id in bio_rep_batches:
        allele_meta_df_batch = pl.DataFrame()
        platemaps = [file for file in os.listdir(PLATEMAP_DIR.format(batch_id=batch_id)) if file.endswith(".txt")]
        for platemap in platemaps:
            platemap_df = pl.read_csv(os.path.join(PLATEMAP_DIR.format(batch_id=batch_id), platemap), separator="\t", infer_schema_length=100000)
            allele_meta_df_batch = pl.concat([allele_meta_df_batch, 
                                        platemap_df.filter((pl.col("node_type").is_not_null()))], # (~pl.col("node_type").is_in(["TC","NC","PC"]))&
                                        how="diagonal_relaxed").sort("plate_map_name")
            allele_meta_df_batch = allele_meta_df_batch.with_columns(pl.col("plate_map_name").alias("plate_map")) ## str.split('_').list.get(0).
            # display(allele_meta_df.head())
        allele_meta_df = pl.concat([
            allele_meta_df,
            allele_meta_df_batch
        ], how="diagonal_relaxed")#.sort("plate_map_name") ## (~pl.col("node_type").is_in(["TC","NC","PC"]))&
    allele_meta_df_dict[bio_rep] = allele_meta_df_batch

    img_well_qc_sum = pl.read_csv(f"{IMGS_QC_METRICS_DIR}/{bio_rep}/plate-well-level_img_qc_sum.csv")
    img_well_qc_sum = img_well_qc_sum.with_columns(
        pl.col("channel").replace("DAPI", "DNA").alias("channel")
    )
    img_well_qc_sum_morph = img_well_qc_sum.filter(pl.col("channel")!="GFP")
    img_well_qc_sum_morph = img_well_qc_sum_morph.group_by(["plate","well"]).agg(
        pl.col("is_bg").max().alias("is_bg"),
        pl.col("s2n_ratio").mean().alias("s2n_ratio")
    ).with_columns(pl.lit("Morph").alias("channel"))
    img_well_qc_sum = pl.concat([
        img_well_qc_sum.select(pl.col(["plate","well","channel","is_bg","s2n_ratio"])),
        img_well_qc_sum_morph.select(pl.col(["plate","well","channel","is_bg","s2n_ratio"])),
    ], how="vertical_relaxed")
    img_well_qc_sum_dict[bio_rep] = img_well_qc_sum

In [25]:
# auroc_df = pl.read_csv(f"{CLASS_SUMMARY_DIR}/imaging_analyses_classification_summary_all.csv")
# auroc_df.unique("gene_allele")

auroc_df = pl.read_csv(f"../outputs/2.classification_results/imaging_analyses_classification_summary_all.csv", 
                       infer_schema_length=100000, separator=",")
auroc_df.unique("gene_allele")

## pillar results
# auroc_df = pl.read_csv(f"/home/shenrunx/igvf/varchamp/2025_Pillar_VarChAMP/2_individual_assay_analyses/imaging/3_outputs/pillar_img_overlapped_gene_variants.csv", 
#                        infer_schema_length=100000, separator=",").with_columns(
#                            pl.lit("2025_01_Batch_13-14").alias("Metadata_Bio_Batch")
#                        )
# auroc_df.unique("gene_allele")

symbol,gene_allele,Metadata_Bio_Batch,Altered_95th_perc_both_batches_GFP,Altered_95th_perc_both_batches_DNA,Altered_95th_perc_both_batches_Mito,Altered_95th_perc_both_batches_AGP,Altered_95th_perc_both_batches_Morph,Altered_99th_perc_both_batches_GFP,Altered_99th_perc_both_batches_DNA,Altered_99th_perc_both_batches_Mito,Altered_99th_perc_both_batches_AGP,Altered_99th_perc_both_batches_Morph,AUROC_Mean_GFP,AUROC_Mean_DNA,AUROC_Mean_Mito,AUROC_Mean_AGP,AUROC_Mean_Morph,AUROC_BioRep1_Morph,AUROC_BioRep1_AGP,AUROC_BioRep1_Mito,AUROC_BioRep1_DNA,AUROC_BioRep1_GFP,AUROC_BioRep2_Morph,AUROC_BioRep2_AGP,AUROC_BioRep2_Mito,AUROC_BioRep2_DNA,AUROC_BioRep2_GFP
str,str,str,bool,bool,bool,bool,bool,bool,bool,bool,bool,bool,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""LITAF""","""LITAF_Pro39Ser""","""2025_03_Batch_15-16""",false,false,false,false,false,false,false,false,false,false,0.926616,0.679637,0.777447,0.843655,0.666559,0.789873,0.896789,0.830513,0.783924,0.892523,0.543244,0.790521,0.724382,0.575349,0.960708
"""GBA""","""GBA_ASP492VAL""","""2025_06_Batch_18-19""",true,true,false,true,true,false,true,false,false,false,0.979159,0.936748,0.937798,0.97844,0.971709,0.976899,0.986891,0.929917,0.931687,0.974761,0.96652,0.96999,0.945679,0.941809,0.983557
"""BST1""","""BST1_Arg160Gln""","""2025_06_Batch_18-19""",true,true,true,true,false,false,false,false,false,false,0.976165,0.882559,0.97601,0.950532,0.891639,0.889801,0.943894,0.973908,0.831941,0.965802,0.893478,0.957169,0.978111,0.933176,0.986528
"""RAD51D""","""RAD51D_Arg165Gln""","""2025_01_Batch_13-14""",false,false,false,false,false,false,false,false,false,false,0.941599,0.839206,0.897381,0.864691,0.874168,0.816743,0.808744,0.849891,0.805973,0.922273,0.931593,0.920638,0.944872,0.872439,0.960925
"""PHYH""","""PHYH_Pro29Ser""","""2025_03_Batch_15-16""",true,true,true,true,false,true,false,false,true,false,0.996655,0.893055,0.989204,0.97469,0.518785,0.464114,0.968008,0.987556,0.806196,0.99602,0.573456,0.981373,0.990852,0.979914,0.997291
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""GFAP""","""GFAP_Glu371Val""","""2025_03_Batch_15-16""",false,false,false,false,false,false,false,false,false,false,0.939703,0.720653,0.85466,0.866453,0.807977,0.796074,0.863472,0.837184,0.747863,0.937377,0.81988,0.869434,0.872136,0.693443,0.94203
"""FGD4""","""FGD4_Arg194His""","""2024_01_Batch_7-8""",false,false,false,false,false,false,false,false,false,false,0.571888,0.529995,0.568074,0.615584,0.566349,0.474366,0.62047,0.557388,0.450001,0.553471,0.658331,0.610699,0.578759,0.609989,0.590304
"""SFTPA2""","""SFTPA2_Gly231Val""","""2025_03_Batch_15-16""",true,false,true,true,false,false,false,false,false,false,0.970028,0.493889,0.926204,0.942286,0.500055,0.613853,0.952075,0.936355,0.624767,0.964566,0.386258,0.932498,0.916052,0.36301,0.975491
"""ALAS2""","""ALAS2_Asp153Val""","""2024_01_Batch_7-8""",false,false,false,false,false,false,false,false,false,false,0.593106,0.64851,0.562089,0.44825,0.558997,0.558997,0.44825,0.562089,0.64851,0.593106,,,,,


## 2. Plot Variant Painting Well Images

In [26]:
def get_allele_batch(allele, score_df=auroc_df):
    return score_df.filter(pl.col("gene_allele")==allele)["Metadata_Bio_Batch"].to_list()[0]


def save_allele_imgs(variant, feat, auroc_df=auroc_df, display=False, save_img=False):
    bio_rep = get_allele_batch(variant)
    auroc_df_batch = auroc_df.with_columns(
        pl.col(f"AUROC_Mean_{feat}").alias("AUROC_Mean"),
        pl.col(f"gene_allele").alias("allele_0")
    )
    ref_allele = variant.split("_")[0]
    ref_wells = allele_meta_df_dict[bio_rep].filter(pl.col("gene_allele")==ref_allele)["imaging_well"].to_list()
    var_wells = allele_meta_df_dict[bio_rep].filter(pl.col("gene_allele")==variant)["imaging_well"].to_list()
    target_file = [f for f in os.listdir(f"{OUT_IMG_DIR}/{bio_rep}") if f.startswith(f"{variant}_{feat}")]
    if target_file:
        print(target_file, "exists.")
        output_dir = ""
        if not display:
            return None

    if save_img:
        output_dir = f"{OUT_IMG_DIR}/{bio_rep}"
        print(f"Img output at {output_dir}")
    else:
        output_dir = ""

    if bio_rep != "2024_12_Batch_11-12":
        if len(ref_wells)==1 and len(var_wells)==1:
            plot_allele(allele_meta_df_dict[bio_rep],
                            variant=variant, sel_channel=feat, 
                            auroc_df=auroc_df_batch, 
                            plate_img_qc=img_well_qc_sum_dict[bio_rep], 
                            site="05", max_intensity=0.99, 
                            display=display,
                            imgs_dir=TIFF_IMGS_DIR, 
                            output_dir=output_dir)
        else:
            for ref_well in ref_wells:
                for var_well in var_wells:
                    plot_allele(allele_meta_df_dict[bio_rep],
                                variant=variant, sel_channel=feat, 
                                auroc_df=auroc_df_batch, 
                                plate_img_qc=img_well_qc_sum_dict[bio_rep], 
                                site="05", max_intensity=0.99, 
                                display=display,
                                ref_well=[ref_well], 
                                var_well=[var_well],
                                imgs_dir=TIFF_IMGS_DIR, 
                                output_dir=output_dir)
    else:
        if len(ref_wells)==4 and len(var_wells)==4:
            plot_allele_single_plate(allele_meta_df_dict[bio_rep], ##.filter(pl.col("plate_map_name").str.contains("B13")
                                     variant=variant, sel_channel=feat, 
                                     auroc_df=auroc_df_batch, 
                                     plate_img_qc=img_well_qc_sum_dict[bio_rep], 
                                     site="05", max_intensity=0.99, 
                                     display=display,
                                     imgs_dir=TIFF_IMGS_DIR, 
                                     output_dir=output_dir)
        else:
            ref_wells_idx = len(ref_wells) // 4
            var_wells_idx = len(var_wells) // 4
            for rw_idx in range(ref_wells_idx):
                for vw_idx in range(var_wells_idx):
                    plot_allele_single_plate(allele_meta_df_dict[bio_rep], ##.filter(pl.col("plate_map_name").str.contains("B13")
                                     variant=variant, sel_channel=feat, 
                                     auroc_df=auroc_df_batch, 
                                     plate_img_qc=img_well_qc_sum_dict[bio_rep], 
                                     site="05", max_intensity=0.99,
                                     ref_well=ref_wells[rw_idx*4:rw_idx*4+4],
                                     var_well=var_wells[vw_idx*4:vw_idx*4+4],
                                     display=display,
                                     imgs_dir=TIFF_IMGS_DIR, 
                                     output_dir=output_dir)

In [20]:
allele_list = auroc_df.filter(pl.col("symbol")=="CCM2")["gene_allele"].unique() #["FOXP3_Leu242Pro", "TPM1_Asn279His"] #auroc_df.filter(pl.col("symbol")=="MLH1")["gene_allele"].unique()
# allele_list

# for variant in tqdm(allele_list):
#     for feat in ["AGP","Mito","GFP"]:
#         # print(variant)
#         save_allele_imgs(variant, feat, display=False, save_img=True)

In [29]:
def plot_allele_separate_plot(pm, variant, sel_channel, plate_img_qc, auroc_df=None, site="05", ref_well=[], var_well=[], max_intensity=0.99, display=False, imgs_dir="", output_dir=""):
    assert imgs_dir != "", "Image directory has to be input!"
    plt.clf()
    cmap = channel_to_cmap(sel_channel)
    channel = channel_dict[sel_channel]
    if auroc_df is not None:
        auroc = auroc_df.filter(pl.col("allele_0")==variant)["AUROC_Mean"].mean()
    else:
        auroc = ""
    
    ## get the number of wells/images per allele
    plate_map = pm.filter(pl.col("gene_allele") == variant).select("plate_map_name").to_pandas().values.flatten()
    wt = variant.split("_")[0]
    wt_wells = pm.filter(pl.col("gene_allele") == wt).select("imaging_well").to_pandas().values.flatten()
    var_wells = pm.filter(pl.col("gene_allele") == variant).select("imaging_well").to_pandas().values.flatten()
    if ref_well:
        wt_wells = [well for well in wt_wells if well in ref_well]
    if var_well:
        var_wells = [well for well in var_wells if well in var_well]
    
    # if len(wt_wells) > 1:
    #     # Get coordinates of wells
    #     well_coords = [well_to_coordinates(w) for w in set([ref_well_pl for ref_well_pl in wt_wells])]
    #     # Sort wells by max distance from edges (descending)
    #     wt_wells = [max(well_coords, key=lambda x: compute_distance(x[1], x[2]))[0]]
    pm_var = pm.filter((pl.col("imaging_well").is_in(np.concatenate([wt_wells, var_wells])))&(pl.col("plate_map_name").is_in(plate_map))).sort("node_type")

    # fig, axes = plt.subplots((len(wt_wells)+len(var_wells))*2, 4, figsize=(15, (len(wt_wells)+len(var_wells))*8), sharex=True, sharey=True)
    for wt_var, pm_row in enumerate(pm_var.iter_rows(named=True)):
        if "allele" in pm_row["node_type"]:
            if pm_row["node_type"] == "allele":
                well = var_wells[0]
                allele = variant
            else:
                well = wt_wells[0]
                allele = wt
        else:
            if pm_row["imaging_well"] in wt_wells:
                well = wt_wells[0]
                allele = wt
            else:
                well = var_wells[0]
                allele = variant

        for i in range(8):
            if i < 4:
                sel_plate = pm_row["imaging_plate_R1"]
            else:
                sel_plate = pm_row["imaging_plate_R2"]
                
            if "_" in sel_plate:
                batch_plate_map = sel_plate.split("_")[0]
            else:
                batch_plate_map = sel_plate
            
            batch = batch_dict[batch_plate_map]
            batch_img_dir = f'{imgs_dir}/{batch}/images'
            
            letter = well[0]
            row = letter_dict[letter]
            col = well[1:3]

            fig, ax = plt.subplots()
            plate_img_dir = plate_dict[sel_plate][f"T{i%4+1}"]
            img_file = f"r{row}c{col}f{site}p01-ch{channel}sk1fk1fl1.tiff"
            if plate_img_qc is not None:
                is_bg_array = plate_img_qc.filter(
                    (pl.col("plate") == plate_img_dir.split("__")[0])
                    & (pl.col("well") == well)
                    & (pl.col("channel") == sel_channel)
                )["is_bg"].to_numpy()
                if is_bg_array.size > 0:
                    is_bg = is_bg_array[0]
                else:
                    is_bg = True

            img = imread(f"{batch_img_dir}/{plate_img_dir}/Images/{img_file}", as_gray=True)
            plot_idx = i+wt_var*4*2
            # print(i, wt_var, plot_idx)
            ax.imshow(img, vmin=0, vmax=np.percentile(img, max_intensity*100), cmap=cmap)
            plot_label = f"{sel_channel}-{sel_plate}_T{i%4+1}_Well{well}_Site{site}_{allele}"
            # axes.flatten()[plot_idx].text(0.03, 0.97, plot_label, color='white', fontsize=10,
            #         verticalalignment='top', horizontalalignment='left', transform=axes.flatten()[plot_idx].transAxes,
            #         bbox=dict(facecolor='black', alpha=0.3, linewidth=2))
            # if is_bg:
            #     axes.flatten()[plot_idx].text(0.03, 0.03, "FLAG:\nOnly Background\nNoise is Detected", color='red', fontsize=10,
            #         verticalalignment='bottom', horizontalalignment='left', transform=axes.flatten()[plot_idx].transAxes,
            #         bbox=dict(facecolor='white', alpha=0.3, linewidth=2))
            # int_95 = str(int(round(np.percentile(img, 95))))
            # axes.flatten()[plot_idx].text(0.97, 0.03, f"95th Intensity:{int_95}\nSet vmax:{max_intensity*100:.0f}th perc.", color='white', fontsize=10,
            #                verticalalignment='bottom', horizontalalignment='right', transform=axes.flatten()[plot_idx].transAxes,
            #                bbox=dict(facecolor='black', alpha=0.3, linewidth=2))
            ax.axis("off")
        
            plt.tight_layout()
            # plt.subplots_adjust(wspace=.01, hspace=-0.2, top=.99)
            fig.savefig(os.path.join(output_dir, f"{plot_label}.png"), dpi=400, bbox_inches='tight')
    
    # if display:
    #     plt.show()
    
    # if output_dir:
    #     file_name = f"{variant}_{sel_channel}"
    #     if auroc:
    #         file_name = f"{file_name}_{auroc:.3f}"
    #     if ref_well:
    #         file_name = f"{file_name}_REF-{'_'.join(ref_well)}"
    #     if var_well:
    #         file_name = f"{file_name}_VAR-{'_'.join(var_well)}"
    #     fig.savefig(os.path.join(output_dir, f"{file_name}.png"), dpi=400, bbox_inches='tight')
        
            plt.close(fig)

In [35]:
# plot_allele_separate_plot(allele_meta_df_dict["2025_01_Batch_13-14"],
#                           variant="CCM2_Ile432Thr", sel_channel="GFP", 
#                           plate_img_qc=img_well_qc_sum_dict[bio_rep], 
#                           imgs_dir=TIFF_IMGS_DIR, 
#                           site="05", max_intensity=0.99, output_dir=".")

## 3. Plot Variant Painting Cell Crop Images

In [13]:
BATCH_PROFILES = "../../2.snakemake_pipeline/outputs/batch_profiles/{}/profiles.parquet" 
IMG_ANALYSIS_DIR = "../../1.image_preprocess_qc/inputs/cpg_imgs/{}/analysis"

### 3.1 Load and store the batch profiles for cell crop look-up

__Run only once and stored in a very large dict for easy loading__

In [14]:
# # Filter thresholds
# min_area_ratio = 0.15
# max_area_ratio = 0.3
# min_center = 50
# max_center = 1030

# # num_mad = 5
# # min_cells = 250

# batch_profiles = {}
# for bio_rep, bio_rep_batches in BIO_REP_BATCHES_DICT.items():
#     for batch_id in BIO_REP_BATCHES_DICT[bio_rep]:
#         imagecsv_dir = IMG_ANALYSIS_DIR.format(batch_id) #f"../../../8.1_upstream_analysis_runxi/2.raw_img_qc/inputs/images/{batch_id}/analysis"
#         prof_path = BATCH_PROFILES.format(batch_id)
#         # Get metadata
#         profiles = pl.scan_parquet(prof_path).select(
#             ["Metadata_well_position", "Metadata_plate_map_name", "Metadata_ImageNumber", "Metadata_ObjectNumber",
#             "Metadata_symbol", "Metadata_gene_allele", "Metadata_node_type", "Metadata_Plate",
#             "Nuclei_AreaShape_Area", "Cells_AreaShape_Area", "Nuclei_AreaShape_Center_X", "Nuclei_AreaShape_Center_Y",
#             "Cells_AreaShape_BoundingBoxMaximum_X", "Cells_AreaShape_BoundingBoxMaximum_Y", "Cells_AreaShape_BoundingBoxMinimum_X",
#             "Cells_AreaShape_BoundingBoxMinimum_Y",	"Cells_AreaShape_Center_X",	"Cells_AreaShape_Center_Y",
#             "Cells_Intensity_MeanIntensity_GFP", "Cells_Intensity_MedianIntensity_GFP", "Cells_Intensity_IntegratedIntensity_GFP"],
#         ).collect()
#         # print(profiles["Metadata_Plate"])
    
#         # Filter based on cell to nucleus area
#         profiles = profiles.with_columns(
#                         (pl.col("Nuclei_AreaShape_Area")/pl.col("Cells_AreaShape_Area")).alias("Nucleus_Cell_Area"),
#                         pl.concat_str([
#                             "Metadata_Plate", "Metadata_well_position", "Metadata_ImageNumber", "Metadata_ObjectNumber",
#                             ], separator="_").alias("Metadata_CellID"),
#                 ).filter((pl.col("Nucleus_Cell_Area") > min_area_ratio) & (pl.col("Nucleus_Cell_Area") < max_area_ratio))
    
#         # Filter cells too close to image edge
#         profiles = profiles.filter(
#             ((pl.col("Nuclei_AreaShape_Center_X") > min_center) & (pl.col("Nuclei_AreaShape_Center_X") < max_center) &
#             (pl.col("Nuclei_AreaShape_Center_Y") > min_center) & (pl.col("Nuclei_AreaShape_Center_Y") < max_center)),
#         )
    
#         # Calculate mean, median and mad of gfp intensity for each allele
#         ## mean
#         means = profiles.group_by(["Metadata_Plate", "Metadata_well_position"]).agg(
#             pl.col("Cells_Intensity_MeanIntensity_GFP").mean().alias("WellIntensityMean"),
#         )
#         profiles = profiles.join(means, on=["Metadata_Plate", "Metadata_well_position"])
#         ## median
#         medians = profiles.group_by(["Metadata_Plate", "Metadata_well_position"]).agg(
#             pl.col("Cells_Intensity_MedianIntensity_GFP").median().alias("WellIntensityMedian"),
#         )
#         profiles = profiles.join(medians, on=["Metadata_Plate", "Metadata_well_position"])
#         ## mad
#         profiles = profiles.with_columns(
#             (pl.col("Cells_Intensity_MedianIntensity_GFP") - pl.col("WellIntensityMedian")).abs().alias("Abs_dev"),
#         )
#         mad = profiles.group_by(["Metadata_Plate", "Metadata_well_position"]).agg(
#             pl.col("Abs_dev").median().alias("Intensity_MAD"),
#         )
#         profiles = profiles.join(mad, on=["Metadata_Plate", "Metadata_well_position"])
    
#         # ## Threshold is 5X
#         # ## Used to be median well intensity + 5*mad implemented by Jess
#         # ## Switching to mean well intensity + 5*mad implemented by Runxi
#         # profiles = profiles.with_columns(
#         #     (pl.col("WellIntensityMedian") + num_mad*pl.col("Intensity_MAD")).alias("Intensity_upper_threshold"), ## pl.col("WellIntensityMedian")
#         #     (pl.col("WellIntensityMedian") - num_mad*pl.col("Intensity_MAD")).alias("Intensity_lower_threshold"), ## pl.col("WellIntensityMedian")
#         # )
#         # ## Filter by intensity MAD
#         # profiles = profiles.filter(
#         #     pl.col("Cells_Intensity_MeanIntensity_GFP") <= pl.col("Intensity_upper_threshold"),
#         # ).filter(
#         #     pl.col("Cells_Intensity_MeanIntensity_GFP") >= pl.col("Intensity_lower_threshold"),
#         # )
    
#         # Filter out alleles with fewer than 250 cells
#         # keep_alleles = profiles.group_by("Metadata_gene_allele").count().filter(
#         #     pl.col("count") >= min_cells,
#         #     ).select("Metadata_gene_allele").to_series().to_list()
#         # profiles = profiles.filter(pl.col("Metadata_gene_allele").is_in(keep_alleles))
    
#         # add full crop coordinates
#         profiles = profiles.with_columns(
#             (pl.col("Nuclei_AreaShape_Center_X") - 50).alias("x_low").round().cast(pl.Int16),
#             (pl.col("Nuclei_AreaShape_Center_X") + 50).alias("x_high").round().cast(pl.Int16),
#             (pl.col("Nuclei_AreaShape_Center_Y") - 50).alias("y_low").round().cast(pl.Int16),
#             (pl.col("Nuclei_AreaShape_Center_Y") + 50).alias("y_high").round().cast(pl.Int16),
#         )
    
#         # Read in all Image.csv to get ImageNumber:SiteNumber mapping and paths
#         image_dat = []
#         icfs = glob.glob(os.path.join(imagecsv_dir, "**/*Image.csv"), recursive=True)
#         for icf in tqdm(icfs):
#             fp = icf.split('/')[-2]
#             # print(fp)
#             plate, well = "-".join(fp.split("-")[:-2]), fp.split("-")[-2]
#             # print(plate, well)
#             image_dat.append(pl.read_csv(icf).select(
#                 [
#                     "ImageNumber",
#                     "Metadata_Site",
#                     "PathName_OrigDNA",
#                     "FileName_OrigDNA",
#                     "FileName_OrigGFP",
#                     ],
#                 ).with_columns(
#                 pl.lit(plate).alias("Metadata_Plate"),
#                 pl.lit(well).alias("Metadata_well_position"),
#                 ))
#         image_dat = pl.concat(image_dat).rename({"ImageNumber": "Metadata_ImageNumber"})
    
#         # Create useful filepaths
#         image_dat = image_dat.with_columns(
#             pl.col("PathName_OrigDNA").str.replace(".*cpg0020-varchamp/", "").alias("Path_root"),
#         )
    
#         image_dat = image_dat.drop([
#             "PathName_OrigDNA",
#             "FileName_OrigDNA",
#             "FileName_OrigGFP",
#             "Path_root",
#         ])
#         # print(image_dat)
    
#         # Append to profiles
#         profiles = profiles.join(image_dat, on = ["Metadata_Plate", "Metadata_well_position", "Metadata_ImageNumber"])
    
#         # Sort by allele, then image number
#         profiles = profiles.with_columns(
#             pl.concat_str(["Metadata_Plate", "Metadata_well_position", "Metadata_Site"], separator="_").alias("Metadata_SiteID"),
#             pl.col("Metadata_gene_allele").str.replace("_", "-").alias("Protein_label"),
#         )
#         profiles = profiles.sort(["Protein_label", "Metadata_SiteID"])
#         alleles = profiles.select("Protein_label").to_series().unique().to_list()
#         batch_profiles[batch_id] = profiles

In [15]:
# Pickle the metadata dictionary
# with open("../../2.snakemake_pipeline/outputs/visualize_cells/batch_prof_dict.pkl", "wb") as f:
#     pickle.dump(batch_profiles, f, pickle.HIGHEST_PROTOCOL)

### 3.2 Plot the cell crops

In [16]:
import pickle
# To load the dictionary and DataFrames later
with open("../../2.snakemake_pipeline/outputs/visualize_cells/batch_prof_dict.pkl", "rb") as f:
    batch_profiles = pickle.load(f)

In [18]:
def save_allele_cell_imgs(variant, feat, batch_profile_dict, auroc_df=auroc_df, display=False, save_img=False):
    bio_rep = get_allele_batch(variant)
    auroc_df_batch = auroc_df.with_columns(
        pl.col(f"AUROC_Mean_{feat}").alias("AUROC_Mean"),
        pl.col(f"gene_allele").alias("allele_0")
    )
    ref_allele = variant.split("_")[0]
    ref_wells = allele_meta_df_dict[bio_rep].filter(pl.col("gene_allele")==ref_allele)["imaging_well"].to_list()
    var_wells = allele_meta_df_dict[bio_rep].filter(pl.col("gene_allele")==variant)["imaging_well"].to_list()
    target_file = [f for f in os.listdir(f"{OUT_CELL_DIR}/{bio_rep}") if f.startswith(f"{variant}_{feat}")]
    if target_file:
        print(target_file, "exists.")
        output_dir = ""
        if not display:
            return None

    if save_img:
        output_dir = f"{OUT_CELL_DIR}/{bio_rep}"
        print(f"Img output at {output_dir}")
    else:
        output_dir = ""

    if bio_rep != "2024_12_Batch_11-12":
        if len(ref_wells)==1 and len(var_wells)==1:
            plot_allele_cell(allele_meta_df_dict[bio_rep],
                             variant=variant, sel_channel=feat,
                             batch_profile_dict=batch_profile_dict,
                             auroc_df=auroc_df_batch, 
                             plate_img_qc=img_well_qc_sum_dict[bio_rep], 
                             site="05", max_intensity=0.99, 
                             display=display,
                            imgs_dir=TIFF_IMGS_DIR, 
                            output_dir=output_dir)
        else:
            for ref_well in ref_wells:
                for var_well in var_wells:
                    plot_allele_cell(allele_meta_df_dict[bio_rep],
                                     variant=variant, 
                                     sel_channel=feat,
                                     batch_profile_dict=batch_profile_dict,
                                     auroc_df=auroc_df_batch, 
                                     plate_img_qc=img_well_qc_sum_dict[bio_rep], 
                                     site="05", max_intensity=0.99, 
                                     display=display,
                                     ref_well=[ref_well], 
                                     var_well=[var_well],
                                     imgs_dir=TIFF_IMGS_DIR, 
                                     output_dir=output_dir)
    else:
        if len(ref_wells)==4 and len(var_wells)==4:
            plot_allele_cell_single_plate(allele_meta_df_dict[bio_rep], ##.filter(pl.col("plate_map_name").str.contains("B13")
                                     variant=variant, sel_channel=feat,
                                     batch_profile_dict=batch_profile_dict,
                                     auroc_df=auroc_df_batch, 
                                     plate_img_qc=img_well_qc_sum_dict[bio_rep], 
                                     site="05", max_intensity=0.99, 
                                     display=display,
                                     imgs_dir=TIFF_IMGS_DIR, 
                                     output_dir=output_dir)
        else:
            ref_wells_idx = len(ref_wells) // 4
            var_wells_idx = len(var_wells) // 4
            for rw_idx in range(ref_wells_idx):
                for vw_idx in range(var_wells_idx):
                    plot_allele_cell_single_plate(allele_meta_df_dict[bio_rep], ##.filter(pl.col("plate_map_name").str.contains("B13")
                                                  variant=variant, sel_channel=feat,
                                                  batch_profile_dict=batch_profile_dict,
                                                  auroc_df=auroc_df_batch, 
                                                  plate_img_qc=img_well_qc_sum_dict[bio_rep], 
                                                  site="05", max_intensity=0.99,
                                                  ref_well=ref_wells[rw_idx*4:rw_idx*4+4],
                                                  var_well=var_wells[vw_idx*4:vw_idx*4+4],
                                                  display=display,
                                                  imgs_dir=TIFF_IMGS_DIR, 
                                                  output_dir=output_dir)

In [27]:
# save_allele_cell_imgs("FOXP3_Leu242Pro", feat="GFP", batch_profile_dict=batch_profiles, display=True)
# save_allele_cell_imgs("TPM1_Asn279His", feat="GFP", batch_profile_dict=batch_profiles, display=True)
# save_allele_cell_imgs("RAD51D_Glu233Gly", feat="GFP", batch_profile_dict=batch_profiles, display=True)
# save_allele_cell_imgs("RAD51D_Val28Leu", feat="GFP", batch_profile_dict=batch_profiles, display=True)

# for variant in tqdm(allele_list): ## ["F9_Ile316Thr"]
#     for feat in ["AGP","Mito","GFP"]:
#         save_allele_cell_imgs(variant, feat=feat, batch_profile_dict=batch_profiles, display=False, save_img=True)

# plot_allele_cell_multi(allele_meta_df_dict["2025_01_Batch_13-14"], variant, 
#                  sel_channel=["AGP","Mito","GFP"], 
#                  plate_img_qc=img_well_qc_sum_dict["2025_01_Batch_13-14"],
#                  auroc_df=None, batch_profile_dict=batch_profiles, display=True, 
#                  # output_dir=f"{OUT_CELL_DIR}/2025_01_Batch_13-14"
# )

In [63]:
def plot_allele_cell_multi(pm, variant, sel_channel, batch_profile_dict, auroc_df, plate_img_qc, site="05", ref_well=[], var_well=[], max_intensity=0.99, display=False, imgs_dir=TIFF_IMGS_DIR, output_dir=""):
    # Detect input type
    is_multichannel = isinstance(sel_channel, (list, tuple))
    channels = sel_channel if is_multichannel else [sel_channel]
    
    # For single channel, maintain original behavior
    if not is_multichannel:
        cmap = channel_to_cmap(sel_channel)
        channel = channel_dict[sel_channel]

    ## get the number of wells/images per allele
    wt = variant.split("_")[0]
    wt_wells = pm.filter(pl.col("gene_allele") == wt).select("imaging_well").to_pandas().values.flatten()
    var_wells = pm.filter(pl.col("gene_allele") == variant).select("imaging_well").to_pandas().values.flatten()
    plate_map = pm.filter(pl.col("gene_allele") == variant).select("plate_map_name").to_pandas().values.flatten()

    if ref_well:
        wt_wells = [well for well in wt_wells if well in ref_well]
    if var_well:
        var_wells = [well for well in var_wells if well in var_well]
    pm_var = pm.filter((pl.col("imaging_well").is_in(np.concatenate([wt_wells, var_wells])))&(pl.col("plate_map_name").is_in(plate_map))).sort("node_type")
    
    plt.clf()
    fig, axes = plt.subplots((len(wt_wells)+len(var_wells))*2, 4, figsize=(15, 16), sharex=True, sharey=True)
    
    for wt_var, pm_row in enumerate(pm_var.iter_rows(named=True)):
        if pm_row["node_type"] == "allele":
            well = var_wells[0]
            allele = variant
        else:
            well = wt_wells[0]
            allele = wt
            
        for i in range(8):
            plot_idx = i+wt_var*4*2
            if i < 4:
                sel_plate = pm_row["imaging_plate_R1"]
            else:
                sel_plate = pm_row["imaging_plate_R2"]

            batch = batch_dict[sel_plate.split("_")[0]]
            batch_img_dir = f'{imgs_dir}/{batch}/images'
            letter = well[0]
            row, col = letter_dict[letter], well[1:3]
            plate_img_dir = plate_dict[sel_plate][f"T{i%4+1}"]
            
            # Load images for all channels
            channel_imgs = {}
            for ch in channels:
                channel_num = channel_dict[ch]
                img_file = f"r{row}c{col}f{site}p01-ch{channel_num}sk1fk1fl1.tiff"
                channel_imgs[ch] = imread(f"{batch_img_dir}/{plate_img_dir}/Images/{img_file}", as_gray=True)
            
            # For QC check, use first channel if multichannel
            main_img = channel_imgs[channels[0]]
            
            if plate_img_qc is not None:
                # Use first channel for QC check
                qc_channel = channels[0] if is_multichannel else sel_channel
                is_bg = plate_img_qc.filter((pl.col("plate") == plate_img_dir.split("__")[0]) & (pl.col("well") == well) & (pl.col("channel") == qc_channel))["is_bg"].to_numpy()[0]

            ## Draw cells
            cell_allele_coord_df = crop_allele(allele, batch_profile_dict[batch], sel_plate.split("P")[0], rep=f"T{i%4+1}", site=site[-1])      
            cell_allele_coord_df = cell_allele_coord_df.with_columns(
                pl.struct("Cells_AreaShape_Center_X", "Cells_AreaShape_Center_Y")
                .map_elements(lambda x: compute_distance_cell(x['Cells_AreaShape_Center_X'], x['Cells_AreaShape_Center_Y']), return_dtype=pl.Float32).cast(pl.Int16)
                .alias('dist2edge')
            ).sort(by=["dist2edge","Cells_AreaShape_Area"], descending=[True,True]).filter(pl.col("Cells_AreaShape_Area")>5000)

            if cell_allele_coord_df.is_empty():
                axes.flatten()[plot_idx].text(0.97, 0.97, "No high-quality\ncell available", color='black', fontsize=12,
                    verticalalignment='top', horizontalalignment='right', transform=axes.flatten()[plot_idx].transAxes,
                    bbox=dict(facecolor='white', alpha=0.3, linewidth=2)
                )
                
                x, y = main_img.shape[0] // 2, main_img.shape[1] // 2
                
                if is_multichannel:
                    # Create multichannel composite
                    channel_subs = {}
                    for ch in channels:
                        img_sub = channel_imgs[ch][y-64:y+64, x-64:x+64]
                        channel_subs[ch] = img_sub
                    composite = create_multichannel_composite(channel_subs, channels, max_intensity)
                    axes.flatten()[plot_idx].imshow(composite)
                else:
                    img_sub = main_img[y-64:y+64, x-64:x+64]
                    axes.flatten()[plot_idx].imshow(img_sub, vmin=0, vmax=np.percentile(img_sub, max_intensity*100), cmap=cmap)
                    
            else:
                plot_yet = 0
                for cell_idx, cell_row in enumerate(cell_allele_coord_df.iter_rows(named=True)):
                    if plot_yet:
                        break
                    x, y = int(cell_row["Nuclei_AreaShape_Center_X"]), int(cell_row["Nuclei_AreaShape_Center_Y"])
                    
                    # Get subimages for all channels
                    channel_subs = {}
                    main_img_sub = None
                    for ch in channels:
                        img_sub = channel_imgs[ch][y-64:y+64, x-64:x+64]
                        channel_subs[ch] = img_sub
                        if ch == channels[0]:  # Use first channel for quality checks
                            main_img_sub = img_sub
                    
                    ## skip the subimage due to poor cell quality (using main channel)
                    if (main_img_sub.shape[0] == 0 or main_img_sub.shape[1] == 0 or 
                        np.percentile(main_img_sub, 90) <= np.median(main_img) or 
                        np.var(main_img_sub) < 1e4 or 
                        np.percentile(main_img_sub, 99) / np.percentile(main_img_sub, 25) < 2):
                        continue
                    
                    # Display image
                    if is_multichannel:
                        composite = create_multichannel_composite(channel_subs, channels, max_intensity)
                        axes.flatten()[plot_idx].imshow(composite)
                        channel_label = "+".join(channels)
                    else:
                        axes.flatten()[plot_idx].imshow(main_img_sub, vmin=0, vmax=np.percentile(main_img_sub, max_intensity*100), cmap=cmap)
                        channel_label = sel_channel
                    
                    plot_label = f"{channel_label}:{sel_plate},T{i%4+1}\nWell:{well},Site:{site}\n{allele}"
                    axes.flatten()[plot_idx].text(0.03, 0.97, plot_label, color='white', fontsize=10,
                            verticalalignment='top', horizontalalignment='left', transform=axes.flatten()[plot_idx].transAxes,
                            bbox=dict(facecolor='black', alpha=0.3, linewidth=2))
                    
                    if is_bg:
                        axes.flatten()[plot_idx].text(0.03, 0.03, "FLAG:\nOnly Background\nNoise is Detected", color='red', fontsize=10,
                            verticalalignment='bottom', horizontalalignment='left', transform=axes.flatten()[plot_idx].transAxes,
                            bbox=dict(facecolor='white', alpha=0.3, linewidth=2))
                    
                    int_95 = str(int(round(np.percentile(main_img_sub, 95))))
                    axes.flatten()[plot_idx].text(0.95, 0.05, f"95th Intensity:{int_95}", color='white', fontsize=10,
                                verticalalignment='bottom', horizontalalignment='right', transform=axes.flatten()[plot_idx].transAxes,
                                bbox=dict(facecolor='black', alpha=0.3, linewidth=2))
                    plot_yet = 1
                    
            axes.flatten()[plot_idx].axis("off")
            
    fig.tight_layout()
    fig.subplots_adjust(wspace=.01, hspace=-0.2, top=.99)
    
    if display:
        plt.show()

    if output_dir:
        if is_multichannel:
            channel_str = "+".join(channels)
        else:
            channel_str = sel_channel
            
        file_name = f"{variant}_{channel_str}_cells"
        if auroc_df:
            auroc = auroc_df.filter(pl.col("allele_0")==variant)["AUROC_Mean"].mean()
            file_name = f"{file_name}_{auroc:.3f}"
        if ref_well:
            file_name = f"{file_name}_REF-{'_'.join(ref_well)}"
        if var_well:
            file_name = f"{file_name}_VAR-{'_'.join(var_well)}"
        fig.savefig(os.path.join(output_dir, f"{file_name}.png"), dpi=400, bbox_inches='tight')
        
    plt.close(fig)


def create_multichannel_composite(channel_subs, channels, max_intensity):
    """Create RGB composite from multiple channel subimages"""
    shape = next(iter(channel_subs.values())).shape
    composite = np.zeros((*shape, 3))
    
    # Channel-specific colors matching your cmap definitions
    color_map = {
        'DAPI': [0, 0, 1],      # Blue (#0000FF)
        'GFP': [0.396, 0.996, 0.031],  # Green (#65fe08) 
        'AGP': [1, 1, 0],       # Yellow (#FFFF00)
        'Mito': [1, 0, 0],      # Red (#FF0000)
        'Brightfield1': [1, 1, 1],     # White/Gray
        'Brightfield2': [1, 1, 1],     # White/Gray
        'Brightfield': [1, 1, 1]       # White/Gray
    }
    
    # Define layer order (first = bottom, last = top)
    layer_order = ['DAPI', 'AGP', 'Mito', 'GFP', 'Brightfield1', 'Brightfield2', 'Brightfield']
    
    # Sort channels by layer order
    ordered_channels = []
    for layer in layer_order:
        if layer in channels:
            ordered_channels.append(layer)
    for ch in channels:
        if ch not in ordered_channels:
            ordered_channels.append(ch)
    
    # Find global percentile across all channels to prevent saturation
    all_values = []
    for ch in ordered_channels:
        all_values.extend(channel_subs[ch].flatten())
    global_max = np.percentile(all_values, max_intensity*100)
    
    for ch in ordered_channels:
        img = channel_subs[ch]
        # Normalize using global max to maintain relative intensities
        img_norm = np.clip(img / global_max, 0, 1)
        
        color = color_map.get(ch, [1, 1, 1])
        
        # Direct addition without scaling
        for i in range(3):
            composite[:,:,i] += img_norm * color[i]
    
    return np.clip(composite, 0, 1)

In [58]:
# save_allele_cell_imgs("MVK_Ser329Asn", feat="GFP", batch_profile_dict=batch_profiles, display=True)

In [57]:
# save_allele_cell_imgs("MLH1_Ile36Asn", feat="GFP", batch_profile_dict=batch_profiles, display=True)

In [None]:
# rsync -avz username@server_ip:/home/shenrunx/igvf/varchamp/2025_varchamp_snakemake/2.snakemake_pipeline/outputs/visualize_imgs/2025_01_Batch_13-14/F9*.png /local/destination/path/