# Compute and Analyze Classifier Metrics

The snakemake pipeline output the probability of 0 / 1 for each cell for each classifier. Here, we compute and save many common metrics from these probabilities.

Modified from the previous strategy, only process the classification results batch by batch.

In [9]:
# imports
import os
import polars as pl
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    average_precision_score,
    balanced_accuracy_score,
    confusion_matrix,
    f1_score,
    roc_auc_score,
)

%matplotlib inline

## Disable truncation globally
pl.Config.set_tbl_rows(20)  # Show all rows
pl.Config.set_tbl_cols(40)  # Show all columns

## define control types
TC = ["EGFP"]
NC = ["RHEB", "MAPK9", "PRKACB", "SLIRP"]
PC = ["ALK", "ALK_Arg1275Gln", "PTK2B"]
cNC = ["Renilla"]
cPC = ["KRAS", "PTK2B", "GHSR", "ABL1", "BRD4", "OPRM1", "RB1", "ADA", "WT PMP22", "LYN", "TNF", 
       "CYP2A6", "CSK", "PAK1", "ALDH2", "CHRM3", "KCNQ2", "ALK T1151M", "PRKCE", "LPAR1", "PLP1"]

# ctrls = TC + NC + PC
# import json
# formatted_list = json.dumps(cPC)
# print(formatted_list)

## Visualize platemaps

In [10]:
def plot_platemap(df, plate_name):
    # Create a 16x24 grid for the 384-well plate
    rows = list('ABCDEFGHIJKLMNOP')
    cols = [f'{i:02d}' for i in range(1, 25)]
    # Initialize the plate grid with empty values
    plate_grid = pl.DataFrame({col: ['' for _ in rows] for col in cols}, schema={col: pl.Utf8 for col in cols})
    
    # Add a row index to the plate grid
    plate_grid = plate_grid.with_row_index('row')
    # Unpivot the plate grid to long format
    plate_grid = plate_grid.unpivot(index='row', on=cols, variable_name='col', value_name='value')
    
    # Add row and column labels
    plate_grid = plate_grid.with_columns(
        pl.col('row').map_elements(lambda x: rows[x], return_dtype=pl.Utf8).alias('row_label'),  # Map row index to row label (A-P)
        pl.col('col').alias('col_label')  # Use the column name directly as the column label
    )
    # Merge with the input data
    df = df.with_columns(
        pl.col('well_position').str.head(1).alias('row_label'),
        pl.col('well_position').str.slice(1).alias('col_label')
    )
    # Join the input data with the plate grid
    plate_grid = plate_grid.join(df, on=['row_label', 'col_label'], how='left')
    # print(plate_grid.select(pl.col("control_type")).unique())
    # Fill missing values in the gene column
    plate_grid = plate_grid.with_columns(
        pl.col('symbol').fill_null(''),
        pl.col('gene_allele').fill_null(''),
        pl.col('gene_allele').str.replace("_", '\n')
          .alias('label')
    )
    # Reshape the grid for plotting
    # heatmap_data = plate_grid.pivot(index='row_label', columns='col_label', values='color').fill_null('white')
    heatmap_data = plate_grid.pivot(index='row_label', on='col_label', values='label').fill_null('')
    # print(heatmap_data)
    # Assign colors based on conditions
    heatmap_colors = plate_grid.pivot(index='row_label', on='col_label', values='control_type').fill_null('')
    # print(heatmap_colors)

    # Convert to numpy arrays for plotting
    heatmap_labels = heatmap_data[:, 1:].to_numpy()
    heatmap_colors = heatmap_colors[:, 1:].to_numpy()
    # print(heatmap_colors)
    
    # Create a color map for the heatmap
    color_map = {
        '_TC_': 'slategrey', # Grey for controls
        'TC': 'slategrey', # Grey for controls
        'NC': 'gainsboro', 
        'PC': 'plum',
        'cPC': 'pink',
        'allele': 'salmon',  # Tomato for disease
        'disease_wt': 'lightskyblue',  # Skyblue for reference
        '': 'white'  # White for missing wells
    }
    
    # Map the colors to the grid
    heatmap_colors = np.vectorize(color_map.get)(heatmap_colors)
    
    # Plot the heatmap
    plt.figure(figsize=(35, 13.5))
    sns.heatmap(
        np.zeros_like(heatmap_labels, dtype=int),  # Dummy data for heatmap
        annot=heatmap_labels,
        fmt='',
        # cmap='viridis',  # Dummy colormap (not used for coloring)
        cbar=False,
        linewidths=1,
        linecolor='black',
        square=True,
        annot_kws={'size': 8.5, 'color': 'black'}
    )
    
    # Apply colors manually
    for i in range(heatmap_colors.shape[0]):
        for j in range(heatmap_colors.shape[1]):
            plt.gca().add_patch(plt.Rectangle((j, i), 1, 1, color=heatmap_colors[i, j], fill=True))
    
    # Customize the plot
    plt.title(f"384-Well Plate Map: {plate_name}", fontsize=16)
    plt.xlabel('Columns', fontsize=12)
    plt.ylabel('Rows', fontsize=12)
    plt.xticks(ticks=np.arange(1,25)-.5, labels=cols, rotation=0)
    plt.yticks(ticks=np.arange(16) + 0.5, labels=rows, rotation=0)
    
    # Show the plot
    plt.tight_layout()
    plt.show()
    
    return plate_grid

