# Pore type prediction from thin-section images 1.2

In this notebook, we continue to the analysis of the models trained in the 1.2.0 notebook.

For this analysis, we take a mask of the pores that we want to analyse. The mask will consider:
- all sizes of pores toghether
- different sizes of pores

From the different sizes of pores, we will create distributions over the sizes of the extracted information.
The info to be extracted is:
- positive pixels (50% threshold)
- negative pixels (50% threshold)
- positive pixels (90% threshold)
- negative pixels (90% threshold)
- true pixels
- false pixels
- summed positive pixels
- summed negative pixels
- true positive pixels (50% threshold)
- false positive pixels (50% threshold)
- false negative pixels (50% threshold)
- true negative pixels (50% threshold)
- true positive pixels (90% threshold)
- false positive pixels (90% threshold)
- false negative pixels (90% threshold)
- true negative pixels (90% threshold)
- summed true positive pixels
- summed false positive pixels
- summed true negative pixels
- summed false negative pixels
- pores with one 50% pixel
- pores without one 50% pixel
- pores with 25% of 50% pixels
- pores without 25% of 50% pixels
- pores with majority 50% pixels
- pores without majority 50% pixels
- pores with one 90% pixel
- pores without one 90% pixel
- pores with 25% of 90% pixels
- pores without 25% of 90% pixels
- pores with majority 90% pixels
- pores without majority 90% pixels

In [None]:
import os
print(os.getcwd())

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

from importlib import reload

import pre_sal_ii.models as models
reload(models)
import pre_sal_ii.custom.module1 as m1
reload(m1)
models.set_all_seeds(0)

import cv2
import numpy as np

from importlib import reload
import pre_sal_ii.models.nn as nn_models
import pre_sal_ii.models.ds as ds_models
reload(ds_models)
from torch.utils.data import DataLoader

import pre_sal_ii.improc as improc
reload(improc)

from typing import cast
from skimage.measure import label, regionprops


import pre_sal_ii
from tqdm.notebook import tqdm
pre_sal_ii.progress = tqdm

In [None]:
import re
def match_group(regex, string, default=None, convert=str):
    m = re.search(regex, string)
    if m is None:
        return default
    return convert(m.group(1))

In [None]:
def combine_pred_true(pred, true, mean=None, std=None):
    import torch
    image_pred_true = np.zeros([*true.shape, 3], dtype=np.uint8)
    image_pred_true = torch.tensor(image_pred_true, dtype=torch.uint8).permute(2, 0, 1)
    if mean is not None:
        image_pred_true[0,:,:] = torch.tensor(mean, dtype=torch.uint8)
    elif std is not None:
        image_pred_true[0,:,:] = torch.tensor(std, dtype=torch.uint8)
    image_pred_true[1,:,:] = torch.tensor(true, dtype=torch.uint8)
    image_pred_true[2,:,:] = torch.tensor(pred, dtype=torch.uint8)
    image_pred_true = image_pred_true.permute(1, 2, 0)
    image_pred_true = image_pred_true.numpy()
    return image_pred_true

In [None]:
def get_filename_args(
            selector_power=0.0, use_channels=False,
            stdev_channel_power=0.0,
            mean_channel_weight=1.0,
            stdev_channel_weight=1.0,
            color_channels_weight=1.0,
            normalize_stdev=True,
        ):
    args = f"_selector_pwr={selector_power:.2f}"
    if use_channels:
        args += f"_channels_pwr={stdev_channel_power:.2f}"
    else:
        args += f"_channels=False"
    if color_channels_weight != 1.0:
        args += f"_color_wt={color_channels_weight:.2f}"
    if mean_channel_weight != 1.0:
        args += f"_mean_wt={mean_channel_weight:.2f}"
    if stdev_channel_weight != 1.0:
        args += f"_stdev_wt={stdev_channel_weight:.2f}"
    if not normalize_stdev:
        args += f"_stdev_norm=False"
    return args

In [None]:
print("Getting input images...")
inputImage, inputImage_no_gamma = m1.get_input_image()
pores_image3 = m1.get_probability_maps_simple(inputImage)
binaryImage_clRed = m1.load_manually_categorized_image()
cond = pores_image3 == 1.0

print(f"max(cond.flatten())={max(cond.flatten())}")
plt.imshow(cond)
plt.show()

