# Pore type prediction from thin-section images 1.1

In this notebook, we get all the code from the 1.0 version, and execute it  many times with multiple parameters combinations.

The parameter groups are:
- use data $C_{min}{\times}K_{max}$ as input selector: yes/no
- use data $C_{min}{\times}K_{max}$ as input channels: yes/no

The combinations in this case yield 4 different instances of trained models.

In [None]:
import os
print(os.getcwd())

In [None]:
from pre_sal_ii.improc import colorspace
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from importlib import reload
from pre_sal_ii.improc import scale_image_and_save, adjust_gamma

import pre_sal_ii.models as models
reload(models)
models.set_all_seeds(0)

import numpy as np
import pandas as pd
from skimage.measure import label, regionprops

import pickle
from pathlib import Path

from k_means_constrained import KMeansConstrained
from pre_sal_ii.training.image_clustering import cluster_pixels_kmeans_constrained_model

import numpy as np
import cv2

import numpy as np
from pre_sal_ii.improc import generate_region_map_from_centroids
from skimage.measure import label, regionprops
from typing import cast

import torch.nn.functional as F
from pre_sal_ii.training import Trainer
import torch


import pre_sal_ii.models.features.data_models as data_models

from importlib import reload
import pre_sal_ii.models.nn as nn_models
import pre_sal_ii.models.ds as ds_models
reload(ds_models)
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from pre_sal_ii.training import Trainer

import pre_sal_ii.models as md

from pre_sal_ii import improc
reload(improc)

In [None]:
def get_input_image():
    
    image_name = "ML-tste_original"
    path = f"../out/classificada_01/{image_name}_25.jpg"
    inputImage_no_gamma: np.ndarray = cv2.imread(path)
    inputImage = adjust_gamma(inputImage_no_gamma, 0.5)
    
    return inputImage, inputImage_no_gamma


In [None]:
def get_probability_maps_simple(inputImage):

    # BGR to CMKY:
    inputImageCMYK = colorspace.bgr2cmyk(inputImage)

    binaryImage = cv2.inRange(
        inputImageCMYK,
        (92,   0,   0,   0),
        (255, 255,  64, 196))

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
    binaryImage = cv2.morphologyEx(binaryImage, cv2.MORPH_ERODE, kernel, iterations=1)
    binaryImage = cv2.morphologyEx(binaryImage, cv2.MORPH_DILATE, kernel, iterations=1)
    binaryImage = cv2.morphologyEx(binaryImage, cv2.MORPH_DILATE, kernel, iterations=1)
    binaryImage = cv2.morphologyEx(binaryImage, cv2.MORPH_ERODE, kernel, iterations=1)

    label_img = cast(np.ndarray, label(binaryImage))
    regions = regionprops(label_img)

    all_objs = []
    for it, region in enumerate(regions):
        ys = (region.coords.T[0] - label_img.shape[0]/2)/(label_img.shape[0]/2)
        xs = (region.coords.T[1] - label_img.shape[1]/2)/(label_img.shape[1]/2)
        obj = {
            "area": region.area,
            "max-dist": max((ys**2 + xs**2)**0.5),
        }
        all_objs.append(obj)

    df = pd.DataFrame(all_objs)

    max_dist = max(df["max-dist"])
    pores_image3 = np.zeros(label_img.shape, dtype=np.uint8)
    for it, region in enumerate(regions):
        if df["max-dist"].iloc[it] <= max_dist*0.8:
            color_value = 255
            pores_image3[region.coords.T[0], region.coords.T[1]] = color_value

    return pores_image3/255.0


In [None]:
def load_manually_categorized_image():
    image_name = "ML-tste_classidicada"
    path = f"../out/classificada_01/{image_name}_25.jpg"
    inputImage_cl = cv2.imread(path)
    binaryImage_clRed: np.ndarray = cv2.inRange(
        inputImage_cl,
        #  B,   G,   R
        (  0,   0, 240),
        (  5,   5, 255))
    return binaryImage_clRed


In [None]:
def filter_central_objects(image: np.ndarray) -> np.ndarray:

    label_img = label(image)
    regions = regionprops(label_img)

    all_objs = []
    for it, region in enumerate(regions):
        ys = (region.coords.T[0] - label_img.shape[0]/2)/(label_img.shape[0]/2)
        xs = (region.coords.T[1] - label_img.shape[1]/2)/(label_img.shape[1]/2)
        obj = {
            "area": region.area,
            "max-dist": max((ys**2 + xs**2)**0.5),
        }
        all_objs.append(obj)

    df = pd.DataFrame(all_objs)

    max_dist = max(df["max-dist"])
    binaryImage_clRed_mx = np.zeros(label_img.shape, dtype=np.uint8)
    for it, region in enumerate(regions):
        if df["max-dist"].iloc[it] <= max_dist*0.8:
            color_value = 255
            binaryImage_clRed_mx[region.coords.T[0], region.coords.T[1]] = color_value
    return binaryImage_clRed_mx