In [None]:
def check_batch_pms(batch_dir):
    # batch_dir = "/home/shenrunx/igvf/varchamp/2021_09_01_VarChAMP/6.downstream_analysis_snakemake/inputs/metadata/platemaps/2025_01_27_Batch_13/platemap/"
    platemap_set = sorted(set([platemap_id.split('_T')[0] for platemap_id in os.listdir(batch_dir)]))
    print("Platemaps available for batch:", platemap_set)
    for platemap_id in platemap_set:
        pm_ref_df = pl.read_csv(Path(batch_dir) / Path(platemap_id+"_T1.txt"), separator='\t', has_header=True)
        columns_to_include = pm_ref_df.columns[1:]
        for pm_file in [pm for pm in os.listdir(batch_dir) if platemap_id in pm and "T1" not in pm]:
            pm_df = pl.read_csv(Path(batch_dir) / Path(pm_file), separator='\t', has_header=True)
            df1_selected = pm_ref_df.select(columns_to_include)
            df2_selected = pm_df.select(columns_to_include)
            assert df1_selected.equals(df2_selected)
            print(f"{pm_file} is equal to ref. {platemap_id}_T1.txt")

In [None]:
plate_map_dir = ""
plate_map_id = "B7A1R1_P1.txt"
meta_dat_b134_corrected = pl.read_csv(Path() / Path("2024_01_23_Batch_7") /Path(plate_map_id), separator='\t', has_header=True)
meta_dat_b134_corrected

# Plot first platemap
plate_map = meta_dat_b134_corrected.filter(
                ~pl.all_horizontal(pl.all().is_null())
            ).filter(~pl.col("gene_allele").is_null())
plate_grid = plot_platemap(plate_map, "2024_01_23_Batch_7")

plate_grid = plot_platemap(meta_dat_b134_corrected, "2024_01_23_Batch_7")

In [6]:
# Define a function to compute metrics for each group
def compute_aubprc(auprc, prior):
    return (auprc * (1 - prior)) / ((auprc * (1 - prior)) + ((1 - auprc) * prior))


def compute_metrics(group):
    y_true = group["Label"].to_numpy()
    y_prob = group["Prediction"].to_numpy()
    y_pred = (y_prob > 0.5).astype(int)
    prior = sum(y_true == 1) / len(y_true)

    class_ID = group["Classifier_ID"].unique()[0]

    # Compute AUROC
    auroc = roc_auc_score(y_true, y_prob)

    # Compute AUPRC
    auprc = average_precision_score(y_true, y_prob)
    aubprc = compute_aubprc(auprc, prior)

    # Compute macro-averaged F1 score
    macro_f1 = f1_score(y_true, y_pred, average="macro")

    # Compute sensitivity and specificity
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)

    # Compute balanced accuracy
    balanced_acc = balanced_accuracy_score(y_true, y_pred)

    return {
        "AUROC": auroc,
        "AUPRC": auprc,
        "AUBPRC": aubprc,
        "Macro_F1": macro_f1,
        "Sensitivity": sensitivity,
        "Specificity": specificity,
        "Balanced_Accuracy": balanced_acc,
        "Classifier_ID": class_ID,
    }