mean_image, stdev_image = None, None
def get_mean_stdev():
    global mean_image, stdev_image
    if mean_image is None or stdev_image is None:
        mean_image, stdev_image = m1.get_mean_stdev(inputImage_no_gamma)
    return mean_image, stdev_image

def open_image_and_split(file):
    img = cv2.imread(file, cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError(f"Could not read image file: {file}")

    # OpenCV loads images in BGR order
    blue, green, red = cv2.split(img)

    return green, red

def run_images_with_best_models(model_file, debug=False, show_images=False):
    import torch
    if debug: print("Starting run...")
    fold_count = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    selector_power = match_group(r"_selector_pwr=([0-9]+\.[0-9]+)", model_file, 0.0, float)
    use_channels = match_group(r"_channels=(False|True)", model_file, True, lambda x: x == "True")
    stdev_channel_power = match_group(r"_channels_pwr=([0-9]+\.[0-9]+)", model_file, 0.0, float)
    mean_channel_weight = match_group(r"_mean_wt=([0-9]+\.[0-9]+)", model_file, 1.0, float)
    stdev_channel_weight = match_group(r"_stdev_wt=([0-9]+\.[0-9]+)", model_file, 1.0, float)
    color_channels_weight = match_group(r"_color_wt=([0-9]+\.[0-9]+)", model_file, 1.0, float)
    normalize_stdev = match_group(r"_stdev_norm=(False|True)", model_file, True, lambda x: x == "True")

    args = f"_selector_pwr={selector_power:.2f}"
    if use_channels:
        args += f"_channels_pwr={stdev_channel_power:.2f}"
    else:
        args += f"_channels=False"
    if color_channels_weight != 1.0:
        args += f"_color_wt={color_channels_weight:.2f}"
    if mean_channel_weight != 1.0:
        args += f"_mean_wt={mean_channel_weight:.2f}"
    if stdev_channel_weight != 1.0:
        args += f"_stdev_wt={stdev_channel_weight:.2f}"
    if not normalize_stdev:
        args += f"_stdev_norm=False"

    if os.path.exists(f"../out/with_pwr/image_pred_8fold_true1.2{args}.png"):
        gt, pred = open_image_and_split(f"../out/with_pwr/image_pred_8fold_true1.2{args}.png")
        mean_image2, stdev_image2 = get_mean_stdev()
        return binaryImage_clRed, pred, mean_image2, stdev_image2

    channels = 3 if not use_channels else 5
    models2 = [nn_models.EncoderNN(initial_dim=channels*32*32).to(device) for _ in range(fold_count)]
    checkpoint = torch.load(model_file)
    for i, m in enumerate(models2):
        m.load_state_dict(checkpoint["models"][i])
    fold_losses2 = checkpoint["fold_losses"]
    
    model = models2[np.argmin(fold_losses2)]

    if debug: print("Adjusting input image...")
    inputImage2 = inputImage.astype(np.float32)/255.
    mean_image2, stdev_image2 = None, None
    if use_channels:
        mean_image2, stdev_image2 = get_mean_stdev()
        stdev_image2 = stdev_image2 / max(stdev_image2.flatten())
        inputImage2 = inputImage2.astype(np.float32)/255.
        inputImage2 = np.dstack((inputImage2, mean_image2, stdev_image2)) # pyright: ignore[reportPossiblyUnboundVariable]

    if debug: print("Creating dataset...")
    dataset2 = ds_models.WhitePixelRegionDataset(
        pores_image3, inputImage2, binaryImage_clRed/255.,
        num_samples=-1, seed=None, use_img_to_tensor=True)
    dataloader2 = DataLoader(dataset2, batch_size=1024, shuffle=False)

    trainer_best = m1.MyTrainer101x101to32x32(model, None, None, device=device, channels=channels)

    if debug: print("Inferring...")
    pred_image = np.zeros_like(binaryImage_clRed, dtype=np.uint8)

    count_gt_half = 0

    from pre_sal_ii import progress
    with torch.no_grad():
        for it, inputs in enumerate(progress(dataloader2)):
            _, _, coords = inputs
            step, outputs = trainer_best.train_epoch_step(inputs)
            Y = outputs

            xs = coords[:,1].cpu().numpy()
            ys = coords[:,0].cpu().numpy()
            vs = Y[:,0].cpu().numpy()
            pred_image[ys, xs] = vs*255

    if debug: print("Creating images...")
    # if show_images:
    #     plt.imshow(pred_image, vmin=0, vmax=255, cmap="gray")
    #     plt.show()
    # os.makedirs("../out/with_pwr/", exist_ok=True)
    # cv2.imwrite(f"../out/with_pwr/sup_pred_8fold_1.2{args}.png", pred_image)
    image_pred_true = combine_pred_true(pred_image, binaryImage_clRed)
    if show_images:
        plt.imshow(image_pred_true[:,:,::-1])
        plt.show()
    os.makedirs("../out/with_pwr/", exist_ok=True)
    cv2.imwrite(f"../out/with_pwr/image_pred_8fold_true1.2{args}.png", image_pred_true)

    return binaryImage_clRed, pred_image, mean_image2, stdev_image2


In [None]:

import numpy as np
import cv2
from scipy.ndimage import distance_transform_edt, label


def open_but_keep_small_objects(mask, thickness):
    """
    """

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (thickness, thickness))
    opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)

    dist = distance_transform_edt(mask)

    labeled_mask, n_labels = label(mask)
    regions = regionprops(labeled_mask)
    
    for r in regions:
        coords = tuple(r.coords.T)
        max_radius = dist[coords].max()
        diam = 2 * max_radius
        
        if diam < thickness:
            opened[coords] = 255

    return opened

