### Utilities for training a model per channel 

In [6]:
'''
Author: Mohamed Ghafoor 
Contact: Moeghaf@gmail.com 
20/11/2023 

'''

# Installations 
!pip uninstall segmentation_models_pytorch --yes
!pip install segmentation_models_pytorch==0.1.0
!pip install -U albumentations
!pip3 install scikit-image


# Imports 
from skimage.util import view_as_windows, view_as_blocks
import numpy as np 
import torch 
import segmentation_models_pytorch as smp
import pandas as pd 
from glob import glob 
from tifffile import imread, imwrite
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt 
import albumentations as albu
import os 
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm 
from skimage import io, color, filters, exposure


def min_max_normalize(image):
    """
    Normalize pixel values in an image to the range [0, 1].

    Parameters:
        image (numpy.ndarray): Input image as a NumPy array.

    Returns:
        numpy.ndarray: Normalized image with pixel values scaled to [0, 1].

    Raises:
        ValueError: If the input image is not a NumPy array.

    Notes:
        - The normalization is performed using the minimum and maximum pixel values in the image.
        - Formula: normalized_image = (image - min_value) / (max_value - min_value).
        - This function assumes the input image has numerical pixel values.

    Example:
        >>> normalized_img = min_max_normalize(image_array)
    """
    # Convert to array 
    image = np.array(image)

    # Calculate the minimum and maximum pixel values in the image
    min_value, max_value = np.min(image), np.max(image)

    # Normalize pixel values to the range [0, 1]
    normalized_image = (image - min_value) / (max_value - min_value)

    return normalized_image


def preprocess_data(target):
    """
    Preprocesses input and output images normalized (0 to 1), histogram matched, 
    and split into 1x224x224 overlapping windows. 

    Parameters:
        target (str): Specific target/channel/metal/protein identifier.

    Returns:
        tuple: Six tensors containing training, validation, and testing windows for input and output images.

    Example:
        >>> x_train_patch, y_train_patch, x_val_patch, y_val_patch, x_test_patch, y_test_patch = preprocess_data("your_target")
    """
    # Load input and output paths
    x_dirs = glob('/home/ec2-user/SageMaker/x/*/*' + target + '*')
    y_dirs = glob('/home/ec2-user/SageMaker/Threshold_images_TS/*/*' + target + '*')

    # Match histograms of input images to the first input
    matched_image = imread(x_dirs[0])

    # Normalize input and output images
    x_images = [min_max_normalize(exposure.match_histograms(imread(i), matched_image)) for i in x_dirs]
    y_images = [min_max_normalize(imread(i)) for i in y_dirs]

    # Split the input images into training (4), validation (1), and testing (1)
    x_train = torch.tensor([x_images[0], x_images[1], x_images[3], x_images[5]])
    x_val = torch.tensor(x_images[4]).unsqueeze(0)
    x_test = torch.tensor(x_images[2]).unsqueeze(0)

    # Split the output images into training (4), validation (1), and testing (1)
    y_train = torch.tensor([y_images[0], y_images[1], y_images[3], y_images[5]])
    y_val = torch.tensor(y_images[4]).unsqueeze(0)
    y_test = torch.tensor(y_images[2]).unsqueeze(0)

    # Create windows of patches for training, validation, and testing
    x_train_patch, y_train_patch = create_patches_all(x_train, y_train)
    x_val_patch, y_val_patch = create_patches_all(x_val, y_val)
    x_test_patch, y_test_patch = create_patches_all(x_test, y_test)

    return x_train_patch, y_train_patch, x_val_patch, y_val_patch, x_test_patch, y_test_patch



def create_patches(image, stride=28, patch_height=224, patch_width=224):
    """
    Create overlapping patches from an input image.

    Parameters:
        image (numpy.ndarray): Input image as a NumPy array.
        stride (int): Stride for patch extraction. Default is 28.
        patch_height (int): Height of each patch. Default is 224.
        patch_width (int): Width of each patch. Default is 224.

    Returns:
        numpy.ndarray: Array of patches extracted from the input image.


    Example:
        >>> patches = create_patches(image_array, stride=20, 
                                     patch_height=128, patch_width=128)
    """
    # Convert to array 
    image = np.array(image)

    # Extract overlapping patches
    patches = view_as_windows(image, (patch_height, patch_width), step=stride)

    # Reshape the patches to create a list of individual patches
    patches = patches.reshape(-1, patch_height, patch_width)
    
    patches = np.expand_dims(patches, 1)
    return patches


