# Pore type prediction from thin-section images 1.2

In this notebook, we get all the code from the 1.0 version, and execute it  many times with multiple parameters combinations.

Different from the 1.1 version of this notebook, instead on having the selector and channels as a yes/no answer, we will consider a continuum of powers ranging from 0.0 to 1.0, in intervals of 0.1. This gives us 100 combinations to train and get data from.


## Initial setup

In [None]:
import os
print(os.getcwd())

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

from importlib import reload

import pre_sal_ii.models as models
import pre_sal_ii.custom.module1 as m1
reload(models)
reload(m1)
models.set_all_seeds(0)

import numpy as np


import numpy as np
import cv2

import numpy as np
from pre_sal_ii.improc import generate_region_map_from_centroids


import torch


from importlib import reload
import pre_sal_ii.models.nn as nn_models
import pre_sal_ii.models.ds as ds_models
reload(ds_models)
import torch
import torch.optim as optim
import torch.nn as nn

import pre_sal_ii.models as md

import pre_sal_ii.improc as improc
reload(improc)

import pre_sal_ii
from tqdm.notebook import tqdm
pre_sal_ii.progress = tqdm

## Visualizations

In [None]:
inputImage, inputImage_no_gamma = m1.get_input_image()

mean_image, stdev_image = m1.get_mean_stdev(inputImage_no_gamma)
stdev_image = stdev_image / max(stdev_image.flatten())

plt.imshow(stdev_image, cmap='gray')

In [None]:
print(max(mean_image.flatten()))
print(max(stdev_image.flatten()))

In [None]:
selector_mask = improc.preprocess_segments(
    mean_image, area_threshold=0.05, morphological_processing={"grow": 25})
print(max(selector_mask.flatten()))
b = np.clip(mean_image, 0, 255).astype(np.uint8)
g = np.clip(selector_mask*255, 0, 255).astype(np.uint8)
r = np.zeros_like(g, dtype=np.uint8)
plt.imshow(cv2.merge([b, g, r])[:,:,::-1])

In [None]:
selmask2 = (selector_mask* 255).astype(np.uint8)
print(selmask2.dtype, selmask2.min(), selmask2.max(), selmask2.shape)

In [None]:
cv2.imwrite("../out/selector_mask.png", selmask2)

In [None]:
plt.imshow((stdev_image*selector_mask)**(1.0 + 0.000001))

In [None]:
power = 0.5
stdev_image2 = stdev_image
# stdev_image2 = cv2.normalize(stdev_image2, None, 0.001**(1/power), 1.0, cv2.NORM_MINMAX)
stdev_image2 = (stdev_image2)**(power)
print(min(stdev_image2.flatten()), max(stdev_image2.flatten()))
plt.imshow(stdev_image2)
plt.show()
plt.hist(stdev_image2.flatten(), bins=100)
plt.show()

In [None]:
prob_base = m1.get_probability_maps_simple(inputImage)
plt.imshow(prob_base, cmap='gray')

## Training

In [None]:
0.0001**0.1