def thickness_based_components_fast(mask, thicknesses):
    """
    mask: binary mask of pores (uint8 0/1)
    thickness: minimum thickness to preserve connections
    """

    # 1. Opening
    opened = mask
    for thickness in thicknesses:
        opened = open_but_keep_small_objects(opened, thickness)

    # 2. Label seeds (fast)
    seeds, n_labels = label(opened)

    # Convert seeds to 32-bit for watershed
    markers = seeds.astype(np.int32)

    # 4. Fast reconstruction via watershed
    labels = cv2.watershed(
        cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR),  # dummy image
        markers
    )

    # Watershed gives -1 on boundaries â†’ fix
    labels[labels < 0] = 0

    return labels, n_labels

def merge_ground_truth_labels(labels, gt):
    labels = labels.copy()
    
    regions = regionprops(labels)

    max_label = labels.max()
    w = labels.shape[1]
    h = labels.shape[0]
    max_dist = 0

    for r in regions:
        coords = r.coords
        gt_values = gt[coords[:, 0], coords[:, 1]]
        area_proportion = np.sum(gt_values > 0) / len(gt_values)
        if area_proportion >= 0.5:
            # remove this label
            labels[coords[:, 0], coords[:, 1]] = 0
        if r.area < 100:
            # remove small objects
            labels[coords[:, 0], coords[:, 1]] = 0
        ys = (coords.T[0] - h/2)/(h/2)
        xs = (coords.T[1] - w/2)/(w/2)
        max_dist = max(max_dist, max((xs**2 + ys**2)**0.5))

    # relabel using gt
    labeled_gt, n_gt_labels = label(gt)
    regions = regionprops(labeled_gt)
    
    for r in regions:
        coords = r.coords
        ys = (coords.T[0] - h/2)/(h/2)
        xs = (coords.T[1] - w/2)/(w/2)
        if r.area < 16:
            continue
        if any((xs**2 + ys**2)**0.5 >= max_dist):
            continue
        max_label += 1
        labels[coords[:, 0], coords[:, 1]] = max_label
    
    return labels, max_label

In [None]:
#gt, pred, mean, std = run_images_with_best_models("../models/supervised-8-folds-1.2_selector_pwr=1.00_channels_pwr=0.00.pt")
#gt, pred, mean, std = run_images_with_best_models("../models/supervised-8-folds-1.2_selector_pwr=1.00_channels=False.pt")
gt, pred, mean, std = run_images_with_best_models("../models/supervised-8-folds-1.2_selector_pwr=0.00_channels_pwr=0.00_mean_wt=255.00.pt")

In [None]:
max(gt.flatten()), max(pred.flatten())

In [None]:
test_image = np.zeros([50,50], dtype=np.uint8)
test_image[10:11,9:15] = 255
test_image[30:45,30:45] = 255
test_image[10:25,30:45] = 255
plt.imshow(cv2.morphologyEx(test_image, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))))
plt.show()
plt.imshow(open_but_keep_small_objects(test_image, 10))
plt.show()

In [None]:
def show_random_colors(labels):
    # labels: 2D array of ints (0..N)
    n_labels = labels.max() + 1

    # generate random colors for each label
    colors = np.random.rand(n_labels, 3)  # RGB in [0,1]

    # ensure background (label 0) is black
    colors[0] = [0, 0, 0]

    from matplotlib.colors import ListedColormap
    cmap = ListedColormap(colors)

    plt.imshow(labels, cmap=cmap, interpolation='nearest')
    plt.axis('off')