def create_patches_all(x, y, stride=28, patch_height=224, patch_width=224):
    """
    Create overlapping patches for a list of input images.

    Parameters:
        x (list): List of input images (NumPy arrays).
        y (list): List of corresponding target images (NumPy arrays).
        stride (int): Stride for patch extraction. Default is 28.
        patch_height (int): Height of each patch. Default is 224.
        patch_width (int): Width of each patch. Default is 224.

    Returns:
        tuple: Two tensors containing patches for input (x) and target (y) images.


    Example:
        >>> x_patches, y_patches = create_patches_all(x_images_list, 
                                                    y_images_list, 
                                                    stride=20, 
                                                    patch_height=128, 
                                                    patch_width=128)
    """

    x_patch = []
    y_patch = []

    for i in range(len(x)):
        x_patch.extend(create_patches(x[i], stride, patch_height, patch_width))
        y_patch.extend(create_patches(y[i], stride, patch_height, patch_width))

    y_patch =  torch.tensor(y_patch, dtype=torch.float)
    x_patch =  torch.tensor(x_patch, dtype=torch.float)
    return x_patch, y_patch 

def get_training_augmentation():
    """
    Defines a set of data augmentations for training images.
    
    Returns:
        albumentations.Compose: Composition of image augmentations.

    Example:
        >>> augmentation = get_training_augmentation()
        >>> augmented_image = augmentation(image=my_input_image, mask=my_mask_image)
    """
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        albu.RandomCrop(height=224, width=224, always_apply=True),

    ]
    return albu.Compose(train_transform)



class CustomDataset(Dataset):
    """
    Custom PyTorch dataset for handling input and target data with optional augmentation and preprocessing.

    Parameters:
        x_data (list or torch.Tensor): List or tensor containing input data.
        y_data (list or torch.Tensor): List or tensor containing target data.
        augmentation (albumentations.Compose, optional): Augmentation to be applied to the data.

    Methods:
        __len__: Returns the number of samples in the dataset.
        __getitem__: Returns a sample from the dataset at the specified index.

    Example:
        >>> dataset = CustomDataset(x_data=my_input_data, y_data=my_target_data, augmentation=my_augmentation, preprocessing=my_preprocessing)
        >>> sample = dataset[0]
        >>> input_sample, target_sample = sample
    """

    def __init__(self, x_data, y_data, augmentation=None):
        """
        Initializes the CustomDataset.

        Args:
            x_data (list or torch.Tensor): List or tensor containing input data.
            y_data (list or torch.Tensor): List or tensor containing target data.
            augmentation (albumentations.Compose, optional): Augmentation to be applied to the data.
            preprocessing (function, optional): Preprocessing function to be applied to the data.
        """
        self.x_data = x_data
        self.y_data = y_data
        self.augmentation = augmentation

    def __len__(self):
        """
        Returns the number of samples in the dataset.

        Returns:
            int: Number of samples in the dataset.
        """
        return len(self.x_data)

    def __getitem__(self, idx):
        """
        Returns a sample from the dataset at the specified index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the input and target samples.
        """
        x = self.x_data[idx]
        y = self.y_data[idx]

        # Apply augmentation if specified
        if self.augmentation:
            x = np.array(x).transpose(1, 2, 0)
            y = np.array(y).transpose(1, 2, 0)
            sample = self.augmentation(image=x, mask=y)
            x, y = sample['image'], sample['mask']
            x = torch.tensor(x.transpose(2, 0, 1))
            y = torch.tensor(y.transpose(2, 0, 1))

        return x, y


def pad_image(image, block_size=224):
    """
    Pads an image to make its dimensions divisible by a specified block size.

    Args:
        image (numpy.ndarray): Input image to be padded.
        block_size (int, optional): Size of the block. Defaults to 224.

    Returns:
        tuple: A tuple containing the padded image and the padding dimensions.
    """
    # Calculate the padding needed to make the dimensions divisible by block_size
    height, width = image.shape
    pad_height = (block_size - height % block_size) % block_size
    pad_width = (block_size - width % block_size) % block_size

    # Calculate the padding values for the top and bottom (ph) and left and right (pw) sides
    ph = int(pad_height / 2)
    pw = int(pad_width / 2)

    # Pad the image with zeros if needed
    padded_image = np.pad(image, ((ph, ph), (pw, pw)), mode='constant', constant_values=0)

    # Return the padded image and padding dimensions as a tuple
    return padded_image, (ph, pw)


