In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.cluster.hierarchy import linkage, leaves_list

In [None]:
import pre_sal_ii.libs.plot_context as pc

from importlib import reload
import localizable_resources as lr

def reload_libs_env():
    from dotenv import load_dotenv
    load_dotenv(".env", override=True)

    reload(pc)
    reload(lr)

reload_libs_env()

global_sizes = pc.rc_sizes(16, 21, 24, [8, 8])
MyPlot = pc.create_plot_context(global_sizes, reload_libs_env)

In [None]:
import os
import json
import glob
from importlib import reload
from tqdm.notebook import tqdm

import re
def match_group(regex, string, default=None, convert=str):
    m = re.search(regex, string)
    if m is None:
        return default
    return convert(m.group(1))

def add_items(all_data, data, info, type):
    all_data.append({**data["50p"], "type": type, "criteria": "50p", **info})
    all_data.append({**data["90p"], "type": type, "criteria": "90p", **info})
    all_data.append({**data["summed"], "type": type, "criteria": "summed", **info})
    all_data.append({**data["pores_1_50p"], "type": type, "criteria": "pores_1_50p", **info})
    all_data.append({**data["pores_1_90p"], "type": type, "criteria": "pores_1_90p", **info})
    all_data.append({**data["pores_25p_50p"], "type": type, "criteria": "pores_25p_50p", **info})
    all_data.append({**data["pores_25p_90p"], "type": type, "criteria": "pores_25p_90p", **info})
    all_data.append({**data["pores_50p_50p"], "type": type, "criteria": "pores_50p_50p", **info})
    all_data.append({**data["pores_50p_90p"], "type": type, "criteria": "pores_50p_90p", **info})
    all_data.append({**data["pores_25p_sum"], "type": type, "criteria": "pores_25p_sum", **info})
    all_data.append({**data["pores_50p_sum"], "type": type, "criteria": "pores_50p_sum", **info})
    all_data.append({**data["pores_75p_sum"], "type": type, "criteria": "pores_75p_sum", **info})

def process_image_files():
    import pre_sal_ii
    reload(pre_sal_ii)
    prev_progress = pre_sal_ii.progress

    # Find all matching images
    files = sorted(glob.glob("../out/with_pwr/stats/stats_pred_8fold_1.2*.json"))

    try:
        bar = tqdm(files)
        pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
        all_data = []

        for filepath in bar:
            # Extract the "args" from the filename
            filename = os.path.basename(filepath)
            # filename: image_pred_8fold_true1.2{args}.png
            args = filename.removeprefix("stats_pred_8fold_1.2").removesuffix(".json")

            selector_power = match_group(r"_selector_pwr=([0-9]+\.[0-9]+)", args, 0.0, float)
            use_channels = match_group(r"_channels=(False|True)", args, True, lambda x: x == "True")
            stdev_channel_power = match_group(r"_channels_pwr=([0-9]+\.[0-9]+)", args, 0.0, float)
            mean_channel_weight = match_group(r"_mean_wt=([0-9]+\.[0-9]+)", args, 1.0, float)
            stdev_channel_weight = match_group(r"_stdev_wt=([0-9]+\.[0-9]+)", args, 1.0, float)
            color_channels_weight = match_group(r"_color_wt=([0-9]+\.[0-9]+)", args, 1.0, float)
            normalize_stdev = match_group(r"_stdev_norm=(False|True)", args, True, lambda x: x == "True")

            info = {
                "selector_power": selector_power,
                "use_channels": use_channels,
                "stdev_channel_power": stdev_channel_power,
                "mean_channel_weight": mean_channel_weight,
                "stdev_channel_weight": stdev_channel_weight,
                "color_channels_weight": color_channels_weight,
                "normalize_stdev": normalize_stdev,
            }

            # Load the JSON data
            with open(filepath, 'r') as f:
                data = json.load(f)
            
            add_items(all_data, data["all"], info, "all")
            add_items(all_data, data["groups"][0], info, "micro")
            add_items(all_data, data["groups"][1], info, "small")
            add_items(all_data, data["groups"][2], info, "medium")
            add_items(all_data, data["groups"][3], info, "large")
            
            
        return all_data
    finally:
        pre_sal_ii.progress = prev_progress

all_data = process_image_files()
df = pd.DataFrame(all_data)

In [None]:
def add_metrics(df):
    df = df.copy()
    
    df["accuracy"] = (df.tp + df.tn) / (df.tp + df.fp + df.fn + df.tn)
    df["precision"] = df.tp / (df.tp + df.fp)
    df["recall"] = df.tp / (df.tp + df.fn)
    df["f1"] = 2 * df["precision"] * df["recall"] / (df["precision"] + df["recall"])
    df["iou"] = df.tp / (df.tp + df.fp + df.fn)
    
    return df


In [None]:
df = add_metrics(df)
df = df.dropna()
df

In [None]:
df.to_csv("../out/with_pwr/stats/pore_type_supervised_1.2.2_stats.csv", index=False)

In [None]:
from sklearn.tree import DecisionTreeRegressor