import numpy as np
import cv2

def save_label_image_cv2(labels, filename):
    """
    labels: 2D array of ints (0..N)
    filename: path to save, e.g. 'out.png'
    """
    labels = labels.astype(np.int32)
    n_labels = labels.max() + 1

    # Generate random colors for each label
    colors = np.random.randint(0, 255, (n_labels, 3), dtype=np.uint8)
    
    # Make label 0 = black (often background)
    colors[0] = [0, 0, 0]

    # Map each pixel: (H, W, 3)
    color_img = colors[labels]

    # Save BGR for OpenCV
    cv2.imwrite(filename, cv2.cvtColor(color_img, cv2.COLOR_RGB2BGR))


In [None]:
gt_labels, n_gt_labels = label(gt > 0)
plt.imshow(gt_labels, cmap='nipy_spectral')
plt.show()
regions_gt = regionprops(gt_labels)
min_pore_size = min([r.area for r in regions_gt if r.area > 0])
max_pore_size = max([r.area for r in regions_gt if r.area > 0])
print(f"min_pore_size={min_pore_size}")
plt.hist([r.area for r in regions_gt if r.area > 0], bins=50, range=(0, 2000))
print(sorted([r.area for r in regions_gt if r.area > 0]))
plt.show()

import math
from pre_sal_ii.improc import rescale
new_image_of_gt_by_area = np.zeros([gt.shape[0], gt.shape[1], 3], dtype=np.uint8)
cmap = plt.get_cmap('viridis')
for r in regions_gt:
    coords = r.coords
    area = r.area
    color = cmap(rescale(math.log(area), math.log(min_pore_size), math.log(max_pore_size), 1.0, 1./255))[0:3]
    new_image_of_gt_by_area[r.coords[:, 0], r.coords[:, 1], :] = np.array(color) * 255
plt.imshow(new_image_of_gt_by_area)
plt.show()
cv2.imwrite("../out/gt_pore_size_visualization.png", new_image_of_gt_by_area[:,:,::-1])

In [None]:
labels_cond = thickness_based_components_fast(cond.astype(np.uint8), [30, 40])[0]*cond
labels_gt_cond = merge_ground_truth_labels(labels_cond, (binaryImage_clRed > 0).astype(np.uint8))[0]
show_random_colors(labels_cond)
plt.show()
save_label_image_cv2(labels_cond, "../out/labels_cond.png")
show_random_colors(labels_gt_cond)
plt.show()
save_label_image_cv2(labels_gt_cond, "../out/labels_gt_cond.png")

In [None]:
label_cond, n_labels_cond = label(cond.astype(np.uint8))
regions_cond = regionprops(label_cond)
new_gt = np.zeros_like(gt)
for r in regions_cond:
    coords = r.coords
    gt_values = gt[coords[:, 0], coords[:, 1]]
    if np.any(gt_values > 0):
        new_gt[coords[:, 0], coords[:, 1]] = 1

gt_composite = combine_pred_true(new_gt*255, gt, cond*255)
plt.imshow(gt_composite)
plt.show()

In [None]:
inputImage2, inputImage_no_gamma = m1.get_input_image()
mean_image2, stdev_image2 = m1.get_mean_stdev(inputImage_no_gamma)

In [None]:
selector_mask = improc.preprocess_segments(
    mean_image, area_threshold=0.05, morphological_processing={"grow": 25})
plt.imshow(selector_mask)
plt.show()

In [None]:
mean_image2 = np.clip(mean_image2, 0, 255).astype(np.uint8)
pred_true_image = combine_pred_true(pred, gt, mean_image2)
plt.imshow(pred_true_image[:,:,::-1])
plt.show()

In [None]:
label_img, n_labels_img = cast(np.ndarray, label(gt > 0))
regions = regionprops(label_img)
areas = sorted([int(r.area) for r in regions])

print(f"areas={areas}")

intervals = []
prev = 0
while True:
    pos = int(len(areas)*0.75)
    if pos < 10:
        intervals.append((prev, float('inf')))
        break
    mid = areas[pos]
    areas = areas[pos + 1:]
    intervals.append((prev, mid))
    prev = mid

print(f"intervals={intervals}")