In [None]:
def calculate_class_metrics(classifier_info: str, predictions: str, metrics_file: str):
    batch_id = [subdir for subdir in classifier_info.split("/") if "Batch" in subdir][0]
    batch_id = f"B{batch_id.split('Batch_')[-1]}"

    # read in classifier info
    class_info = pl.read_csv(classifier_info)
    class_info = class_info.with_columns(
        (pl.col("trainsize_1") / (pl.col("trainsize_0") + pl.col("trainsize_1"))).alias(
            "train_prob_1"
        ),
        (pl.col("testsize_1") / (pl.col("testsize_0") + pl.col("testsize_1"))).alias(
            "test_prob_1"
        ),
    )

    # read in predictions
    preds = pl.scan_parquet(predictions)
    preds = preds.with_columns(pl.lit(batch_id).alias("Batch")).collect()
    preds = preds.with_columns(
        pl.concat_str(
            [pl.col("Classifier_ID"), pl.col("Metadata_Protein"), pl.col("Batch")],
            separator="_",
        ).alias("Full_Classifier_ID")
    )

    # Initialize an empty list to store the results
    results = []
    classIDs = preds.select("Full_Classifier_ID").to_series().unique().to_list()

    # Group by Classifier_ID and compute metrics for each group
    for id in tqdm(classIDs):
        metrics = compute_metrics(preds.filter(pl.col("Full_Classifier_ID") == id))
        metrics["Full_Classifier_ID"] = id
        results.append(metrics)

    # Convert the results to a Polars DataFrame
    metrics_df = pl.DataFrame(results)

    # Add classifier info and save
    metrics_df = metrics_df.join(class_info, on="Classifier_ID")
    metrics_df = metrics_df.with_columns(
        (
            pl.max_horizontal(["trainsize_0", "trainsize_1"])
            / pl.min_horizontal(["trainsize_0", "trainsize_1"])
        ).alias("Training_imbalance"),
        (
            pl.max_horizontal(["testsize_0", "testsize_1"])
            / pl.min_horizontal(["testsize_0", "testsize_1"])
        ).alias("Testing_imbalance"),
    )
    metrics_df.write_csv(metrics_file)

    return metrics_df

In [32]:
metrics_df = calculate_class_metrics(feat_info="../outputs/results/2024_01_23_Batch_7/profiles_tcdropped_filtered_var_mad_outlier_featselect/feat_importance.csv",
                                     classifier_info="../outputs/results/2024_01_23_Batch_7/profiles_tcdropped_filtered_var_mad_outlier_featselect/classifier_info.csv",
                                     predictions="../outputs/results/2024_01_23_Batch_7/profiles_tcdropped_filtered_var_mad_outlier_featselect/predictions.parquet",
                                     metrics_file="../outputs/analyses/2024_01_23_Batch_7/profiles_tcdropped_filtered_var_mad_outlier_featselect/metrics.csv")

100%|██████████| 5826/5826 [03:17<00:00, 29.56it/s]


## Compute metrics

In [43]:
metrics_file = "../outputs/analyses/2024_01_23_Batch_7/profiles_tcdropped_filtered_var_mad_outlier_featselect/metrics.csv"
metrics_df = pl.read_csv(metrics_file)

batch_id = [subdir for subdir in metrics_file.split("/") if "Batch" in subdir][0]
batch_id = f"B{batch_id.split('Batch_')[-1]}"

metrics_df