In [None]:
def run(
            selector_power=0.0, use_channels=False,
            stdev_channel_power=0.0, debug=False,
            mean_channel_weight=1.0,
            stdev_channel_weight=1.0,
            color_channels_weight=1.0,
            normalize_stdev=True,
        ):
    """
     Runs the supervised training with options to use selector mask and additional channels.
    """

    args = f"_selector_pwr={selector_power:.2f}"
    if use_channels:
        args += f"_channels_pwr={stdev_channel_power:.2f}"
    else:
        args += f"_channels=False"
    if color_channels_weight != 1.0:
        args += f"_color_wt={color_channels_weight:.2f}"
    if mean_channel_weight != 1.0:
        args += f"_mean_wt={mean_channel_weight:.2f}"
    if stdev_channel_weight != 1.0:
        args += f"_stdev_wt={stdev_channel_weight:.2f}"
    if not normalize_stdev:
        args += f"_stdev_norm=False"
    filename = f"../models/supervised-8-folds-1.2{args}.pt"

    if os.path.exists(filename):
        return

    if debug: print("Starting run...")
    inputImage, inputImage_no_gamma = m1.get_input_image()
    
    mean_image, stdev_image = m1.get_mean_stdev(inputImage_no_gamma)
    max_stdev_pixel = max(stdev_image.flatten())
    stdev_image = stdev_image / max_stdev_pixel
    selector_mask = improc.preprocess_segments(
        mean_image, area_threshold=0.03, morphological_processing={"grow": 25})
    
    if debug: print("Getting probability maps...")
    prob_base = (stdev_image*selector_mask)**(selector_power + 0.000001) # pyright: ignore[reportOperatorIssue, reportPossiblyUnboundVariable]
    #

    binaryImage_clRed = m1.load_manually_categorized_image()

    # Create 8-folds from the probability image
    if debug: print("Loading 8-fold divisions of binaryImage_clRed...")
    kmc_model = m1.get_kmc_model(binaryImage_clRed, debug=debug)
    
    md.set_all_seeds(42)
    
    centroids = kmc_model.cluster_centers_
    regions4 = generate_region_map_from_centroids(np.ones_like(prob_base, dtype=np.uint8), centroids)

    #
    num_regions = 8

    if debug: print("Creating 8-fold probability masks...")
    from pre_sal_ii import progress
    prob_masks = []
    for i in progress(range(num_regions)):
        mask_i = (regions4 == i).astype(float)
        prob_masks.append(prob_base * mask_i)


    fold_count = 8
    batch_size = 128
    num_samples = int(10000/(fold_count - 1)//batch_size*batch_size)

    if debug: print(f"num_samples = {num_samples}")
    if debug: print(f"batch_size = {batch_size}")
    if debug: print(f"fold_count = {fold_count}")

    if debug: print("Adjusting input image...")
    inputImage = inputImage.astype(np.float32)*(color_channels_weight/255.0)
    if use_channels:
        mean_image = np.clip(mean_image, 0, 255)*(mean_channel_weight/255.0)
        
        stdev_image2 = stdev_image
        # stdev_image2 = improc.rescale(stdev_image2, 0.0, 1.0, 0.0001, 1.0)
        stdev_image2 = (stdev_image2)**(stdev_channel_power)
        stdev_image2 = stdev_image2*stdev_channel_weight
        if not normalize_stdev:
            stdev_image2 = stdev_image2*(max_stdev_pixel/255.0)

        inputImage = np.dstack((inputImage, mean_image, stdev_image2)) # pyright: ignore[reportPossiblyUnboundVariable]

    if debug: print("Creating datasets...")
    datasets = [
        ds_models.ProbabilityMapPixelRegionDataset(
                prob_map, inputImage, binaryImage_clRed/255.,
                num_samples=num_samples,
                region_size=101, target_region_size=1, seed=4290
            ) for prob_map in prob_masks
        ]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if debug: print("Creating models...")
    channels = 3 if not use_channels else 5
    models = [nn_models.EncoderNN(initial_dim=channels*32*32).to(device) for _ in range(fold_count)]
    criterion = nn.MSELoss()
    optimizers = [optim.AdamW(models[it].parameters(),
                            lr=1e-4,
                            weight_decay=1e-5,
                        ) for it in range(fold_count)]
    

    if debug: print("Creating trainers...")
    from pre_sal_ii.training import cross_validate
    trainers = [m1.MyTrainer101x101to32x32(
            models[fold],
            optimizers[fold],
            criterion,
            device=device,
            channels=channels,
        ) for fold in range(fold_count)]
    
    #
    # TRAINING WITH CROSS-VALIDATION AND EARLY STOPPING
    #
    if debug: print("Training...")
    best_models, best_losses, best_epochs = cross_validate(
        trainers, datasets, num_epochs=200, patience=15,
        msg_info=f"selpwr={selector_power:.2f}_ch={'T' if use_channels else 'F'}{f'_chpwr={stdev_channel_power:.2f}' if use_channels else ''}",
        debug=debug
        )

    torch.save({
        "models": [m.state_dict() for m in best_models],
        "fold_losses": best_losses,
        "epochs": best_epochs,
    }, filename)
    

### Training with a queue (distributed)

In [None]:
import requests
import json
import itertools

A = range(11)  # 0 to 10
B = range(11)
items = list(itertools.product(A, B))

url = "http://localhost:8191/create"
serialized_items = "\n".join(json.dumps(item) for item in items)
print(f"Submitting {len(items)} items to the queue...")
# print("  " + "\n  ".join(serialized_items.split("\n")))
response = requests.post(url, data={"items": serialized_items})

# Check for errors
if response.status_code == 200:
    guid = response.text.strip()  # The server returns the GUID as plain text
    print("Queue created with ID:", guid)
else:
    print("Error:", response.status_code, response.text)

In [None]:
import requests
import json
import itertools

url = "http://localhost:8191/"
guid = "1c6d6274-b07e-492c-90a2-8a4f8f56eeb3"
total_items = 121
requests.get(f"{url}size?id={guid}").json()

In [None]:
A = range(11)  # 0 to 10
B = range(11)
items = list(itertools.product(A, B))

for item in items:
    i, j = item
    filename = f"../models/supervised-8-folds-1.2_selector_pwr={i/10:.2f}_channels_pwr={j/10:.2f}.pt"
    if os.path.exists(filename):
        requests.get(f"{url}remove?id={guid}&item={json.dumps((i,j))}").json()


In [None]:
import pre_sal_ii
reload(pre_sal_ii)
prev_progress = pre_sal_ii.progress
got = None
try:
    from tqdm.notebook import tqdm
    bar = tqdm(total=total_items)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    while True:
        # print("Fetching next item...")
        response = requests.get(f"{url}get?id={guid}")
        if response.status_code != 200:
            print("No more items to process.")
            bar.n = total_items
            bar.refresh()
            break
        # print("Received response:", response.json())
        i, j = response.json()
        got = (i, j)
        bar.set_description(f"selpwr={i/10:.2f}_ch=True_chpwr={j/10:.2f}")
        missing = requests.get(f"{url}size?id={guid}").json()
        bar.n = total_items - missing
        bar.refresh()
        run(selector_power=i/10, use_channels=True, stdev_channel_power=j/10, debug=False)
        got = None
except Exception as e:
    if got is not None:
        requests.get(f"{url}add?id={guid}&item={json.dumps(got)}").json()
    pre_sal_ii.progress = prev_progress
    print(f"Raising exception {e}...")
    raise e

In [None]:
if got is not None:
    requests.get(f"{url}add?id={guid}&item={json.dumps(got)}").json()
    got = None

### Training all locally

In [None]:
import itertools
A = range(11)  # 0 to 10
B = range(11)
pairs = list(itertools.product(A, B))
import pre_sal_ii
reload(pre_sal_ii)
prev_progress = pre_sal_ii.progress
try:
    from tqdm.notebook import tqdm
    bar = tqdm(pairs)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    for i, j in bar:
        bar.set_description(f"selpwr={i/10:.2f}_ch=True_chpwr={j/10:.2f}")
        run(selector_power=i/10, use_channels=True, stdev_channel_power=j/10, debug=False)
except Exception as e:
    pre_sal_ii.progress = prev_progress

In [None]:
import itertools
A = range(11)  # 0 to 10
B = range(11)
pairs = list(itertools.product(A, B))
import pre_sal_ii
reload(pre_sal_ii)
prev_progress = pre_sal_ii.progress
try:
    from tqdm.notebook import tqdm
    bar = tqdm(pairs)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    for i, j in bar:
        bar.set_description(f"selpwr={i/10:.2f}_ch=True_chpwr={j/10:.2f}_mean_wt=255.0")
        run(selector_power=i/10, use_channels=True, stdev_channel_power=j/10,
            mean_channel_weight=255.0,
            debug=False)
except Exception as e:
    pre_sal_ii.progress = prev_progress

In [None]:
A = range(11)  # 0 to 10
import pre_sal_ii
reload(pre_sal_ii)
prev_progress = pre_sal_ii.progress
try:
    from tqdm.notebook import tqdm
    bar = tqdm(A)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    for i in bar:
        bar.set_description(f"selpwr={i/10:.2f}_ch=False")
        run(selector_power=i/10, use_channels=False, debug=False)
except Exception as e:
    pre_sal_ii.progress = prev_progress

In [None]:
A = range(11)  # 0 to 10
import pre_sal_ii
reload(pre_sal_ii)
prev_progress = pre_sal_ii.progress
try:
    from tqdm.notebook import tqdm
    bar = tqdm(A)
    pre_sal_ii.progress = lambda *args, **kwargs: tqdm(*args, leave=False, **kwargs)
    for i in bar:
        bar.set_description(f"selpwr={i/10:.2f}_ch=False_mean_wt=255.0")
        run(selector_power=i/10, use_channels=False,
            mean_channel_weight=255.0,
            debug=False)
except Exception as e:
    pre_sal_ii.progress = prev_progress