In [None]:
cmap = plt.get_cmap('viridis')
img = np.zeros([*gt.shape, 3], dtype=np.uint8)
for r in regions:
    area = int(r.area)
    for i, (low, high) in enumerate(intervals):
        color = cmap((i + 1) / len(intervals))[0:3]
        if low < area <= high:
            img[r.coords[:, 0], r.coords[:, 1], :] = np.array(color) * 255
            break

max_label = len(intervals)
print(f"max_label={max_label}")

plt.imshow(img, vmin=0, vmax=255)

In [None]:
label_img, n_labels_img = cast(np.ndarray, label(pred > 0))
regions = regionprops(label_img)

cmap = plt.get_cmap('viridis')

img = np.zeros([*gt.shape, 3], dtype=np.uint8)
for r in regions:
    area = int(r.area)
    for i, (low, high) in enumerate(intervals):
        color = cmap((i + 1) / len(intervals))[0:3]
        if low < area <= high:
            img[r.coords[:, 0], r.coords[:, 1], :] = (color * (pred[r.coords[:, 0], r.coords[:, 1]] / 255.0)[:, None]) * 255
            break

plt.imshow(img)

In [None]:
def get_filters(gt, pred, intervals):

    label_img, n_labels_img = label(pred > 0)
    regions = regionprops(label_img)

    img1 = np.zeros_like(gt, dtype=np.uint8)
    for r in regions:
        area = int(r.area)
        for i, (low, high) in enumerate(intervals):
            if low < area <= high:
                img1[label_img == r.label] = int(i + 1)
                break

    label_img, n_labels_img = label(gt > 0)
    regions = regionprops(label_img)

    img2 = np.zeros_like(gt, dtype=np.uint8)
    for r in regions:
        area = int(r.area)
        for i, (low, high) in enumerate(intervals):
            if low < area <= high:
                img2[label_img == r.label] = int(i + 1)
                break

    max_label = max(img1.max(), img2.max())
    print(f"max label={max_label}")

    filter = [
        (img1 == 1) | (img2 == 1),
        (img1 == 2) | (img2 == 2),
        (img1 == 3) | (img2 == 3),
        (img1 == 4) | (img2 == 4),
        ]
    return filter


In [None]:
print(max(pores_image3.flatten()))
# we multiply by pores_image3 because predictions are only valid where pores are present
filter = get_filters(gt*pores_image3, pred*pores_image3, intervals)
plt.subplots(1, 4, figsize=(16, 4))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.imshow(filter[i])

In [None]:
x1 = [True, False]
min(x1) == 0 and max(x1) == 1
1 + True