AUROC,AUPRC,AUBPRC,Macro_F1,Sensitivity,Specificity,Balanced_Accuracy,Classifier_ID,Full_Classifier_ID,Plate,trainsize_0,testsize_0,well_0,allele_0,trainsize_1,testsize_1,well_1,allele_1,Metadata_Control,train_prob_1,test_prob_1,Training_imbalance,Testing_imbalance
f64,f64,f64,f64,f64,f64,f64,str,str,str,i64,i64,str,str,i64,i64,str,str,bool,f64,f64,f64,f64
0.948019,0.925634,0.933927,0.886822,0.877966,0.895522,0.886744,"""2024_01_17_B7A1R1_P1T3_H06_J21""","""2024_01_17_B7A1R1_P1T3_H06_J21…","""2024_01_17_B7A1R1_P1T3""",748,335,"""H06""","""DES_Arg150Gln""",661,295,"""J21""","""DES""",false,0.469127,0.468254,1.131619,1.135593
0.600219,0.723602,0.587594,0.523385,0.463807,0.640394,0.552101,"""2024_01_22_B7A1R1_P4T4_H09_F09""","""2024_01_22_B7A1R1_P4T4_H09_F09…","""2024_01_22_B7A1R1_P4T4""",1153,203,"""H09""","""TTPA_Asp64Gly""",1070,373,"""F09""","""TTPA""",false,0.481332,0.647569,1.07757,1.837438
0.529359,0.540754,0.533133,0.441133,0.701887,0.237354,0.46962,"""2024_01_22_B7A1R1_P4T3_C14_I12""","""2024_01_22_B7A1R1_P4T3_C14_I12…","""2024_01_22_B7A1R1_P4T3""",635,257,"""C14""","""TH_Ile219Thr""",663,265,"""I12""","""TH""",false,0.510786,0.507663,1.044094,1.031128
0.707959,0.723445,0.748561,0.654405,0.493724,0.823529,0.658627,"""2024_01_17_B7A1R1_P1T3_M02_K02""","""2024_01_17_B7A1R1_P1T3_M02_K02…","""2024_01_17_B7A1R1_P1T3""",702,272,"""M02""","""AMPD2_Glu697Asp""",579,239,"""K02""","""AMPD2""",false,0.451991,0.46771,1.212435,1.138075
0.850989,0.90201,0.844164,0.728705,0.646091,0.874126,0.760108,"""2024_01_17_B7A1R1_P1T4_C01_A01""","""2024_01_17_B7A1R1_P1T4_C01_A01…","""2024_01_17_B7A1R1_P1T4""",520,143,"""C01""","""ACSF3_Ala197Thr""",832,243,"""A01""","""ACSF3""",false,0.615385,0.629534,1.6,1.699301
0.83739,0.870089,0.85311,0.751481,0.842995,0.657382,0.750188,"""2024_01_17_B7A1R1_P2T4_C08_M06""","""2024_01_17_B7A1R1_P2T4_C08_M06…","""2024_01_17_B7A1R1_P2T4""",670,359,"""C08""","""GSS_Arg283Cys""",1074,414,"""M06""","""GSS""",false,0.615826,0.535576,1.602985,1.153203
0.822665,0.838119,0.799517,0.744888,0.851351,0.631579,0.741465,"""2024_01_17_B7A1R1_P1T3_H20_P18""","""2024_01_17_B7A1R1_P1T3_H20_P18…","""2024_01_17_B7A1R1_P1T3""",148,57,"""H20""","""EMD_Ser54Phe""",135,74,"""P18""","""EMD""",false,0.477032,0.564885,1.096296,1.298246
0.628188,0.437454,0.613327,0.573296,0.357616,0.785714,0.571665,"""2024_01_19_B7A1R1_P3T4_D20_P18""","""2024_01_19_B7A1R1_P3T4_D20_P18…","""2024_01_19_B7A1R1_P3T4""",1782,616,"""D20""","""SDC3_Val136Ile""",1170,302,"""P18""","""SDC3""",false,0.396341,0.328976,1.523077,2.039735
0.53289,0.593545,0.5555,0.52284,0.477089,0.576378,0.526733,"""2024_01_17_B7A1R1_P1T2_C09_E07""","""2024_01_17_B7A1R1_P1T2_C09_E07…","""2024_01_17_B7A1R1_P1T2""",1796,635,"""C09""","""AGXT_Phe152Ile""",1379,742,"""E07""","""AGXT""",false,0.434331,0.538853,1.302393,1.168504
0.997898,0.998109,0.998656,0.99106,0.994764,0.988848,0.991806,"""2024_01_17_B7A1R1_P2T3_M04_K04""","""2024_01_17_B7A1R1_P2T3_M04_K04…","""2024_01_17_B7A1R1_P2T3""",1530,269,"""M04""","""GOSR2_Gly144Trp""",492,191,"""K04""","""GOSR2""",false,0.243323,0.415217,3.109756,1.408377


In [44]:
thresh = 3
min_class_num = 2

# Add useful columns (type, batch)
metrics_df = metrics_df.with_columns(
    pl.format(
        "A{}P{}",  # 组合格式
        pl.col("Plate").str.extract(r'A(\d+)', 1),  # 提取A后的数字
        pl.col("Plate").str.extract(r'_P(\d+)T', 1)  # 提取P后的数字（排除T后的部分）
    ).alias("Allele_set")
)

metrics_df = metrics_df.with_columns(
    pl.when(pl.col("Full_Classifier_ID").str.contains("true"))
    .then(pl.lit("localization"))
    .otherwise(pl.lit("morphology"))
    .alias("Classifier_type"),
    pl.col("Full_Classifier_ID").str.split("_").list.last().alias("Batch"),
)

# Filter based on class imbalance
metrics_ctrl = (
    metrics_df.filter(
        (pl.col("Training_imbalance") < thresh) & (pl.col("Metadata_Control"))
    )
    .select(["Classifier_type", "Batch", "AUROC"])
    .group_by(["Classifier_type", "Batch"])
    .quantile(0.99)
).rename({"AUROC": "AUROC_thresh"})

# Merge with metrics_df and decide whether it passed the threshold
metrics_df = metrics_df.join(metrics_ctrl, on=["Classifier_type", "Batch"])

metrics_df

AUROC,AUPRC,AUBPRC,Macro_F1,Sensitivity,Specificity,Balanced_Accuracy,Classifier_ID,Full_Classifier_ID,Plate,trainsize_0,testsize_0,well_0,allele_0,trainsize_1,testsize_1,well_1,allele_1,Metadata_Control,train_prob_1,test_prob_1,Training_imbalance,Testing_imbalance,Allele_set,Classifier_type,Batch,AUROC_thresh
f64,f64,f64,f64,f64,f64,f64,str,str,str,i64,i64,str,str,i64,i64,str,str,bool,f64,f64,f64,f64,str,str,str,f64
0.948019,0.925634,0.933927,0.886822,0.877966,0.895522,0.886744,"""2024_01_17_B7A1R1_P1T3_H06_J21""","""2024_01_17_B7A1R1_P1T3_H06_J21…","""2024_01_17_B7A1R1_P1T3""",748,335,"""H06""","""DES_Arg150Gln""",661,295,"""J21""","""DES""",false,0.469127,0.468254,1.131619,1.135593,"""A1P1""","""morphology""","""B7""",0.997315
0.600219,0.723602,0.587594,0.523385,0.463807,0.640394,0.552101,"""2024_01_22_B7A1R1_P4T4_H09_F09""","""2024_01_22_B7A1R1_P4T4_H09_F09…","""2024_01_22_B7A1R1_P4T4""",1153,203,"""H09""","""TTPA_Asp64Gly""",1070,373,"""F09""","""TTPA""",false,0.481332,0.647569,1.07757,1.837438,"""A1P4""","""localization""","""B7""",0.71418
0.529359,0.540754,0.533133,0.441133,0.701887,0.237354,0.46962,"""2024_01_22_B7A1R1_P4T3_C14_I12""","""2024_01_22_B7A1R1_P4T3_C14_I12…","""2024_01_22_B7A1R1_P4T3""",635,257,"""C14""","""TH_Ile219Thr""",663,265,"""I12""","""TH""",false,0.510786,0.507663,1.044094,1.031128,"""A1P4""","""morphology""","""B7""",0.997315
0.707959,0.723445,0.748561,0.654405,0.493724,0.823529,0.658627,"""2024_01_17_B7A1R1_P1T3_M02_K02""","""2024_01_17_B7A1R1_P1T3_M02_K02…","""2024_01_17_B7A1R1_P1T3""",702,272,"""M02""","""AMPD2_Glu697Asp""",579,239,"""K02""","""AMPD2""",false,0.451991,0.46771,1.212435,1.138075,"""A1P1""","""localization""","""B7""",0.71418
0.850989,0.90201,0.844164,0.728705,0.646091,0.874126,0.760108,"""2024_01_17_B7A1R1_P1T4_C01_A01""","""2024_01_17_B7A1R1_P1T4_C01_A01…","""2024_01_17_B7A1R1_P1T4""",520,143,"""C01""","""ACSF3_Ala197Thr""",832,243,"""A01""","""ACSF3""",false,0.615385,0.629534,1.6,1.699301,"""A1P1""","""morphology""","""B7""",0.997315
0.83739,0.870089,0.85311,0.751481,0.842995,0.657382,0.750188,"""2024_01_17_B7A1R1_P2T4_C08_M06""","""2024_01_17_B7A1R1_P2T4_C08_M06…","""2024_01_17_B7A1R1_P2T4""",670,359,"""C08""","""GSS_Arg283Cys""",1074,414,"""M06""","""GSS""",false,0.615826,0.535576,1.602985,1.153203,"""A1P2""","""localization""","""B7""",0.71418
0.822665,0.838119,0.799517,0.744888,0.851351,0.631579,0.741465,"""2024_01_17_B7A1R1_P1T3_H20_P18""","""2024_01_17_B7A1R1_P1T3_H20_P18…","""2024_01_17_B7A1R1_P1T3""",148,57,"""H20""","""EMD_Ser54Phe""",135,74,"""P18""","""EMD""",false,0.477032,0.564885,1.096296,1.298246,"""A1P1""","""morphology""","""B7""",0.997315
0.628188,0.437454,0.613327,0.573296,0.357616,0.785714,0.571665,"""2024_01_19_B7A1R1_P3T4_D20_P18""","""2024_01_19_B7A1R1_P3T4_D20_P18…","""2024_01_19_B7A1R1_P3T4""",1782,616,"""D20""","""SDC3_Val136Ile""",1170,302,"""P18""","""SDC3""",false,0.396341,0.328976,1.523077,2.039735,"""A1P3""","""localization""","""B7""",0.71418
0.53289,0.593545,0.5555,0.52284,0.477089,0.576378,0.526733,"""2024_01_17_B7A1R1_P1T2_C09_E07""","""2024_01_17_B7A1R1_P1T2_C09_E07…","""2024_01_17_B7A1R1_P1T2""",1796,635,"""C09""","""AGXT_Phe152Ile""",1379,742,"""E07""","""AGXT""",false,0.434331,0.538853,1.302393,1.168504,"""A1P1""","""morphology""","""B7""",0.997315
0.997898,0.998109,0.998656,0.99106,0.994764,0.988848,0.991806,"""2024_01_17_B7A1R1_P2T3_M04_K04""","""2024_01_17_B7A1R1_P2T3_M04_K04…","""2024_01_17_B7A1R1_P2T3""",1530,269,"""M04""","""GOSR2_Gly144Trp""",492,191,"""K04""","""GOSR2""",false,0.243323,0.415217,3.109756,1.408377,"""A1P2""","""localization""","""B7""",0.71418


In [None]:
# Must be at least min_class_num classifiers per batch
# Number of classifiers is the same for localization and morph, so just use morph
classifier_count = (
    metrics_df.filter(
        (~pl.col("Metadata_Control"))
        & (pl.col("Classifier_type") == "localization")
    )
    .group_by(["allele_0", "Allele_set", "Batch", "allele_1"])
    .agg([pl.len().alias("Number_classifiers")])
)
classifier_count = classifier_count.pivot(
    index=["allele_0", "allele_1", "Allele_set"],
    on="Batch",
    values="Number_classifiers",
)
print("Total number of unique classifiers:", classifier_count.shape[0])
print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))
print("==========================================================================")

# Must be at least min_class_num classifiers per batch
# Number of classifiers is the same for localization and morph, so just use morph
classifier_count = (
    metrics_df.filter(
        (pl.col("Training_imbalance") < thresh)
        & (~pl.col("Metadata_Control"))
        & (pl.col("Classifier_type") == "localization")
    )
    .group_by(["allele_0", "Allele_set", "Batch", "allele_1"])
    .agg([pl.len().alias("Number_classifiers")])
)
classifier_count = classifier_count.pivot(
    index=["allele_0", "allele_1", "Allele_set"],
    on="Batch",
    values="Number_classifiers",
)
print("After filtering out classifiers with training imbalance > 3:")
print("Total number of unique classifiers:", classifier_count.shape)
print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))
print("==========================================================================")


classifier_count = classifier_count.filter(
    (pl.col(batch_id) >= min_class_num)
)
print("After filtering out alleles with available number of classifiers < 2:")
print("Total number of unique classifiers:", classifier_count.shape)
print("Total number of unique variant alleles:", len(classifier_count.select("allele_0").to_series().unique().to_list()))
print("Total number of unique WT genes:", len(classifier_count.select("allele_1").to_series().unique().to_list()))

# filter based on this
keep_alleles = classifier_count.select("allele_0").to_series().unique().to_list()
metrics_df = metrics_df.filter(
    ~((~pl.col("Metadata_Control")) & ~pl.col("allele_0").is_in(keep_alleles))
)

Total number of unique classifiers: 644
Total number of unique variant alleles: 644
Total number of unique WT genes: 124
After filtering out classifiers with imbalance > 3:
Total number of unique classifiers: (604, 4)
Total number of unique variant alleles: 604
Total number of unique WT genes: 118
After filtering out alleles with available number of classifiers < 2:
Total number of unique classifiers: (591, 4)
Total number of unique variant alleles: 591
Total number of unique WT genes: 115


In [None]:
# Filter by imbalance and calculate mean AUROC for each batch
metrics_wtvar = (
    (
        metrics_df.filter(
            (pl.col("Training_imbalance") < thresh) & (~pl.col("Metadata_Control"))
        )
    )
    .select([
        "AUROC",
        "Classifier_type",
        "Batch",
        "AUROC_thresh",
        "allele_0",
        "trainsize_0",
        "testsize_0",
        "trainsize_1",
        "testsize_1",
        "Allele_set",
        "Training_imbalance",
    ])
    .group_by(["Classifier_type", "allele_0", "Allele_set", "Batch", "AUROC_thresh"])
    .agg([
        pl.all()
        .exclude(["Classifier_type", "allele_0", "Allele_set", "Batch", "AUROC_thresh"])
        .mean()
        .name.suffix("_mean")
    ])
)

metrics_wtvar

# Write out results
# metrics_wtvar.write_csv(f"{metrics_dir}/metrics_summary.csv")

Classifier_type,allele_0,Allele_set,Batch,AUROC_thresh,AUROC_mean,trainsize_0_mean,testsize_0_mean,trainsize_1_mean,testsize_1_mean,Training_imbalance_mean
str,str,str,str,f64,f64,f64,f64,f64,f64,f64
"""localization""","""FTH1_Lys54Arg""","""A1P2""","""B7""",0.71418,0.570533,875.25,291.75,1008.75,336.25,1.157161
"""morphology""","""ALAS2_Gly254Ser""","""A1P1""","""B7""",0.997315,0.715024,523.5,174.5,553.5,184.5,1.057588
"""localization""","""EFHC1_Cys259Tyr""","""A1P1""","""B7""",0.71418,0.939442,1052.25,350.75,883.5,294.5,1.190766
"""morphology""","""AGXT_Arg289His""","""A1P1""","""B7""",0.997315,0.996732,557.0,96.0,1499.0,622.0,2.711352
"""localization""","""EFHC1_Asp210Asn""","""A1P1""","""B7""",0.71418,0.927736,1023.75,341.25,883.5,294.5,1.155941
"""morphology""","""MVK_Val377Ile""","""A1P3""","""B7""",0.997315,0.653887,966.75,322.25,977.25,325.75,1.068489
"""morphology""","""HPRT1_Asp194Glu""","""A1P2""","""B7""",0.997315,0.631202,1395.75,465.25,1864.5,621.5,1.335934
"""morphology""","""DES_Arg16Cys""","""A1P1""","""B7""",0.997315,0.539689,498.75,166.25,717.0,239.0,1.4413
"""localization""","""PKP2_Thr482Ala""","""A1P3""","""B7""",0.71418,0.516515,145.333333,40.666667,337.666667,129.333333,2.298867
"""morphology""","""PKP2_Ser140Phe""","""A1P3""","""B7""",0.997315,0.83379,336.0,112.0,350.25,116.75,1.313026


## Old implementation

In [None]:
# paths
snakemake_dir = ".."
pipeline = "profiles_tcdropped_filtered_var_mad_outlier_featselect"
bio_rep_1 = "2024_01_23_Batch_7"
# bio_rep_2 = "2024_02_06_Batch_8"
# bio_rep_combined = f"B{bio_rep_1.split('Batch_')[-1]}-B{bio_rep_2.split('Batch_')[-1]}"
# bio_rep_combined

'B7-B8'

In [3]:
res_br_1 = f"{snakemake_dir}/outputs/results/{bio_rep_1}/{pipeline}"
res_br_2 = f"{snakemake_dir}/outputs/results/{bio_rep_2}/{pipeline}"
metrics_dir = f"{snakemake_dir}/outputs/classification_metrics/{bio_rep_combined}/{pipeline}"

In [None]:
# read in bb classifier info
info_b7 = pl.read_csv(f"{res_br_1}/classifier_info.csv")
info_b7 = info_b7.with_columns(
    (pl.col("trainsize_1") / (pl.col("trainsize_0") + pl.col("trainsize_1"))).alias(
        "train_prob_1"
    ),
    (pl.col("testsize_1") / (pl.col("testsize_0") + pl.col("testsize_1"))).alias(
        "test_prob_1"
    ),
)

info_b8 = pl.read_csv(f"{res_br_2}/classifier_info.csv")
info_b8 = info_b8.with_columns(
    (pl.col("trainsize_1") / (pl.col("trainsize_0") + pl.col("trainsize_1"))).alias(
        "train_prob_1"
    ),
    (pl.col("testsize_1") / (pl.col("testsize_0") + pl.col("testsize_1"))).alias(
        "test_prob_1"
    ),
)
info = pl.concat([info_b7, info_b8])

In [None]:
# classifier predictions
preds_b7 = pl.scan_parquet(f"{res_br_1}/predictions.parquet")
preds_b7 = preds_b7.with_columns(pl.lit("batch7").alias("Batch")).collect()

preds_b8 = pl.scan_parquet(f"{res_br_2}/predictions.parquet")
preds_b8 = preds_b8.with_columns(pl.lit("batch8").alias("Batch")).collect()

preds = pl.concat([preds_b7, preds_b8]).with_columns(
    pl.concat_str(
        [pl.col("Classifier_ID"), pl.col("Metadata_Protein"), pl.col("Batch")],
        separator="_",
    ).alias("Full_Classifier_ID")
)

In [None]:
# Initialize an empty list to store the results
results = []
classIDs = preds.select("Full_Classifier_ID").to_series().unique().to_list()

# Group by Classifier_ID and compute metrics for each group
for id in tqdm(classIDs):
    metrics = compute_metrics(preds.filter(pl.col("Full_Classifier_ID") == id))
    metrics["Full_Classifier_ID"] = id
    results.append(metrics)

# Convert the results to a Polars DataFrame
metrics_df = pl.DataFrame(results)

# Add classifier info and save
metrics_df = metrics_df.join(info, on="Classifier_ID")
metrics_df = metrics_df.with_columns(
    (
        pl.max_horizontal(["trainsize_0", "trainsize_1"])
        / pl.min_horizontal(["trainsize_0", "trainsize_1"])
    ).alias("Training_imbalance"),
    (
        pl.max_horizontal(["testsize_0", "testsize_1"])
        / pl.min_horizontal(["testsize_0", "testsize_1"])
    ).alias("Testing_imbalance"),
)
metrics_df.write_csv(f"{metrics_dir}/metrics.csv")