# Applying model learned from other image to detect moldic pores

In this notebook we will apply the model learned from a different image to see what happens.
The model was learned from an image in which all pores were classified in 3 classes: moldic, vugular and interparticle. In fact only the moldic class was used in the training process since it was the most abundant type of pores.

In [None]:
import os
print(os.getcwd())

In [None]:
from pre_sal_ii.improc import colorspace

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from pre_sal_ii.models.EncoderNN import EncoderNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

model = EncoderNN().to(device)
model.load_state_dict(torch.load("../models/supervised-1b.bin"))
model.eval()


In [None]:
from pre_sal_ii.improc import scale_image_and_save
image_name = "122.20_jpeg_escal"
path = f"../data/thin_sections/{image_name}.jpg"
scale_image_and_save(path, "../out/thin_sections_4x/", 25)

In [None]:
path = f"../out/thin_sections/{image_name}_25.jpg"
inputImage = cv2.imread(path)
plt.imshow(inputImage[:,:,::-1])

# BGR to CMKY:
inputImageCMYK = colorspace.bgr2cmyk(inputImage)

In [None]:
binaryImage = cv2.inRange(
    inputImageCMYK,
    (92,   0,   0,   0),
    (255, 255,  64, 196))
binaryImage

In [None]:
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
binaryImage = cv2.morphologyEx(binaryImage, cv2.MORPH_ERODE, kernel, iterations=1)
binaryImage = cv2.morphologyEx(binaryImage, cv2.MORPH_DILATE, kernel, iterations=1)
binaryImage = cv2.morphologyEx(binaryImage, cv2.MORPH_DILATE, kernel, iterations=1)
binaryImage = cv2.morphologyEx(binaryImage, cv2.MORPH_ERODE, kernel, iterations=1)

plt.imshow(binaryImage, cmap='gray')
cv2.imwrite("../out/some.jpg", binaryImage)
porosidade = np.sum(binaryImage/255)/binaryImage.size
print(f"porosidade = {porosidade}")

In [None]:
from skimage.measure import label, regionprops

label_img = label(binaryImage)
regions = regionprops(label_img)

In [None]:
all_objs = []
for it, region in enumerate(regions):
    ys = (region.coords.T[0] - label_img.shape[0]/2)/(label_img.shape[0]/2)
    xs = (region.coords.T[1] - label_img.shape[1]/2)/(label_img.shape[1]/2)
    obj = {
        "area": region.area,
        "max-dist": max((ys**2 + xs**2)**0.5),
    }
    all_objs.append(obj)

df = pd.DataFrame(all_objs)

In [None]:
max_dist = max(df["max-dist"])
pores_image3 = np.zeros(label_img.shape, dtype=np.uint8)
for it, region in enumerate(regions):
    if df["max-dist"].iloc[it] <= max_dist*0.8:
        color_value = 255
        pores_image3[region.coords.T[0], region.coords.T[1]] = color_value

In [None]:
print(pores_image3.shape)
plt.imshow(pores_image3, cmap="gray")


In [None]:
from pre_sal_ii.models.WhitePixelRegionDataset import WhitePixelRegionDataset

dataset2 = WhitePixelRegionDataset(
    pores_image3, inputImage, None, num_samples=-1, seed=None)


In [None]:
pred_image = np.zeros(inputImage.shape, dtype=np.uint8)

count_gt_half = 0

with torch.no_grad():
    from tqdm import tqdm
    for it, (imgX, _, coords) in enumerate(tqdm(dataset2)):
        # print(f"coords.shape={coords.shape}")
        imgX = imgX.to(device)
        imgX = imgX.unsqueeze(0)
        imgX = imgX.permute(0, 3, 1, 2)
        # print(f"imgX.shape={imgX.shape}")
        imgX = imgX/255
        imgX = F.interpolate(
            imgX, size=(32, 32), mode='bilinear',
            align_corners=False)
        imgX = imgX.reshape(-1, 3*32*32)
        # print(f"imgX.shape={imgX.shape}")
        # break
        Y = model(imgX)

        pred_image[int(coords[0]), int(coords[1])] = float(Y[0,0])*255

        # if float(Y[0,0]) > 0.5:
        #     count_gt_half += 1
        #     print(f"{[coords[0], coords[1]]} -> {pred_image[coords[0], coords[1]]} (Y[0,0]={Y[0,0]})")
            
        # if it > 1000: break

In [None]:
plt.imshow(pred_image, vmin=0, vmax=255, cmap="gray")
cv2.imwrite(f"../out/sup_pred_{image_name}_2.jpg", pred_image)
