### Class used to load trained model for each channel and predict

In [None]:
'''
Author: Mohamed Ghafoor 
Contact: Moeghaf@gmail.com 
20/11/2023 

'''

import torch 
import segmentation_models_pytorch as smp
from skimage.util import view_as_windows, view_as_blocks

import torch 
import segmentation_models_pytorch as smp
from skimage.util import view_as_windows, view_as_blocks

class single_ch_model:
    def __init__(self, ch_model_path, imgs_to_threshold):
        """
        Initialize the single channel model.

        Args:
            ch_model_path (str): Path to the trained model file.
            imgs_to_threshold (list): List of images to be thresholded.

        """
        # Load model 
        print('Loading trained model ...')
        self.model = torch.load(ch_model_path)
        self.model.eval()
        self.imgs_to_threshold = imgs_to_threshold
        self.normalize_imgs()
        print('Initialization complete')
        
    def normalize_imgs(self): 
        """
        Normalize the images in `self.imgs_to_threshold` using histogram matching.

        """
        # Convert the list of images to a numpy array
        print('Normalizing ...')
        images = np.array([(i) for i in self.imgs_to_threshold])
        self.normalized_images =  np.array([min_max_normalize(exposure.match_histograms(i, images[0])) for i in images])
     
    def pad_image(self, image, block_size=224):
        """
        Pad the input image to make its dimensions divisible by `block_size`.

        Args:
            image (numpy.ndarray): Input image.
            block_size (int): Size of the block for padding.

        Returns:
            padded_image (numpy.ndarray): Padded image.
            padding (tuple): Tuple containing padding dimensions.
        """
        # Calculate the padding needed to make the dimensions divisible by block_size
        height, width = image.shape
        pad_height = (block_size - height % block_size) % block_size
        pad_width = (block_size - width % block_size) % block_size
        ph = int(pad_height/2)
        pw = int(pad_width/2)

        # Pad the image with zeros if needed
        padded_image = np.pad(image, ((ph, ph), (pw, pw)), mode='constant', constant_values=0)

        return padded_image, (ph, pw)    
    
    def image_to_blocks(self, image):
        """
        Convert an image into non-overlapping blocks.

        Args:
            image (numpy.ndarray): Input image.

        Returns:
            blocks (torch.Tensor): Blocks of the image.
            padding (tuple): Tuple containing padding dimensions.
        """
        # Desired block size
        block_size = 224

        # Pad the image
        padded_image, padding = self.pad_image(image, block_size)

        # Divide the padded image into 224x224 blocks
        blocks = view_as_blocks(padded_image, (block_size, block_size))
        
        return torch.tensor(blocks.astype(np.float32)), padding
    
    def unpad_image(self, image, padding):
        """
        Remove padding from an image.

        Args:
            image (numpy.ndarray): Padded image.
            padding (tuple): Tuple containing padding dimensions.

        Returns:
            unpadded_image (numpy.ndarray): Unpadded image.
        """
        # Extract padding dimensions
        ph, pw = padding

        # Remove the padding from the image
        unpadded_image = image[ph:-ph, pw:-pw]

        return unpadded_image
    
    def blocks_to_image(self, blocks): 
        """
        Reconstruct an image from non-overlapping blocks.

        Args:
            blocks (numpy.ndarray): Blocks of the image.

        Returns:
            reconstructed_image (numpy.ndarray): Reconstructed image.
        """
        rows = [np.hstack(row) for row in blocks]
        reconstructed_image = np.vstack(rows)
        return reconstructed_image

    def predict(self):
        """
        Predict thresholded images using the trained model.

        """
        print('Predicting ...')
        self.thresholded_images = []
        for img in self.normalized_images:
            
            # Split image into blocks, reshape for input to the network 
            x_in, padding = self.image_to_blocks(img)
            x_in = x_in.reshape(12, 1, 224, 224)
            
            # Predict 
            x_out = self.model.predict(x_in.cuda())
            
            # Reshape back into the original image dimension 
            pred_img = x_out.squeeze(0).squeeze(0).cpu().numpy().round() # Select those above 0.4 
            pred_img = pred_img.reshape(4, 3, 224, 224)
            pr_mask_stacked = self.blocks_to_image(pred_img)
            predicted_mask = self.unpad_image(pr_mask_stacked, padding)
            self.thresholded_images.append(predicted_mask)
        self.thresholded_images  = np.array(self.thresholded_images)