def unpad_image(image, padding):
    """
    Removes padding from an image based on the specified padding dimensions.

    Args:
        image (numpy.ndarray): Padded image to be unpadded.
        padding (tuple): Padding dimensions (ph, pw).

    Returns:
        numpy.ndarray: The unpadded image.
    """
    # Extract padding dimensions
    ph, pw = padding

    # Remove the padding from the image
    unpadded_image = image[ph:-ph, pw:-pw]

    # Return the unpadded image
    return unpadded_image


def image_to_blocks(image):
    """
    Divides an image into blocks of a desired size.

    Args:
        image (numpy.ndarray): Input image to be divided into blocks.

    Returns:
        tuple: A tuple containing the blocks as a PyTorch tensor and the padding dimensions.
    """
    # Desired block size
    block_size = 224

    # Pad the image
    padded_image, padding = pad_image(image, block_size)

    # Divide the padded image into 224x224 blocks
    blocks = view_as_blocks(padded_image, (block_size, block_size))
    
    # Return blocks as a PyTorch tensor and padding dimensions
    return torch.tensor(blocks.astype(np.float32)), padding


def blocks_to_image(blocks): 
    """
    Reconstructs an image from blocks.

    Args:
        blocks (numpy.ndarray): Blocks to be reconstructed into an image.

    Returns:
        numpy.ndarray: The reconstructed image.
    """
    # Concatenate blocks horizontally to form rows
    rows = [np.hstack(row) for row in blocks]

    # Stack rows vertically to reconstruct the image
    reconstructed_image = np.vstack(rows)

    # Return the reconstructed image
    return reconstructed_image


