In [None]:
import pandas as pd
from openpyxl.styles import PatternFill, Font, Alignment, Border, Side
from openpyxl.utils import get_column_letter

from zero_shot_validation_scripts.utils import SUFFIX_PREFIX_DICT


In [None]:
row_or_col_norm="row"
assert row_or_col_norm in ["row", "col"], "row_or_col_norm must be either 'row' or 'col'"

sort_by="Percent Accuracy" if row_or_col_norm=="row" else "Percent Precision"
confusion_matrix_col_prefix = "Percent of All Cells Predicted as " if row_or_col_norm=="col" else "Percent Predicted as Class "

In [35]:
def interpolate_color(color_start, color_end, t):
    """Interpolate between two hex colors based on t (0.0 to 1.0)."""
    start_rgb = tuple(int(color_start[i:i + 2], 16) for i in (0, 2, 4))
    end_rgb = tuple(int(color_end[i:i + 2], 16) for i in (0, 2, 4))
    interpolated_rgb = tuple(int(start + t * (end - start)) for start, end in zip(start_rgb, end_rgb))
    return ''.join(f'{x:02X}' for x in interpolated_rgb)


def apply_continuous_coloring(sheet, start_col):
    """Apply a continuous blue-hue color scale to the confusion matrix."""
    values = [cell.value for row in sheet.iter_rows(min_row=2, min_col=start_col + 1) for cell in row if isinstance(cell.value, (int, float))]
    if not values:
        return  # No numeric values found

    min_val, max_val = min(values), max(values)
    color_min, color_max = "A1E0F7", "003366"  # Light Blue -> Dark Blue

    for row in sheet.iter_rows(min_row=2, min_col=start_col + 1):
        for cell in row:
            if isinstance(cell.value, (int, float)):
                scale = (cell.value - min_val) / (max_val - min_val)
                color = interpolate_color(color_min, color_max, scale)
                cell.fill = PatternFill(start_color=color, end_color=color, fill_type="solid")
                cell.font = Font(color="000000" if scale < 0.5 else "FFFFFF")  # Adjust font color


def draw_diagonal_borders(sheet, start_col, num_classes):
    """Draw thick borders around diagonal cells of the confusion matrix."""
    thick_border = Border(left=Side(style="thick"), right=Side(style="thick"), top=Side(style="thick"), bottom=Side(style="thick"))
    for i in range(num_classes):
        cell = sheet.cell(row=i + 2, column=start_col + i + 1)
        cell.border = thick_border


def format_excel_sheets(input_files, sheet_names, col_names, output_file):
    """Formats performance metric CSVs into an Excel file with visual improvements."""
    with pd.ExcelWriter(output_file, engine="openpyxl") as writer:
        for input_file, sheet_name, col_name in zip(input_files, sheet_names, col_names):
            df = pd.read_csv(input_file)
            prefix, suffix = SUFFIX_PREFIX_DICT[col_name]
            df["class"] = df["class"].str.replace(prefix, "").str.replace(suffix, "")
            df.columns = df.columns.str.replace(prefix, "").str.replace(suffix, "").str.replace("n_samples_predicted_as_", confusion_matrix_col_prefix)

            confusion_matrix_start = 10  # Columns before confusion matrix
            pretty_names = {
                "class": "Class", "precision": "Percent Precision", "accuracy": "Percent Accuracy", "f1": "F1 Score",
                "rocauc": "AUROC", "recall_at_1": "Percent Recall @1", "recall_at_5": "Percent Recall @5",
                "recall_at_10": "Percent Recall @10", "recall_at_50": "Percent Recall @50", "n_samples_in_class": "Number of Samples in Class"
            }
            df.rename(columns=pretty_names, inplace=True)
            if row_or_col_norm=="col":
                df.iloc[:, confusion_matrix_start:] = df.iloc[:, confusion_matrix_start:].apply(lambda col: (100 * col) / col.sum(), axis=0)
            else:
                df.iloc[:, confusion_matrix_start:] = df.iloc[:, confusion_matrix_start:].apply(lambda row: (100 * row) / row.sum(), axis=1)
            for col in ["Percent Precision", "Percent Recall @1", "Percent Recall @5", "Percent Recall @10", "Percent Recall @50", "Percent Accuracy"]:
                df[col] *= 100

            # Sort by given column (e.g. accuracy, precision, etc.)
            sort_index = df[sort_by].sort_values(ascending=False).index
            cols_new_order = df.columns[:confusion_matrix_start].tolist() + [df.columns[confusion_matrix_start:][i] for i in sort_index]
            df=df.iloc[sort_index].loc[:,cols_new_order]

            df.to_excel(writer, sheet_name=sheet_name, index=False)
            workbook = writer.book
            sheet = workbook[sheet_name]

            # Header formatting
            for cell in sheet[1]:
                cell.font = Font(bold=True)
                cell.alignment = Alignment(wrap_text=True, horizontal="center", vertical="center")

            # Center alignment for all other cells
            for row in sheet.iter_rows(min_row=2):
                for cell in row:
                    cell.alignment = Alignment(horizontal="center", vertical="center")

            # Column width adjustments
            for col_idx, col in enumerate(sheet.columns, start=1):
                if not col[0].value.startswith(confusion_matrix_col_prefix):
                    max_length = max(len(str(cell.value)) for cell in col if cell.value)
                    sheet.column_dimensions[get_column_letter(col_idx)].width = max_length + 2
                else:
                    sheet.column_dimensions[get_column_letter(col_idx)].width = 18

            sheet.row_dimensions[1].height = 100

            # Borders
            thin_border = Border(left=Side(style="thin"), right=Side(style="thin"), top=Side(style="thin"), bottom=Side(style="thin"))
            for row in sheet.iter_rows():
                for cell in row:
                    cell.border = thin_border

            # Confusion matrix header coloring
            for row in sheet.iter_rows(min_row=1, min_col=confusion_matrix_start + 1, max_col=sheet.max_column, max_row=1):
                for cell in row:
                    cell.fill = PatternFill(start_color="DDDDDD", end_color="DDDDDD", fill_type="solid")

            # Number formatting
            for row in sheet.iter_rows(min_row=2):
                for cell in row:
                    if isinstance(cell.value, float) and cell.value != 0:
                        cell.number_format = "0.00"

            # Confusion matrix visual enhancements
            apply_continuous_coloring(sheet, confusion_matrix_start)
            num_classes = df["Class"].nunique()
            draw_diagonal_borders(sheet, confusion_matrix_start, num_classes)

            # Drop Percent Recall @50 if all values are NaN
            for recall_col in ["Percent Recall @1", "Percent Recall @5", "Percent Recall @10", "Percent Recall @50"]:
                if df[recall_col].isna().all():
                    col_idx = df.columns.get_loc(recall_col) + 1
                    df.drop(columns=[recall_col], inplace=True)
                    sheet.delete_cols(col_idx)
                    

input_files = list(snakemake.input.datasets_perlabel)
output_file = snakemake.output.confusion_mtx_table
datasets, label_cols = snakemake.params.datasets, snakemake.params.label_cols

sheet_names = [f"{dataset_name} - {column_name}" for dataset_name, column_name in zip(snakemake.params.dataset_names_pretty,
                                                                                       snakemake.params.label_cols_pretty)]
format_excel_sheets(input_files, sheet_names, label_cols, output_file)