In [8]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

import pickle

from skimage.util import view_as_windows
from sklearn.neighbors import KNeighborsClassifier
from skimage.measure import moments_central, moments_hu, moments_normalized

import segmentation_models_pytorch as smp
import torch

from torchvision.transforms import v2

In [9]:
img_path = './data/test/image/0.png'
mask_path = './data/test/mask/0.png'

In [10]:
def preprocess(img):
    green_photo = img[:,:,1]

    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    clahe_photo = clahe.apply(green_photo)

    return clahe_photo

def get_img_and_mask(img_path, mask_path):
    img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
    mask = cv2.cvtColor(cv2.imread(str(mask_path)), cv2.COLOR_BGR2GRAY)
    mask = np.array(mask).astype(np.uint8)

    return img, mask

def basic_approach_postprocess(masks):
    postprocessed_masks = [cv2.threshold(mask, 255, 255, cv2.THRESH_BINARY)[1] for mask in masks]

    postprocessed_masks = np.array(postprocessed_masks).astype(np.uint8)

    # makeing black everything that is not connected to the biggest white object
    for i in range(len(postprocessed_masks)):
        nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(postprocessed_masks[i])
        max_label = 0
        max_size = 0
        for j in range(1, nlabels):
            if stats[j, cv2.CC_STAT_AREA] > max_size:
                max_label = j
                max_size = stats[j, cv2.CC_STAT_AREA]
        postprocessed_masks[i][labels != max_label] = 0

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    postprocessed_masks = [cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) for mask in postprocessed_masks]

    return postprocessed_masks

def plot(img, mask, pred):
    plt.figure(figsize=(12, 6))
    plt.subplot(131)
    plt.imshow(img)
    plt.title('Image')
    plt.axis('off')
    plt.subplot(132)
    plt.imshow(mask, cmap='gray')
    plt.title('Mask')
    plt.axis('off')
    plt.subplot(133)
    plt.imshow(pred, cmap='gray')
    plt.title('Prediction')
    plt.axis('off')
    plt.show()

def sharr(img_path, mask_path):
    img, mask = get_img_and_mask(img_path, mask_path)
    prep_img = preprocess(img)

    scharr_x = cv2.Scharr(prep_img, cv2.CV_64F, 1, 0)
    scharr_y = cv2.Scharr(prep_img, cv2.CV_64F, 0, 1)
    scharr = cv2.magnitude(scharr_x, scharr_y)

    pred = basic_approach_postprocess([scharr])[0]

    plot(img, mask, pred)

def features_extracting(photo_batch):
    color_variance = np.var(photo_batch, axis=(1, 2))
    hu_moments = np.zeros((len(photo_batch), 7))
    for i, photo in enumerate(photo_batch):
        mu = moments_central(photo, order=5)
        nu = moments_normalized(mu)
        hu = moments_hu(nu)
        hu_moments[i] = hu

    features = np.concatenate((color_variance.reshape(-1, 1), hu_moments), axis=1)
    return features

def create_dataset(photos, masks, size=5, step=1):
    ds_photos = []
    ds_masks = []
    for photo, mask in zip(photos, masks):
        photo_pad = np.pad(photo, ((size // 2, size // 2), (size // 2, size // 2)), mode='constant')
        mask_pad = np.pad(mask, ((size // 2, size // 2), (size // 2, size // 2)), mode='constant')

        # Extract patches
        patches_photo = view_as_windows(photo_pad, (size, size), step=step).reshape(-1, size, size)
        patches_mask = view_as_windows(mask_pad, (size, size), step=step).reshape(-1, size, size)

        # Extract features for all patches in a batch
        batched_features = features_extracting(patches_photo)

        ds_photos.extend(batched_features)
        ds_masks.extend(patches_mask[:, size // 2, size // 2])  # Extract central pixel from each mask patch

    return ds_photos, ds_masks

def kneighbors(img_path, mask_path):
    model_path = "./models/knnclassifier.pkl"

    img, mask = get_img_and_mask(img_path, mask_path)
    img = preprocess(img)

    photo_batch, _ = create_dataset([img], [mask], size=5, step=1)

    model = pickle.load(open(model_path))
    pred = model.predict(photo_batch)

    y_pred_img = np.zeros((512, 512))
    for i in range(0, 512):
        for j in range(0, 512):
            y_pred_img[i][j] = pred[i * 512 + j]

    plot(img, mask, pred)

def unet(img_path, mask_path):
    model_path = "./models/model.pth"

    img, mask = get_img_and_mask(img_path, mask_path)
    prep_img = preprocess(img)

    model = smp.Unet()

    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

    test_image_transform = v2.Compose([
        v2.ToPILImage(),
        v2.Grayscale(num_output_channels=1),
        v2.ToTensor(),
        v2.Normalize(mean=[0.485], std=[0.229]),
        ])

    model.eval()
    x = test_image_transform(prep_img)
    x = x[:, None, :, :]
    pred = model(x)

    plot(img, mask, pred)

In [None]:
# sharr(img_path, mask_path)
# kneighbors(img_path, mask_path)
unet(img_path, mask_path)