def get_feature_importances(df):
    df_encoded = pd.get_dummies(df, columns=["type", "criteria", "mean_channel_weight", "use_channels"], drop_first=True)
    features = df_encoded.drop(columns=["tp","fp","fn","tn","f1","accuracy","precision","recall","iou"])
    target = df_encoded["f1"]
    tree = DecisionTreeRegressor(max_depth=3)
    tree.fit(features, target)
    importances = pd.Series(tree.feature_importances_, index=features.columns)
    return importances.sort_values(ascending=False)

importances_all = get_feature_importances(df)
print(importances_all)


In [None]:
importances_large = get_feature_importances(df[df.type == "large"])
print(importances_large)


In [None]:
importances_medium = get_feature_importances(df[df.type == "medium"])
print(importances_medium)


In [None]:
importances_small = get_feature_importances(df[df.type == "small"])
print(importances_small)


In [None]:
importances_micro = get_feature_importances(df[df.type == "micro"])
print(importances_micro)


In [None]:
df_importances = pd.DataFrame({
    "all": importances_all,
    "large": importances_large,
    "medium": importances_medium,
    "small": importances_small, 
    "micro": importances_micro,
})
df_importances = df_importances.dropna()
df_importances = df_importances[(df_importances != 0).any(axis=1)]
df_importances = df_importances.loc[df_importances.mean(axis=1).sort_values(ascending=False).index]
print(df_importances.to_latex(
        index=True,
        caption="Feature Importances for Model Performance",
        label="tab:feature_importances",
    ))

In [None]:
crits = ["90p", "summed", "pores_1_50p", "pores_25p_sum"]

df_best = df[(df["mean_channel_weight"] == 255.0) & (df["criteria"].isin(crits))]
df_best

In [None]:
name = "selector_power_vs_f1_boxplot"
with MyPlot(f"../images/{name}.pdf", figsize=[10, 6]) as mp:

    fig, ax = plt.subplots()
    
    df_best.boxplot(
        ax=ax,
        column="f1",
        by="selector_power",
        grid=False,
    )
    plt.xlabel("selector power")
    plt.ylabel("F1 score")
    plt.title("Effect of selector power on model quality")
    plt.suptitle("")  

In [None]:
df_best

In [None]:
colors = [
    (1, 0, 0),  # red for -1
    (1, 1, 1),  # white for 0
    (0, 0, 1),  # blue for +1
]

from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list("red_white_blue", colors, N=256)


# Compute correlation matrix
numeric_df = df.select_dtypes(include=[float, int]).drop(columns=["stdev_channel_weight", "color_channels_weight"])
numeric_df = numeric_df.loc[:, numeric_df.std() > 0]

# Compute correlation matrix
corr = numeric_df.corr()

# ---- Hierarchical clustering on the correlation matrix ----
# Convert correlation to a distance matrix
distance = 1 - corr

# Perform clustering
link = linkage(distance, method='average')

# Get ordering of rows/columns
order = leaves_list(link)

# Reorder the correlation matrix
corr_reordered = corr.values[order][:, order]
labels_reordered = corr.columns[order]

# ---- Plot heatmap ----
fig, ax = plt.subplots(figsize=(10, 8))

cax = ax.imshow(corr_reordered, interpolation='nearest', cmap=cmap, vmin=-1, vmax=1)
fig.colorbar(cax)

ax.set_xticks(np.arange(len(labels_reordered)))
ax.set_yticks(np.arange(len(labels_reordered)))

ax.set_xticklabels(labels_reordered, rotation=90)
ax.set_yticklabels(labels_reordered)

plt.tight_layout()
plt.show()

In [None]:
# Split datasets
from pre_sal_ii.libs.plot_context import rc_sizes


def violinplot(df):
    df_no = df[df["use_channels"] == False]
    df_yes = df[df["use_channels"] == True]

    metrics = ["f1", "iou", "recall", "precision", "accuracy"]
    metrics_names = ["F1", "IoU", "Recall", "Precision", "Accuracy"]

    # Prepare data in the form:
    #   data[metric] = [values_when_no, values_when_yes]
    data = {m: [df_no[m].dropna().values, df_yes[m].dropna().values] for m in metrics}

    fig, axes = plt.subplots(1, len(metrics), sharey=False)

    for i, (metric, metric_name) in enumerate(zip(metrics, metrics_names)):
        ax = axes[i]

        parts = ax.violinplot(
            data[metric],
            positions=[1, 2],
            showmeans=True,
            showextrema=True
        )

        # Set x-axis
        ax.set_xticks([1, 2])
        ax.set_xticklabels(["No user channels", "With user channels"], rotation=60)
        ax.set_ylim(-0.1, 1.1)
        ax.set_title(metric_name)
        ax.set_ylabel("")

name = "user_channels_effect_violinplot"
with MyPlot(f"../images/{name}.pdf", figsize=[15, 5], sizes=pc.rc_sizes(12, 14, 16, [8, 8])) as mp:
    violinplot(df)

types = df["type"].unique()
for type in types:
    df_type = df[df["type"] == type]
    name = f"user_channels_effect_violinplot_{type}"
    with MyPlot(f"../images/{name}.pdf", figsize=[15, 5], sizes=pc.rc_sizes(12, 14, 16, [8, 8])) as mp:
        violinplot(df_type)