In [None]:
def get_stats(gt, pred, cond):
    assert1 = min(cond.flatten()) == 0 and max(cond.flatten()) == 1
    assert2 = min(gt.flatten()) == 0 and max(gt.flatten()) > 0
    # assert3 = min(pred.flatten()) == 0 and max(pred.flatten()) > 0
    
    if not assert1:
        print("Error: cond image is not binary!")
        print(f"min(cond.flatten())={min(cond.flatten())}, max(cond.flatten())={max(cond.flatten())}")
        assert assert1
    if not assert2:
        print("Error: gt image is not binary!")
        print(f"min(gt.flatten())={min(gt.flatten())}, max(gt.flatten())={max(gt.flatten())}")
        plt.imshow(gt)
        plt.show()
        assert assert2
    # if not assert3:
    #     print("Error: pred image is not in expected range!")
    #     print(f"min(pred.flatten())={min(pred.flatten())}, max(pred.flatten())={max(pred.flatten())}")
    #     plt.imshow(pred)
    #     plt.show()
    #     assert assert3
    
    cond_eq_1 = cond
    gt_gt_0 = (gt > 0)
    gt_eq_0 = (gt == 0) & cond_eq_1
    pred_leq_127 = (pred <= 127) & cond_eq_1
    pred_gt_127 = (pred > 127)
    pred_leq_229 = (pred <= 229) & cond_eq_1
    pred_gt_229 = (pred > 229)
    
    positive_50p = int(sum(pred_gt_127.flatten()))
    negative_50p = int(sum(pred_leq_127.flatten()))
    positive_90p = int(sum(pred_gt_229.flatten()))
    negative_90p = int(sum(pred_leq_229.flatten()))

    true_pixels = int(sum(gt_gt_0.flatten()))
    false_pixels = int(sum(gt_eq_0.flatten()))

    summed_positive_pixels = float(sum((pred/255.0).flatten()))
    summed_negative_pixels = float(sum(((255 - pred)/255.0*cond).flatten()))
    
    true_positive_50p = int(sum((gt_gt_0 & pred_gt_127).flatten()))
    false_positive_50p = int(sum((gt_gt_0 & pred_leq_127).flatten()))
    false_negative_50p = int(sum((gt_eq_0 & pred_gt_127).flatten()))
    true_negative_50p = int(sum((gt_eq_0 & pred_leq_127).flatten()))

    true_positive_90p = int(sum((gt_gt_0 & pred_gt_229).flatten()))
    false_positive_90p = int(sum((gt_gt_0 & pred_leq_229).flatten()))
    false_negative_90p = int(sum((gt_eq_0 & pred_gt_229).flatten()))
    true_negative_90p = int(sum((gt_eq_0 & pred_leq_229).flatten()))
    
    summed_true_positives = float(sum((pred/255.0 * gt_gt_0).flatten()))
    summed_false_positives = float(sum((pred/255.0 * gt_eq_0).flatten()))
    summed_true_negatives = float(sum(((255 - pred)/255.0 * gt_eq_0).flatten()))
    summed_false_negatives = float(sum(((255 - pred)/255.0 * gt_gt_0).flatten()))
    
    label_img, _ = label(cond_eq_1)
    regions = regionprops(label_img)

    pores_1_50p_tp = 0
    pores_1_50p_fp = 0
    pores_1_50p_fn = 0
    pores_1_50p_tn = 0
    
    pores_1_90p_tp = 0
    pores_1_90p_fp = 0
    pores_1_90p_fn = 0
    pores_1_90p_tn = 0

    pores_25p_50p_tp = 0
    pores_25p_50p_fp = 0
    pores_25p_50p_fn = 0
    pores_25p_50p_tn = 0
    
    pores_25p_90p_tp = 0
    pores_25p_90p_fp = 0
    pores_25p_90p_fn = 0
    pores_25p_90p_tn = 0

    pores_50p_50p_tp = 0
    pores_50p_50p_fp = 0
    pores_50p_50p_fn = 0
    pores_50p_50p_tn = 0
    
    pores_50p_90p_tp = 0
    pores_50p_90p_fp = 0
    pores_50p_90p_fn = 0
    pores_50p_90p_tn = 0

    pores_25p_sum_tp = 0
    pores_25p_sum_fp = 0
    pores_25p_sum_fn = 0
    pores_25p_sum_tn = 0
    
    pores_50p_sum_tp = 0
    pores_50p_sum_fp = 0
    pores_50p_sum_fn = 0
    pores_50p_sum_tn = 0
    
    pores_75p_sum_tp = 0
    pores_75p_sum_fp = 0
    pores_75p_sum_fn = 0
    pores_75p_sum_tn = 0

    # img_test_1_50p = np.zeros([*gt.shape, 3], dtype=np.uint8)
    # img_test_1_50p_gt = np.zeros([*gt.shape, 3], dtype=np.uint8)
    # img_test_1_50p_gt[gt > 0, 0] = 255

    for region in regions:
        region_coords = region.coords
        
        pred_values = pred[region_coords[:, 0], region_coords[:, 1]]
        gt_values = gt[region_coords[:, 0], region_coords[:, 1]]
        
        is_positive = np.any(pred_values > 127)
        is_true = np.any(gt_values > 0)
        pores_1_50p_tp += int(is_positive and is_true)
        pores_1_50p_fp += int(is_positive and not is_true)
        pores_1_50p_fn += int(not is_positive and is_true)
        pores_1_50p_tn += int(not is_positive and not is_true)
        
        # img_test_1_50p[region_coords[:, 0], region_coords[:, 1], 0] = int(is_positive) * 255
        # img_test_1_50p[region_coords[:, 0], region_coords[:, 1], 1] = int(is_true) * 255
        # img_test_1_50p[region_coords[:, 0], region_coords[:, 1], 2] = int(not is_positive and not is_true) * 255
        # img_test_1_50p_gt[region_coords[:, 0], region_coords[:, 1], 1] = int(is_true) * 255
        
        is_positive = np.any(pred_values > 229)
        pores_1_90p_tp += int(is_positive and is_true)
        pores_1_90p_fp += int(is_positive and not is_true)
        pores_1_90p_fn += int(not is_positive and is_true)
        pores_1_90p_tn += int(not is_positive and not is_true)
        
        is_positive = np.sum(pred_values > 127) >= 0.25 * len(pred_values)
        pores_25p_50p_tp += int(is_positive and is_true)
        pores_25p_50p_fp += int(is_positive and not is_true)
        pores_25p_50p_fn += int(not is_positive and is_true)
        pores_25p_50p_tn += int(not is_positive and not is_true)
        
        is_positive = np.sum(pred_values > 229) >= 0.25 * len(pred_values)
        pores_25p_90p_tp += int(is_positive and is_true)
        pores_25p_90p_fp += int(is_positive and not is_true)
        pores_25p_90p_fn += int(not is_positive and is_true)
        pores_25p_90p_tn += int(not is_positive and not is_true)
        
        is_positive = np.sum(pred_values > 127) >= 0.50 * len(pred_values)
        pores_50p_50p_tp += int(is_positive and is_true)
        pores_50p_50p_fp += int(is_positive and not is_true)
        pores_50p_50p_fn += int(not is_positive and is_true)
        pores_50p_50p_tn += int(not is_positive and not is_true)
        
        is_positive = np.sum(pred_values > 229) >= 0.50 * len(pred_values)
        pores_50p_90p_tp += int(is_positive and is_true)
        pores_50p_90p_fp += int(is_positive and not is_true)
        pores_50p_90p_fn += int(not is_positive and is_true)
        pores_50p_90p_tn += int(not is_positive and not is_true)
        
        is_positive = int(np.sum(pred_values / 255.0) >= 0.25 * len(pred_values))
        pores_25p_sum_tp += int(is_positive and is_true)
        pores_25p_sum_fp += int(is_positive and not is_true)
        pores_25p_sum_fn += int(not is_positive and is_true)
        pores_25p_sum_tn += int(not is_positive and not is_true)
        
        is_positive = int(np.sum(pred_values / 255.0) >= 0.50 * len(pred_values))
        pores_50p_sum_tp += int(is_positive and is_true)
        pores_50p_sum_fp += int(is_positive and not is_true)
        pores_50p_sum_fn += int(not is_positive and is_true)
        pores_50p_sum_tn += int(not is_positive and not is_true)
        
        is_positive = int(np.sum(pred_values / 255.0) >= 0.75 * len(pred_values))
        pores_75p_sum_tp += int(is_positive and is_true)
        pores_75p_sum_fp += int(is_positive and not is_true)
        pores_75p_sum_fn += int(not is_positive and is_true)
        pores_75p_sum_tn += int(not is_positive and not is_true)
        
    # cv2.imwrite("../out/img_test_1_50p.png", img_test_1_50p[:,:,::-1])
    # cv2.imwrite("../out/img_test_1_50p_gt.png", img_test_1_50p_gt[:,:,::-1])
    # raise Exception("stop")

    return {
    "positive_50p": positive_50p,
    "negative_50p": negative_50p,
    "positive_90p": positive_90p,
    "negative_90p": negative_90p,
    "true_pixels": true_pixels,
    "false_pixels": false_pixels,
    "summed_positive_pixels": summed_positive_pixels,
    "summed_negative_pixels": summed_negative_pixels,
    
    "50p": {
        "tp": true_positive_50p,
        "fp": false_positive_50p,
        "fn": false_negative_50p,
        "tn": true_negative_50p,
    },
    "90p": {
        "tp": true_positive_90p,
        "fp": false_positive_90p,
        "fn": false_negative_90p,
        "tn": true_negative_90p,
    },
    "summed": {
        "tp": summed_true_positives,
        "fp": summed_false_positives,
        "tn": summed_true_negatives,
        "fn": summed_false_negatives,
    },
    
    "pores_1_50p": {
        "tp": pores_1_50p_tp,
        "fp": pores_1_50p_fp,
        "fn": pores_1_50p_fn,
        "tn": pores_1_50p_tn,
    },
    "pores_1_90p": {
        "tp": pores_1_90p_tp,
        "fp": pores_1_90p_fp,
        "fn": pores_1_90p_fn,
        "tn": pores_1_90p_tn,
    },
    "pores_25p_50p": {
        "tp": pores_25p_50p_tp,
        "fp": pores_25p_50p_fp,
        "fn": pores_25p_50p_fn,
        "tn": pores_25p_50p_tn,
    },
    "pores_25p_90p": {
        "tp": pores_25p_90p_tp,
        "fp": pores_25p_90p_fp,
        "fn": pores_25p_90p_fn,
        "tn": pores_25p_90p_tn,
    },
    "pores_50p_50p": {
        "tp": pores_50p_50p_tp,
        "fp": pores_50p_50p_fp,
        "fn": pores_50p_50p_fn,
        "tn": pores_50p_50p_tn,
    },
    "pores_50p_90p": {
        "tp": pores_50p_90p_tp,
        "fp": pores_50p_90p_fp,
        "fn": pores_50p_90p_fn,
        "tn": pores_50p_90p_tn,
    },
    "pores_25p_sum": {
        "tp": pores_25p_sum_tp,
        "fp": pores_25p_sum_fp,
        "fn": pores_25p_sum_fn,
        "tn": pores_25p_sum_tn,
    },
    "pores_50p_sum": {
        "tp": pores_50p_sum_tp,
        "fp": pores_50p_sum_fp,
        "fn": pores_50p_sum_fn,
        "tn": pores_50p_sum_tn,
    },
    "pores_75p_sum": {
        "tp": pores_75p_sum_tp,
        "fp": pores_75p_sum_fp,
        "fn": pores_75p_sum_fn,
        "tn": pores_75p_sum_tn,
    },
}