In [None]:
def get_kmc_model(binaryImage_clRed) -> KMeansConstrained:
    cache_path = Path("../models/kmc_model_1.pkl")
    cache_path.parent.mkdir(exist_ok=True)

    if cache_path.exists():
        # Load cached model
        with open(cache_path, "rb") as f:
            kmc_model = pickle.load(f)
        print("Loaded cached model from disk.")
    else:
        # Train the model
        binaryImage_clRed_mx = filter_central_objects(binaryImage_clRed)
        kmc_model = cluster_pixels_kmeans_constrained_model(binaryImage_clRed_mx, fraction=10)
        # Save to cache
        with open(cache_path, "wb") as f:
            pickle.dump(kmc_model, f)
        print("Saved trained model to cache.")
        
    return kmc_model


In [None]:
def get_probability_maps_stdev(inputImage_no_gamma):
    df = pd.read_csv("../data/c_min_k_max_params.csv")
    xs = df["clicked_x"].astype(int)
    ys = df["clicked_y"].astype(int)
    mean_img, _, _ = data_models.compute_mean_image(inputImage_no_gamma, xs, ys, show_progress=True)
    stdev_image = data_models.compute_std_image(inputImage_no_gamma, xs, ys, mean_img, show_progress=True)
    return stdev_image / max(stdev_image.flatten()), mean_img


In [None]:
do_asserts = False

class MyTrainer(Trainer):
    def __init__(self, model, optimizer, criterion, device: str | torch.device = "cuda", channels=3, criterion_kwargs={}):
        super().__init__(model, optimizer, criterion, device, criterion_kwargs)
        self.channels = channels

    def train_epoch_step(self, inputs):
        imgs = inputs[0].to(self.device)
        if do_asserts: assert (*imgs.shape[1:],) == (self.channels, 101, 101)
        imgs = F.interpolate(
            imgs, size=(32, 32), mode='bilinear',
            align_corners=False)
        if do_asserts: assert (*imgs.shape[1:],) == (self.channels, 32, 32)
        imgs = imgs.reshape(-1, self.channels*32*32)
        if do_asserts: assert (*imgs.shape[1:],) == (self.channels*32*32,)
        outputs = self.model(imgs)
        return imgs.shape[0], outputs

    def train_epoch_loss(self, inputs, outputs):
        expected = inputs[1].to(self.device)
        expected = torch.squeeze(expected, 1)
        expected = torch.squeeze(expected, 2)
        if do_asserts: assert (*expected.shape[1:],) == (1,)
        loss = self.criterion(outputs, expected, **self.criterion_kwargs)
        return loss