def train_single_models(targets):
    """
    Trains a U-Net model for each target protein and evaluates its performance.

    Args:
        targets (list): List of target proteins.

    Returns:
        None
    """
    # Create directories for saving model results
    os.makedirs('single_model_results_unet_v2/trained_models', exist_ok=True)
    os.makedirs('single_model_results_unet_v2/test_predictions', exist_ok=True)
    os.makedirs('single_model_results_unet_v2/model_results', exist_ok=True)

    for p, target in enumerate(targets):
        print('\n ---------------------------------------------')
        print('Training model for ', target)

        print('Preprocessing data ...')
        # Preprocess data 
        x_train_patch, y_train_patch, x_val_patch, y_val_patch, x_test_patch, y_test_patch = preprocess_data(target)

        # Load into dataloaders
        custom_dataset = CustomDataset(x_train_patch, y_train_patch, augmentation=get_training_augmentation())
        train_dataloader = DataLoader(custom_dataset, batch_size=32, shuffle=True, num_workers=8)

        custom_dataset = CustomDataset(x_val_patch, y_val_patch)
        val_dataloader = DataLoader(custom_dataset, batch_size=32, shuffle=True, num_workers=8)

        custom_dataset = CustomDataset(x_test_patch, y_test_patch)
        test_dataloader = DataLoader(custom_dataset, batch_size=1, shuffle=True, num_workers=8)

        print('Building model...')
        # Build autoencoder  
        ENCODER = 'resnet152'
        ENCODER_WEIGHTS = 'imagenet'
        ACTIVATION = 'sigmoid'  # could be None for logits or 'softmax2d' for multiclass segmentation
        DEVICE = 'cuda'

        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            in_channels=1,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=1,
            activation=ACTIVATION,
        )

        # Freeze encoder layers
        for param in model.encoder.parameters():
            param.requires_grad = False

        loss = smp.utils.losses.DiceLoss()
        metrics = [
            smp.utils.metrics.IoU(threshold=0.5),
        ]

        optimizer = torch.optim.Adam([
            dict(params=model.parameters(), lr=0.0001),
        ])

        train_epoch = smp.utils.train.TrainEpoch(
            model,
            loss=loss,
            metrics=metrics,
            optimizer=optimizer,
            device=DEVICE,
            verbose=True,
        )

        valid_epoch = smp.utils.train.ValidEpoch(
            model,
            loss=loss,
            metrics=metrics,
            device=DEVICE,
            verbose=True,
        )
        max_score = 0
        
        # Train model 
        print('Training model...')
        for i in range(0, 10):

            print('\nEpoch: {}'.format(i))
            train_logs = train_epoch.run(train_dataloader)
            valid_logs = valid_epoch.run(val_dataloader)

            if max_score < valid_logs['iou_score']:
                max_score = valid_logs['iou_score']
                torch.save(model, 'single_model_results_unet_v2/trained_models/'+target+'_best_model.pth')
                print('Model saved!')

        # Load best model
        best_model = torch.load('single_model_results_unet_v2/trained_models/'+target+'_best_model.pth')
        best_model.eval()

        # evaluate model on the test set
        test_epoch = smp.utils.train.ValidEpoch(
            model=best_model,
            loss=loss,
            metrics=metrics,
            device=DEVICE,
        )

        logs = test_epoch.run(test_dataloader)

        # Paths to data 
        x_dirs = glob('/home/ec2-user/SageMaker/x/*/*'+target+'*')
        y_dirs = glob('/home/ec2-user/SageMaker/Threshold_images_TS/*/*'+target+'*')
        
        # Load test data and histogram match 
        matched_image = imread(x_dirs[0])
        x_test = torch.tensor(min_max_normalize(exposure.match_histograms(imread(x_dirs[2]), matched_image))).unsqueeze(0)
        y_test = torch.tensor(min_max_normalize(imread(y_dirs[2]))).unsqueeze(0)
        
        # Split test data into non-ovelapping windows/blocks
        x_in, padding = image_to_blocks(x_test[0])
        x_in = x_in.reshape(12, 1, 224, 224)
        x_out = best_model.predict(x_in.cuda())

        # Predict and reshape back into original dimensions 
        pred_img = x_out.squeeze(0).squeeze(0).cpu().numpy().round()  # Select those above 0.4
        pred_img = pred_img.reshape(4, 3, 224, 224)
        pr_mask_stacked = blocks_to_image(pred_img)
        predicted_mask = unpad_image(pr_mask_stacked, padding)
        
        # Save results and measure performance 
        print('Testing model, saving results ... ')
        fig, ax = plt.subplots(1, 3)
        fig.suptitle(proteins[p])
        ax[0].imshow(x_test[0])
        ax[0].set_title('Raw')
        ax[0].axis('off')
        ax[1].imshow(y_test[0])
        ax[1].set_title('Manual threshold')
        ax[1].axis('off')
        ax[2].imshow(predicted_mask)
        ax[2].set_title('Predicted threshold')
        ax[2].axis('off')
        fig.savefig('single_model_results_unet_v2/test_predictions/'+target+'_prediction.png')
        plt.show()

        # Test metrics
        predicted_labels = predicted_mask.flatten()
        true_labels = y_test[0].flatten()

        # Calculate precision, recall, and F1 score
        accuracy = accuracy_score(true_labels, predicted_labels)
        precision = precision_score(true_labels, predicted_labels)
        recall = recall_score(true_labels, predicted_labels)
        f1 = f1_score(true_labels, predicted_labels)

        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1 Score: {f1:.4f}")
        
        # Save accuracy, precision, recall and F1 for each model 
        df = pd.DataFrame([[target, np.round(accuracy, 4), np.round(precision, 4), np.round(recall, 4), np.round(f1, 4)]],
                          columns=['Model', 'Accuracy', 'Precision', 'Recall', 'F1 Score'])
        df.to_csv('single_model_results_unet_v2/model_results/'+target+'_metrics.csv')


            
def plot_images(image_list, rows, cols,savename=None, figsize=(10, 10)):
    """
    Plot a grid of images from the given image list.
    
    Parameters:
        image_list (list): List of images (e.g., NumPy arrays or matplotlib AxesImage objects).
        rows (int): Number of rows in the grid.
        cols (int): Number of columns in the grid.
        figsize (tuple): Size of the figure. Default is (10, 10).
    """
    # Create a new figure and set the size of the grid
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    
    # Flatten the axes array to iterate over all subplots
    flat_axes = axes.flatten()
    
    # Plot each image in the grid
    for i, ax in enumerate(flat_axes):
        if i < len(image_list):
            ax.imshow(image_list[i], cmap='gray')
            ax.axis("off")
        else:
            ax.axis("off")
    
    # Adjust layout and show the plot
    plt.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.98, wspace=0.05, hspace=0.05)
    plt.show()
    if savename != None: 
        fig.savefig(savename)       
        
        

            

    



    
    
    
    

Found existing installation: segmentation-models-pytorch 0.1.0
Uninstalling segmentation-models-pytorch-0.1.0:
  Successfully uninstalled segmentation-models-pytorch-0.1.0
Collecting segmentation_models_pytorch==0.1.0
  Using cached segmentation_models_pytorch-0.1.0-py3-none-any.whl (42 kB)
Installing collected packages: segmentation_models_pytorch
Successfully installed segmentation_models_pytorch-0.1.0


