This is a simple notebook to predict algae cells based on our saved `unet` model.

Again let's start by cloning the codebase:

In [None]:
! git clone https://github.com/mahyar-osn/predict-algae-species.git

In [None]:
import sys
sys.path.insert(0, '/content/ml-test/src/provectus-algae-task/prediction')
import os

In [None]:
import core.config as config

In [None]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
def plot(orig_image, orig_annotations, pred_annotations):
    figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
    ax[0].imshow(orig_image)
    ax[1].imshow(orig_annotations)
    ax[2].imshow(pred_annotations)
    ax[0].set_title("Image")
    ax[1].set_title("Original Annotations")
    ax[2].set_title("Predicted Annotations")
    figure.tight_layout()
    figure.show()

In [None]:
def make_predictions(model, image_path):
    model.eval()  # set model to evaluation mode
    with torch.no_grad():  # turn off gradient tracking
        """ load the image from disk, swap its color channels, cast it
        to float data type, and scale its pixel values. """
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float32") / 255.0

        original = image.copy()
        filename = image_path.split("/")[-1]  # find the filename and generate the path to ground truth
        filename, _ = os.path.splitext(filename)
        ground_truth_path = os.path.join(config.MASK_DATASET_PATH,
                                         filename + '.png')

        """ load the ground-truth segmentation mask in grayscale mode and resize it. """
        gt_annotation = cv2.imread(ground_truth_path, 0)
        gt_annotation = cv2.resize(gt_annotation, (config.INPUT_IMAGE_HEIGHT,
                                                   config.INPUT_IMAGE_HEIGHT))

        """ make the channel axis to be the leading one, add a batch dimension,
        create a PyTorch tensor, and flash it to the current device. """
        image = np.transpose(image, (2, 0, 1))
        image = np.expand_dims(image, 0)
        image = torch.from_numpy(image).to(config.DEVICE)

        """ make the prediction, pass the results through the sigmoid function,
        and convert the result to a NumPy array. """
        prediction = model(image).squeeze()
        prediction = torch.sigmoid(prediction)
        prediction = prediction.cpu().numpy()
        prediction = (prediction > 0.15) * 255  # filter out the weak predictions and convert them to integers
        prediction = prediction.astype(np.uint8)
        plot(original, gt_annotation, prediction)  # plot

Let's try predicting the `Pp` strain cells:

In [None]:
strain = 'Pp'
print("[INFO] loading up test image paths...")
image_paths = open(config.TEST_PATHS).read().strip().split("\n")
image_paths = [x for x in image_paths if strain in x]
image_paths = np.random.choice(image_paths, size=4)
print("[INFO] load up model...")
unet = torch.load(config.MODEL_PATH + '.{}'.format(strain), map_location=torch.device('cpu'))
for path in image_paths:
    make_predictions(unet, path)  # predict and visualise


We can do the same thing for `Cr` strain cells:

In [None]:
strain = 'Cr'
print("[INFO] loading up test image paths...")
image_paths = open(config.TEST_PATHS).read().strip().split("\n")
image_paths = [x for x in image_paths if strain in x]
image_paths = np.random.choice(image_paths, size=4)
print("[INFO] load up model...")
unet = torch.load(config.MODEL_PATH + '.{}'.format(strain), map_location=torch.device('cpu'))
for path in image_paths:
    make_predictions(unet, path)  # predict and visualise