In [None]:
print(max_label)
assert max_label == 4
def get_all_stats(gt, pred, cond):
    gt = gt*cond
    pred = pred*cond
    all = get_stats(gt, pred, cond)

    groups = []

    for grp in range(max_label):
        sub_gt = gt * filter[grp]
        sub_pred = pred * filter[grp]

        group_stats = get_stats(sub_gt, sub_pred, filter[grp])
        groups.append(group_stats)
        
    return {"all": all, "groups": groups}

In [None]:
import itertools
A = range(11)  # 0 to 10
B = range(11)
pairs = list(itertools.product(A, B))
import pre_sal_ii
reload(pre_sal_ii)
prev_progress = pre_sal_ii.progress
from tqdm.notebook import tqdm

try:
    bar = tqdm(pairs)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    for i, j in bar:
        args = get_filename_args(i/10, True, j/10)
        bar.set_description(args)
        gt, pred, mean, std = run_images_with_best_models(f"../models/supervised-8-folds-1.2{args}.pt")
finally:
    pre_sal_ii.progress = prev_progress

try:
    bar = tqdm(A)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    for i in bar:
        args = get_filename_args(i/10, False)
        bar.set_description(args)
        gt, pred, mean, std = run_images_with_best_models(f"../models/supervised-8-folds-1.2{args}.pt")