In [None]:
def run(use_selector=False, use_channels=False):

    print("Starting run...")
    inputImage, inputImage_no_gamma = get_input_image()
    
    if use_selector or use_channels:
        stdev_image, mean_image = get_probability_maps_stdev(inputImage_no_gamma)
        #stdev_image = stdev_image / max(stdev_image.flatten())
        selector_mask = improc.preprocess_segments(mean_image, area_threshold=0.03, morphological_processing="grow")

    print("Getting probability maps...")
    if use_selector:
        prob_base = stdev_image * selector_mask # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable]
    else:
        prob_base = get_probability_maps_simple(inputImage)
    #

    binaryImage_clRed = load_manually_categorized_image()

    # Create 8-folds from the probability image
    print("Loading 8-fold divisions of binaryImage_clRed...")
    kmc_model = get_kmc_model(binaryImage_clRed)

    md.set_all_seeds(42)
    
    centroids = kmc_model.cluster_centers_
    regions4 = generate_region_map_from_centroids(np.ones_like(prob_base, dtype=np.uint8), centroids)

    #
    num_regions = 8

    print("Creating 8-fold probability masks...")
    from tqdm import tqdm
    prob_masks = []
    for i in tqdm(range(num_regions)):
        mask_i = (regions4 == i).astype(float)
        prob_masks.append(prob_base * mask_i)


    fold_count = 8
    batch_size = 128
    num_samples = int(10000/(fold_count - 1)//batch_size*batch_size)

    print(f"num_samples = {num_samples}")
    print(f"batch_size = {batch_size}")
    print(f"fold_count = {fold_count}")

    print("Adjusting input image...")
    inputImage = inputImage.astype(np.float32)/255.
    if use_channels:
        mean_image = np.clip(mean_image, 0, 255)/255.
        inputImage = np.dstack((inputImage, mean_image, stdev_image)) # pyright: ignore[reportPossiblyUnboundVariable]
    
    print("Creating datasets...")
    datasets = [
        ds_models.ProbabilityMapPixelRegionDataset(
                prob_map, inputImage, binaryImage_clRed/255.,
                num_samples=num_samples,
                region_size=101, target_region_size=1, seed=4290
            ) for prob_map in prob_masks
        ]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Creating models...")
    channels = 3 if not use_channels else 5
    models = [nn_models.EncoderNN(initial_dim=channels*32*32).to(device) for _ in range(fold_count)]
    criterion = nn.MSELoss()
    optimizers = [optim.AdamW(models[it].parameters(),
                            lr=1e-4,
                            weight_decay=1e-5,
                        ) for it in range(fold_count)]
    

    print("Creating trainers...")
    from pre_sal_ii.training import cross_validate
    trainers = [MyTrainer(
            models[fold],
            optimizers[fold],
            criterion,
            device=device,
            channels=channels,
        ) for fold in range(fold_count)]
    
    #
    # TRAINING WITH CROSS-VALIDATION AND EARLY STOPPING
    #
    print("Training...")
    best_models, best_losses, best_epochs = cross_validate(
        trainers, datasets#, num_epochs=200, patience=15
        )

    torch.save({
        "models": [m.state_dict() for m in best_models],
        "fold_losses": best_losses,
        "epochs": best_epochs,
    }, f"../models/supervised-8-folds-1.1_selector={use_selector}_channels={use_channels}.pt")

In [None]:
run(use_selector=True, use_channels=True)

In [None]:
run(use_selector=True, use_channels=False)


In [None]:
run(use_selector=False, use_channels=True)


In [None]:
run(use_selector=False, use_channels=False)


## Producing images with the best models

In [None]:
def run_images_with_best_models(model_file):
    import torch
    print("Starting run...")
    fold_count = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    use_selector = "selector=True" in model_file
    use_channels = "channels=True" in model_file

    channels = 3 if not use_channels else 5
    models2 = [nn_models.EncoderNN(initial_dim=channels*32*32).to(device) for _ in range(fold_count)]
    checkpoint = torch.load(model_file)
    for i, m in enumerate(models2):
        m.load_state_dict(checkpoint["models"][i])
    fold_losses2 = checkpoint["fold_losses"]
    
    model = models2[np.argmin(fold_losses2)]

    print("Getting input images...")
    inputImage, inputImage_no_gamma = get_input_image()
    pores_image3 = get_probability_maps_simple(inputImage)
    binaryImage_clRed = load_manually_categorized_image()

    print("Adjusting input image...")
    inputImage = inputImage.astype(np.float32)/255.
    if use_channels:
        stdev_image, mean_image = get_probability_maps_stdev(inputImage_no_gamma)
        #stdev_image = stdev_image / max(stdev_image.flatten())
        inputImage = inputImage.astype(np.float32)/255.
        inputImage = np.dstack((inputImage, mean_image, stdev_image)) # pyright: ignore[reportPossiblyUnboundVariable]

    print("Creating dataset...")
    dataset2 = ds_models.WhitePixelRegionDataset(
        pores_image3, inputImage, binaryImage_clRed/255.,
        num_samples=-1, seed=None, use_img_to_tensor=True)
    dataloader2 = DataLoader(dataset2, batch_size=1024, shuffle=False)

    trainer_best = MyTrainer(model, None, None, device=device, channels=channels)

    print("Inferring...")
    pred_image = np.zeros_like(binaryImage_clRed, dtype=np.uint8)

    count_gt_half = 0

    from tqdm import tqdm
    with torch.no_grad():
        for it, inputs in enumerate(tqdm(dataloader2)):
            _, _, coords = inputs
            step, outputs = trainer_best.train_epoch_step(inputs)
            Y = outputs

            xs = coords[:,1].cpu().numpy()
            ys = coords[:,0].cpu().numpy()
            vs = Y[:,0].cpu().numpy()
            pred_image[ys, xs] = vs*255

    print("Creating images...")
    plt.imshow(pred_image, vmin=0, vmax=255, cmap="gray")
    plt.show()
    cv2.imwrite(f"../out/sup_pred_8fold_1.1_selector={use_selector}_channels={use_channels}.jpg", pred_image)

    image_pred_true = np.zeros([*binaryImage_clRed.shape, 3], dtype=np.uint8)
    image_pred_true = torch.tensor(image_pred_true, dtype=torch.uint8).permute(2, 0, 1)
    image_pred_true[1,:,:] = torch.tensor(binaryImage_clRed, dtype=torch.uint8)
    image_pred_true[2,:,:] = torch.tensor(pred_image, dtype=torch.uint8)
    image_pred_true = image_pred_true.permute(1, 2, 0)
    image_pred_true = image_pred_true.numpy()
    plt.imshow(image_pred_true[:,:,::-1])
    plt.show()
    cv2.imwrite(f"../out/image_pred_8fold_true1.1_selector={use_selector}_channels={use_channels}.jpg", image_pred_true)


In [None]:
run_images_with_best_models("../models/supervised-8-folds-1.1_selector=True_channels=True.pt")


In [None]:
run_images_with_best_models("../models/supervised-8-folds-1.1_selector=True_channels=False.pt")


In [None]:
run_images_with_best_models("../models/supervised-8-folds-1.1_selector=False_channels=True.pt")


In [None]:
run_images_with_best_models("../models/supervised-8-folds-1.1_selector=False_channels=False.pt")