finally:
    pre_sal_ii.progress = prev_progress


try:
    bar = tqdm(pairs)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    for i, j in bar:
        args = get_filename_args(i/10, True, j/10, mean_channel_weight=255.0)
        bar.set_description(args)
        gt, pred, mean, std = run_images_with_best_models(f"../models/supervised-8-folds-1.2{args}.pt")
finally:
    pre_sal_ii.progress = prev_progress

try:
    bar = tqdm(A)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    for i in bar:
        args = get_filename_args(i/10, False, mean_channel_weight=255.0)
        bar.set_description(args)
        gt, pred, mean, std = run_images_with_best_models(f"../models/supervised-8-folds-1.2{args}.pt")
finally:
    pre_sal_ii.progress = prev_progress

In [None]:
import os
import json
import glob
from importlib import reload
from tqdm.notebook import tqdm

def process_image_files():
    import pre_sal_ii
    reload(pre_sal_ii)
    prev_progress = pre_sal_ii.progress

    # Find all matching images
    files = sorted(glob.glob("../out/with_pwr/image_pred_8fold_true1.2*.png"))

    try:
        bar = tqdm(files)
        pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)

        for filepath in bar:
            # Extract the "args" from the filename
            filename = os.path.basename(filepath)
            # filename: image_pred_8fold_true1.2{args}.png
            args = filename.removeprefix("image_pred_8fold_true1.2").removesuffix(".png")

            out_file = f"../out/with_pwr/stats/stats_pred_8fold_1.2{args}.json"
            if os.path.exists(out_file):
                continue
            
            bar.set_description(args)

            gt, pred = open_image_and_split(filepath)
            data = get_all_stats(gt*cond, pred*cond, cond)

            os.makedirs("../out/with_pwr/stats/", exist_ok=True)

            with open(out_file, "w") as f:
                json.dump(data, f, indent=2)

    finally:
        pre_sal_ii.progress = prev_progress

process_image_files()